aboutsummaryrefslogtreecommitdiff
path: root/nerv/layer/dropout.lua
diff options
context:
space:
mode:
authorTed Yin <Determinant@users.noreply.github.com>2016-03-12 13:17:38 +0800
committerTed Yin <Determinant@users.noreply.github.com>2016-03-12 13:17:38 +0800
commit2b03555ea53a47e87d03a79feb866c868d424f01 (patch)
tree63cd01ee70d056d3a159a1e7d9aa4ea6f414d213 /nerv/layer/dropout.lua
parente8b1007d99691c08dd1b71f5733eb3cd2827dc64 (diff)
parent2660af7f6a9ac243a8ad38bf3375ef0fd292bf52 (diff)
Merge pull request #31 from liuq901/master
modfiy param generate & rewrite LSTM layer
Diffstat (limited to 'nerv/layer/dropout.lua')
-rw-r--r--nerv/layer/dropout.lua11
1 files changed, 5 insertions, 6 deletions
diff --git a/nerv/layer/dropout.lua b/nerv/layer/dropout.lua
index 1a379c9..de0fb64 100644
--- a/nerv/layer/dropout.lua
+++ b/nerv/layer/dropout.lua
@@ -2,8 +2,7 @@ local DropoutLayer = nerv.class("nerv.DropoutLayer", "nerv.Layer")
function DropoutLayer:__init(id, global_conf, layer_conf)
nerv.Layer.__init(self, id, global_conf, layer_conf)
- self.rate = layer_conf.dropout_rate or global_conf.dropout_rate
- if self.rate == nil then
+ if self.gconf.dropout_rate == nil then
nerv.warning("[DropoutLayer:propagate] dropout rate is not set")
end
self:check_dim_len(1, 1) -- two inputs: nn output and label
@@ -41,12 +40,12 @@ function DropoutLayer:propagate(input, output, t)
if t == nil then
t = 1
end
- if self.rate then
+ if self.gconf.dropout_rate then
self.mask[t]:rand_uniform()
-- since we will lose a portion of the actvations, we multiply the
-- activations by 1 / (1 - rate) to compensate
- self.mask[t]:thres_mask(self.mask[t], self.rate,
- 0, 1 / (1.0 - self.rate))
+ self.mask[t]:thres_mask(self.mask[t], self.gconf.dropout_rate,
+ 0, 1 / (1.0 - self.gconf.dropout_rate))
output[1]:mul_elem(input[1], self.mask[t])
else
output[1]:copy_fromd(input[1])
@@ -61,7 +60,7 @@ function DropoutLayer:back_propagate(bp_err, next_bp_err, input, output, t)
if t == nil then
t = 1
end
- if self.rate then
+ if self.gconf.dropout_rate then
next_bp_err[1]:mul_elem(bp_err[1], self.mask[t])
else
next_bp_err[1]:copy_fromd(bp_err[1])