aboutsummaryrefslogtreecommitdiff
path: root/layer/softmax_ce.lua
diff options
context:
space:
mode:
Diffstat (limited to 'layer/softmax_ce.lua')
-rw-r--r--layer/softmax_ce.lua20
1 files changed, 15 insertions, 5 deletions
diff --git a/layer/softmax_ce.lua b/layer/softmax_ce.lua
index 09eb3a9..cf98c45 100644
--- a/layer/softmax_ce.lua
+++ b/layer/softmax_ce.lua
@@ -5,6 +5,10 @@ function SoftmaxCELayer:__init(id, global_conf, layer_conf)
self.gconf = global_conf
self.dim_in = layer_conf.dim_in
self.dim_out = layer_conf.dim_out
+ self.compressed = layer_conf.compressed
+ if self.compressed == nil then
+ self.compressed = false
+ end
self:check_dim_len(2, -1) -- two inputs: nn output and label
end
@@ -26,15 +30,21 @@ function SoftmaxCELayer:propagate(input, output)
soutput:softmax(input[1])
local ce = soutput:create()
ce:log_elem(soutput)
- ce:mul_elem(ce, input[2])
--- print(input[1][0])
--- print(soutput[1][0])
- -- add total ce
+ local label = input[2]
+ if self.compressed then
+ label = label:decompress(input[1]:ncol())
+ end
+ ce:mul_elem(ce, label)
+ -- add total ce
self.total_ce = self.total_ce - ce:rowsum():colsum()[0]
self.total_frames = self.total_frames + soutput:nrow()
end
function SoftmaxCELayer:back_propagate(next_bp_err, bp_err, input, output)
-- softmax output - label
- next_bp_err[1]:add(self.soutput, input[2], 1.0, -1.0)
+ local label = input[2]
+ if self.compressed then
+ label = label:decompress(input[1]:ncol())
+ end
+ next_bp_err[1]:add(self.soutput, label, 1.0, -1.0)
end