Skip to content

Commit

Permalink
Internal.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 392341778
  • Loading branch information
T5 Team authored and t5-copybara committed Aug 23, 2021
1 parent 6f2b1ad commit 094d77c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
10 changes: 8 additions & 2 deletions t5/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,10 @@ def f1_score_with_invalid(targets, predictions):
return {"f1": 100 * sklearn.metrics.f1_score(targets, predictions)}


def mean_group_metric(metric_fn, group_key="group", value_key="value"):
def mean_group_metric(metric_fn,
group_key="group",
value_key="value",
return_subgroup_scores=False):
"""Returns a metric that averages `metric_fn` on sub-groups of results.
The sub-groups are defined by aggregating results (targets and predictions)
Expand All @@ -262,6 +265,7 @@ def mean_group_metric(metric_fn, group_key="group", value_key="value"):
metric_fn: function, the metric to compute on the subgroups.
group_key: string, the key for the grouping value in the target dictionary.
value_key: string, the key for the value in the dictionaries.
return_subgroup_scores: If true, include the scores for each sub-group.
"""
def my_metric(targets, predictions):
"""Computes mean of `metric_fn` over subgroups of results."""
Expand All @@ -271,9 +275,11 @@ def my_metric(targets, predictions):
grouped_values[g][0].append(targ[value_key])
grouped_values[g][1].append(pred[value_key])
group_scores = collections.defaultdict(list)
for (targets, predictions) in grouped_values.values():
for group, (targets, predictions) in grouped_values.items():
for metric, score in metric_fn(targets, predictions).items():
group_scores[metric].append(score)
if return_subgroup_scores:
group_scores["%s-%s" % (group, metric)].append(score)
return {metric: np.mean(scores) for metric, scores in group_scores.items()}
return my_metric

Expand Down
13 changes: 13 additions & 0 deletions t5/evaluation/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,19 @@ def test_mean_group_metric(self):
{"value": 1}]),
{"accuracy": 25.})

def test_mean_group_metric_with_subgroups(self):
metric_fn = metrics.mean_group_metric(
metrics.accuracy, return_subgroup_scores=True)
self.assertDictClose(
metric_fn(
[{"group": "a", "value": 0},
{"group": "a", "value": 1},
{"group": "b", "value": 0}],
[{"value": 0},
{"value": 0},
{"value": 1}]),
{"accuracy": 25.0, "a-accuracy": 50.0, "b-accuracy": 0.0})

def test_multirc_f1_over_all_answers(self):
metric_fn = metrics.multirc_f1_over_all_answers
self.assertDictClose(
Expand Down

0 comments on commit 094d77c

Please sign in to comment.