diff options
Diffstat (limited to 'layer/softmax_ce.lua')
-rw-r--r-- | layer/softmax_ce.lua | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/layer/softmax_ce.lua b/layer/softmax_ce.lua new file mode 100644 index 0000000..37d2864 --- /dev/null +++ b/layer/softmax_ce.lua @@ -0,0 +1,32 @@ +local SoftmaxCELayer = nerv.class("nerv.SoftmaxCELayer", "nerv.Layer") + +function SoftmaxCELayer:__init(id, global_conf) + self.id = id + self.gconf = global_conf +end + +function SoftmaxCELayer:init() + self.total_ce = 0.0 + self.total_frames = 0 +end + +function SoftmaxCELayer:update(bp_err, input, output) + -- no params, therefore do nothing +end + +function SoftmaxCELayer:propagate(input, output) + local soutput = input[0]:create() -- temporary value for calc softmax + self.soutput = soutput + soutput:softmax(input[0]) + local ce = soutput:create() + ce:log_elem(soutput) + ce:mul_elem(ce, input[1]) + -- add total ce + self.total_ce = self.total_ce - ce:rowsum():colsum()[0] + self.total_frames = self.total_frames + soutput:nrow() +end + +function SoftmaxCELayer:back_propagate(next_bp_err, bp_err, input, output) + -- softmax output - label + next_bp_err[0]:add(self.soutput, input[1], 1.0, -1.0) +end |