diff options
Diffstat (limited to 'layer/softmax_ce.lua')
-rw-r--r-- | layer/softmax_ce.lua | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/layer/softmax_ce.lua b/layer/softmax_ce.lua index 7888540..daf891e 100644 --- a/layer/softmax_ce.lua +++ b/layer/softmax_ce.lua @@ -12,13 +12,15 @@ function SoftmaxCELayer:__init(id, global_conf, layer_conf) self:check_dim_len(2, -1) -- two inputs: nn output and label end -function SoftmaxCELayer:init() +function SoftmaxCELayer:init(batch_size) if not self.compressed and (self.dim_in[1] ~= self.dim_in[2]) then nerv.error("mismatching dimensions of previous network output and labels") end self.total_ce = 0.0 self.total_correct = 0 self.total_frames = 0 + self.softmax = self.gconf.cumat_type(batch_size, self.dim_in[1]) + self.ce = self.softmax:create() end function SoftmaxCELayer:update(bp_err, input, output) @@ -26,12 +28,11 @@ function SoftmaxCELayer:update(bp_err, input, output) end function SoftmaxCELayer:propagate(input, output) - local soutput = input[1]:create() -- temporary value for calc softmax - self.soutput = soutput - local classified = soutput:softmax(input[1]) - local ce = soutput:create() - ce:log_elem(soutput) + local softmax = self.softmax + local ce = self.ce + local classified = softmax:softmax(input[1]) local label = input[2] + ce:log_elem(softmax) if self.compressed then label = label:decompress(input[1]:ncol()) end @@ -42,26 +43,26 @@ function SoftmaxCELayer:propagate(input, output) end -- add total ce self.total_ce = self.total_ce - ce:colsum()[0] - self.total_frames = self.total_frames + soutput:nrow() + self.total_frames = self.total_frames + softmax:nrow() -- TODO: add colsame for uncompressed label if self.compressed then self.total_correct = self.total_correct + classified:colsame(input[2])[0] end end -function SoftmaxCELayer:back_propagate(next_bp_err, bp_err, input, output) +function SoftmaxCELayer:back_propagate(bp_err, next_bp_err, input, output) -- softmax output - label local label = input[2] if self.compressed then label = label:decompress(input[1]:ncol()) end local nbe = next_bp_err[1] - nbe:add(self.soutput, label, 1.0, -1.0) + nbe:add(self.softmax, label, 1.0, -1.0) if bp_err[1] ~= nil then nbe:scale_rows_by_col(bp_err[1]) end end function SoftmaxCELayer:get_params() - return {} + return nerv.ParamRepo({}) end |