summaryrefslogtreecommitdiff
path: root/layer/softmax_ce.lua
diff options
context:
space:
mode:
Diffstat (limited to 'layer/softmax_ce.lua')
-rw-r--r--layer/softmax_ce.lua21
1 files changed, 11 insertions, 10 deletions
diff --git a/layer/softmax_ce.lua b/layer/softmax_ce.lua
index 7888540..daf891e 100644
--- a/layer/softmax_ce.lua
+++ b/layer/softmax_ce.lua
@@ -12,13 +12,15 @@ function SoftmaxCELayer:__init(id, global_conf, layer_conf)
self:check_dim_len(2, -1) -- two inputs: nn output and label
end
-function SoftmaxCELayer:init()
+function SoftmaxCELayer:init(batch_size)
if not self.compressed and (self.dim_in[1] ~= self.dim_in[2]) then
nerv.error("mismatching dimensions of previous network output and labels")
end
self.total_ce = 0.0
self.total_correct = 0
self.total_frames = 0
+ self.softmax = self.gconf.cumat_type(batch_size, self.dim_in[1])
+ self.ce = self.softmax:create()
end
function SoftmaxCELayer:update(bp_err, input, output)
@@ -26,12 +28,11 @@ function SoftmaxCELayer:update(bp_err, input, output)
end
function SoftmaxCELayer:propagate(input, output)
- local soutput = input[1]:create() -- temporary value for calc softmax
- self.soutput = soutput
- local classified = soutput:softmax(input[1])
- local ce = soutput:create()
- ce:log_elem(soutput)
+ local softmax = self.softmax
+ local ce = self.ce
+ local classified = softmax:softmax(input[1])
local label = input[2]
+ ce:log_elem(softmax)
if self.compressed then
label = label:decompress(input[1]:ncol())
end
@@ -42,26 +43,26 @@ function SoftmaxCELayer:propagate(input, output)
end
-- add total ce
self.total_ce = self.total_ce - ce:colsum()[0]
- self.total_frames = self.total_frames + soutput:nrow()
+ self.total_frames = self.total_frames + softmax:nrow()
-- TODO: add colsame for uncompressed label
if self.compressed then
self.total_correct = self.total_correct + classified:colsame(input[2])[0]
end
end
-function SoftmaxCELayer:back_propagate(next_bp_err, bp_err, input, output)
+function SoftmaxCELayer:back_propagate(bp_err, next_bp_err, input, output)
-- softmax output - label
local label = input[2]
if self.compressed then
label = label:decompress(input[1]:ncol())
end
local nbe = next_bp_err[1]
- nbe:add(self.soutput, label, 1.0, -1.0)
+ nbe:add(self.softmax, label, 1.0, -1.0)
if bp_err[1] ~= nil then
nbe:scale_rows_by_col(bp_err[1])
end
end
function SoftmaxCELayer:get_params()
- return {}
+ return nerv.ParamRepo({})
end