diff --git a/hands_on/pyanno_voting/pyanno/tests/test_voting.py b/hands_on/pyanno_voting/pyanno/tests/test_voting.py index f21a10c..a4a420f 100644 --- a/hands_on/pyanno_voting/pyanno/tests/test_voting.py +++ b/hands_on/pyanno_voting/pyanno/tests/test_voting.py @@ -41,3 +41,15 @@ def test_majority_vote_empty_item(): expected = [1, MV, 2] result = voting.majority_vote(annotations) assert result == expected + +from numpy.testing import assert_array_equal +def test_frequency_vote(): + annotations = [ + [1, 1, 2], + [-1, 1, 2], + ] + classes=4 + expected = [0, 0.6, 0.4,0] + result = voting.labels_frequency(annotations,classes) + assert_array_equal(result,expected) + diff --git a/hands_on/pyanno_voting/pyanno/voting.py b/hands_on/pyanno_voting/pyanno/voting.py index d5b5747..acd072b 100644 --- a/hands_on/pyanno_voting/pyanno/voting.py +++ b/hands_on/pyanno_voting/pyanno/voting.py @@ -100,3 +100,12 @@ def labels_frequency(annotations, nclasses): freq[k] is the frequency of elements of class k in `annotations`, i.e. their count over the number of total of observed (non-missing) elements """ + + votes = np.concatenate(annotations) + correct_votes = votes[votes >= 0] + freq = np.zeros(nclasses) + for iClass in range(nclasses): + freq[iClass] = sum(correct_votes == iClass)/len(correct_votes) + + return freq +