aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nerv/layer/softmax.lua7
1 files changed, 6 insertions, 1 deletions
diff --git a/nerv/layer/softmax.lua b/nerv/layer/softmax.lua
index f7a5163..81dcacc 100644
--- a/nerv/layer/softmax.lua
+++ b/nerv/layer/softmax.lua
@@ -28,7 +28,12 @@ function SoftmaxLayer:propagate(input, output)
end
function SoftmaxLayer:back_propagate(bp_err, next_bp_err, input, output)
- nerv.error_method_not_implemented()
+ local nbe = next_bp_err[1]
+ nbe:mul_elem(bp_err[1], output[1])
+ local offset = nbe:rowsum()
+ nbe:copy_from(bp_err[1])
+ nbe:add_row(offset, -1.0)
+ nbe:mul_elem(nbe, output[1])
end
function SoftmaxLayer:get_params()