From d85549b86e259db6859ccc96da4e0053653d756a Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Sat, 16 Apr 2016 23:20:14 +0900 Subject: [PATCH] adds MultiACE or ACE per class --- src/metric.jl | 57 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/src/metric.jl b/src/metric.jl index a22794e9f158..7916d45b639c 100644 --- a/src/metric.jl +++ b/src/metric.jl @@ -200,3 +200,60 @@ function update!(metric :: ACE, labels :: Vector{NDArray}, preds :: Vector{NDArr end end +#=doc +.. class:: MultiACE + + Averaged cross-entropy for classification. This also know als logloss. + This variant keeps track of the different losses per class. + + Calculated the averaged cross entropy for multi-dimentions output. +=# +type MultiACE <: AbstractEvalMetric + aces :: Vector{Float64} + counts :: Vector{Int} + + MultiACE(nclasses) = new(Base.zeros(nclasses), Base.zeros(Int, nclasses)) +end + +function get(metric :: MultiACE) + aces = [(symbol("ACE_$(i-0)"), - metric.aces[i] / metric.counts[i]) for i in 1:length(metric.aces)] + push!(aces, (:ACE, - Base.sum(metric.aces) / Base.sum(metric.counts))) + return aces +end + +function reset!(metric :: MultiACE) + metric.aces = Base.zero(metric.aces) + metric.counts = Base.zero(metric.counts) +end + +function _update_single_output(metric :: MultiACE, label :: NDArray, pred :: NDArray) + @nd_as_jl ro=(label,pred) begin + # Samples are stored in the last dimension + @assert size(label, ndims(label)) == size(pred, ndims(pred)) + @assert ndims(pred) == 4 + + labels = reshape(label, size(pred, 1, 2)..., 1, size(pred, 4)) + for sample in 1:size(labels, 4) + for j in 1:size(labels, 2) + for i in 1:size(labels, 1) + label = labels[i, j, 1, sample] + + # Cross-entropy reduces to -(ln(p_1)*0 + ln(p_2)*1) for classification + # Since we can only target labels right now this is the only thing we can do. + target = Int(label) + 1 # klasses are 0...k-1 => julia indexing + p_k = pred[i, j, target, sample] + + metric.aces[target] += log(p_k) + metric.counts[target] += 1 + end + end + end + end +end + +function update!(metric :: MultiACE, labels :: Vector{NDArray}, preds :: Vector{NDArray}) + @assert length(labels) == length(preds) + for i = 1:length(labels) + _update_single_output(metric, labels[i], preds[i]) + end +end