aboutsummaryrefslogtreecommitdiff
path: root/layer/softmax_ce.lua
diff options
context:
space:
mode:
Diffstat (limited to 'layer/softmax_ce.lua')
-rw-r--r--layer/softmax_ce.lua32
1 files changed, 32 insertions, 0 deletions
diff --git a/layer/softmax_ce.lua b/layer/softmax_ce.lua
new file mode 100644
index 0000000..37d2864
--- /dev/null
+++ b/layer/softmax_ce.lua
@@ -0,0 +1,32 @@
+local SoftmaxCELayer = nerv.class("nerv.SoftmaxCELayer", "nerv.Layer")
+
+function SoftmaxCELayer:__init(id, global_conf)
+ self.id = id
+ self.gconf = global_conf
+end
+
+function SoftmaxCELayer:init()
+ self.total_ce = 0.0
+ self.total_frames = 0
+end
+
+function SoftmaxCELayer:update(bp_err, input, output)
+ -- no params, therefore do nothing
+end
+
+function SoftmaxCELayer:propagate(input, output)
+ local soutput = input[0]:create() -- temporary value for calc softmax
+ self.soutput = soutput
+ soutput:softmax(input[0])
+ local ce = soutput:create()
+ ce:log_elem(soutput)
+ ce:mul_elem(ce, input[1])
+ -- add total ce
+ self.total_ce = self.total_ce - ce:rowsum():colsum()[0]
+ self.total_frames = self.total_frames + soutput:nrow()
+end
+
+function SoftmaxCELayer:back_propagate(next_bp_err, bp_err, input, output)
+ -- softmax output - label
+ next_bp_err[0]:add(self.soutput, input[1], 1.0, -1.0)
+end