summaryrefslogtreecommitdiff
path: root/layer/softmax_ce.lua
diff options
context:
space:
mode:
authorDeterminant <[email protected]>2015-05-28 17:01:10 +0800
committerDeterminant <[email protected]>2015-05-28 17:01:10 +0800
commite934b616496940bfe0924ca1992035d2346baa62 (patch)
tree6ed5398d9123cc2cbfd2b09ac1aed74db42299c4 /layer/softmax_ce.lua
parente4dedc2992149d245ea65132131253072d3276b8 (diff)
add softmax + ce layer; test_dnn_layers produces the same result as TNet
Diffstat (limited to 'layer/softmax_ce.lua')
-rw-r--r--layer/softmax_ce.lua32
1 files changed, 32 insertions, 0 deletions
diff --git a/layer/softmax_ce.lua b/layer/softmax_ce.lua
new file mode 100644
index 0000000..37d2864
--- /dev/null
+++ b/layer/softmax_ce.lua
@@ -0,0 +1,32 @@
+local SoftmaxCELayer = nerv.class("nerv.SoftmaxCELayer", "nerv.Layer")
+
+function SoftmaxCELayer:__init(id, global_conf)
+ self.id = id
+ self.gconf = global_conf
+end
+
+function SoftmaxCELayer:init()
+ self.total_ce = 0.0
+ self.total_frames = 0
+end
+
+function SoftmaxCELayer:update(bp_err, input, output)
+ -- no params, therefore do nothing
+end
+
+function SoftmaxCELayer:propagate(input, output)
+ local soutput = input[0]:create() -- temporary value for calc softmax
+ self.soutput = soutput
+ soutput:softmax(input[0])
+ local ce = soutput:create()
+ ce:log_elem(soutput)
+ ce:mul_elem(ce, input[1])
+ -- add total ce
+ self.total_ce = self.total_ce - ce:rowsum():colsum()[0]
+ self.total_frames = self.total_frames + soutput:nrow()
+end
+
+function SoftmaxCELayer:back_propagate(next_bp_err, bp_err, input, output)
+ -- softmax output - label
+ next_bp_err[0]:add(self.soutput, input[1], 1.0, -1.0)
+end