Skip to content

Commit

Permalink
Merge pull request apache#81 from oist/multi_aces
Browse files Browse the repository at this point in the history
adds MultiACE or ACE per class
  • Loading branch information
pluskid committed Apr 16, 2016
2 parents 79bfdbe + d85549b commit 3bff587
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions src/metric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,60 @@ function update!(metric :: ACE, labels :: Vector{NDArray}, preds :: Vector{NDArr
end
end

#=doc
.. class:: MultiACE
Averaged cross-entropy for classification. This also know als logloss.
This variant keeps track of the different losses per class.
Calculated the averaged cross entropy for multi-dimentions output.
=#
type MultiACE <: AbstractEvalMetric
aces :: Vector{Float64}
counts :: Vector{Int}

MultiACE(nclasses) = new(Base.zeros(nclasses), Base.zeros(Int, nclasses))
end

function get(metric :: MultiACE)
aces = [(symbol("ACE_$(i-0)"), - metric.aces[i] / metric.counts[i]) for i in 1:length(metric.aces)]
push!(aces, (:ACE, - Base.sum(metric.aces) / Base.sum(metric.counts)))
return aces
end

function reset!(metric :: MultiACE)
metric.aces = Base.zero(metric.aces)
metric.counts = Base.zero(metric.counts)
end

function _update_single_output(metric :: MultiACE, label :: NDArray, pred :: NDArray)
@nd_as_jl ro=(label,pred) begin
# Samples are stored in the last dimension
@assert size(label, ndims(label)) == size(pred, ndims(pred))
@assert ndims(pred) == 4

labels = reshape(label, size(pred, 1, 2)..., 1, size(pred, 4))
for sample in 1:size(labels, 4)
for j in 1:size(labels, 2)
for i in 1:size(labels, 1)
label = labels[i, j, 1, sample]

# Cross-entropy reduces to -(ln(p_1)*0 + ln(p_2)*1) for classification
# Since we can only target labels right now this is the only thing we can do.
target = Int(label) + 1 # klasses are 0...k-1 => julia indexing
p_k = pred[i, j, target, sample]

metric.aces[target] += log(p_k)
metric.counts[target] += 1
end
end
end
end
end

function update!(metric :: MultiACE, labels :: Vector{NDArray}, preds :: Vector{NDArray})
@assert length(labels) == length(preds)
for i = 1:length(labels)
_update_single_output(metric, labels[i], preds[i])
end
end

0 comments on commit 3bff587

Please sign in to comment.