Skip to content

Commit

Permalink
Introduces seqio.CollectingMetric class
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 466117278
  • Loading branch information
KEHANG authored and t5-copybara committed Apr 13, 2023
1 parent 421f9c3 commit 5d10d43
Show file tree
Hide file tree
Showing 2 changed files with 321 additions and 2 deletions.
110 changes: 109 additions & 1 deletion t5/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,28 @@
import itertools
import re
import string
from typing import Dict, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union

from absl import logging
import editdistance
import flax
import jax.numpy as jnp
import numpy as np
import sacrebleu
import scipy.stats
import seqio
import sklearn.metrics
from t5.evaluation import qa_utils
import tensorflow.compat.v2 as tf

from rouge_score import rouge_scorer
from rouge_score import scoring


ModelOutputType = seqio.metrics.ModelOutputType
CollectingMetric = seqio.metrics.CollectingMetric


def bleu(targets, predictions, tokenizer="intl"):
"""Computes BLEU score.
Expand Down Expand Up @@ -643,3 +651,103 @@ def edit_distance(targets, predictions, lower=True):
"mean_edit": np.mean(edit_distances),
"median_edit": np.median(edit_distances),
"sum_edit": sum(edit_distances)}


@flax.struct.dataclass
class ShardedSquad(seqio.metrics.Metric):
"""Implements SQuAD metrics, maximizing over answers per question."""

f1: float = 0.0
em: float = 0.0
count: int = 0
model_output_type: ModelOutputType = ModelOutputType.PREDICTION

@classmethod
def empty(cls) -> "ShardedSquad":
return cls(f1=0.0, em=0.0, count=0)

@classmethod
def from_model_output(
cls,
inputs: Sequence[Mapping[str, Any]],
model_output: np.ndarray,
features: Mapping[str, seqio.Feature],
target_field_name: str = "targets",
mask: Optional[np.ndarray] = None,
indices_2d: Optional[np.ndarray] = None) -> "ShardedSquad":

del indices_2d
if mask is None:
mask = jnp.ones((len(inputs),))

# Postprocesses the targets here.
postprocessed_targets = [[
tf.compat.as_text(answers) for answers in example["answers"]
] for example, included in zip(inputs, mask) if included]

# Decodes the predictions here.
vocab = features[target_field_name].vocabulary
predictions = [
vocab.decode(tokens)
for tokens, included in zip(model_output, mask)
if included
]

squad_result = squad(targets=postprocessed_targets, predictions=predictions)
return cls(f1=squad_result["f1"], em=squad_result["em"], count=mask.sum())

def merge(self, other: "ShardedSquad") -> "ShardedSquad":
"""Returns `Squad` that is the accumulation of `self` and `other`.
Args:
other: A `Squad` whose inermediate values should be accumulated onto the
values of `self`. Note that in a distributed setting, `other` will
typically be the output of a `jax.lax` parallel operator and thus have a
dimension added to the dataclass returned by `.from_model_output()`.
Returns:
A new `Squad` that accumulates the value from both `self` and `other`.
"""
count = self.count + other.count
f1 = (self.f1 * self.count + other.f1 * other.count)/count
em = (self.em * self.count + other.em * other.count)/count

return type(self)(f1=f1, em=em, count=count)

def compute(self):
return {"f1": self.f1, "em": self.em}


@flax.struct.dataclass
class PassthroughSquad(CollectingMetric):
"""Implements SQuAD metrics, maximizing over answers per question."""

model_output_type: ModelOutputType = ModelOutputType.PREDICTION

def actual_compute(self, task_dataset_as_numpy, task_output_features,
target_field_name: str = "targets"):
# Postprocesses the targets here.
postprocessed_targets = [[
tf.compat.as_text(answers) for answers in example["answers"]
] for example in task_dataset_as_numpy]

# We process the model outputs here by the steps below.
# Step 1: removes padded examples using mask.
indices_2d = self.values["indices_2d"][self.values["mask"] == 1]
model_output = self.values["model_output"][self.values["mask"] == 1]
assert len(postprocessed_targets) == len(indices_2d)

# Step 2: sorts the model outputs by 2d-indices, namely (shard_id,
# index_within_shard) to align with targets.
permutation = np.lexsort((indices_2d[:, 1], indices_2d[:, 0]))
model_output = [
model_output[permutation[i]] for i in range(len(permutation))
]

# Decodes the predictions here.
target_vocab = task_output_features[target_field_name].vocabulary
predictions = [
target_vocab.decode(tokens) for tokens in model_output
]

return squad(postprocessed_targets, predictions), None
213 changes: 212 additions & 1 deletion t5/evaluation/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@

"""Tests for t5.evaluation.metrics."""

from unittest import mock

from absl.testing import absltest
import numpy as np
import seqio
import sklearn.metrics

from t5.evaluation import metrics
from t5.evaluation import test_utils

Expand Down Expand Up @@ -706,5 +709,213 @@ def test_edit_distance(self):
})


def mock_decode(self, ids):
decode_dict = {v: k for k, v in self._encode_dict.items()}
words = [decode_dict[token] for token in ids if token != 0]
return " ".join(words)


class PassthroughSquadTest(test_utils.BaseMetricsTest):

def test_same(self):
ref = "this is a string"
inputs = [{"answers": ["", ref]}, {"answers": [ref, ref]}]

with mock.patch.object(
seqio.test_utils.MockVocabulary, "decode", new=mock_decode):
vocabulary = seqio.test_utils.MockVocabulary(
{
"this": 2,
"is": 3,
"a": 4,
"string": 5
}, vocab_size=10)

model_output = np.array([[2, 3, 4, 5], [2, 3, 4, 5]])
features = {"targets": seqio.Feature(vocabulary)}
metric = metrics.PassthroughSquad.from_model_output(
inputs, model_output, features)
self.assertDictClose(metric.actual_compute(inputs, features)[0],
{"em": 100, "f1": 100})

def test_different(self):
ref = "this is a string"
inputs = [{"answers": [ref, ref]}, {"answers": [ref, ref]}]

with mock.patch.object(
seqio.test_utils.MockVocabulary, "decode", new=mock_decode):
vocabulary = seqio.test_utils.MockVocabulary(
{
"this": 2,
"is": 3,
"a": 4,
"string": 5,
"": 6
}, vocab_size=10)

model_output = np.array([[6], [6]])
features = {"targets": seqio.Feature(vocabulary)}
metric = metrics.PassthroughSquad.from_model_output(
inputs, model_output, features)
self.assertDictClose(metric.actual_compute(inputs, features)[0],
{"em": 0, "f1": 0})

def test_big(self):
inputs = [
{"answers": ["big moose", "hippo"]},
{"answers": ["correct1"]},
{"answers": ["correct2.1", "correct2.2"]},
{"answers": ["a", "b"]},
]

with mock.patch.object(
seqio.test_utils.MockVocabulary, "decode", new=mock_decode):
vocabulary = seqio.test_utils.MockVocabulary(
{
"‘a": 2,
"big": 3,
"Moose!‘": 4,
"wrong": 5,
"correct2.2": 6,
"c": 7
}, vocab_size=10)

model_output = np.array([[2, 3, 4], [5, 0, 0], [6, 0, 0], [7, 0, 0]])
features = {"targets": seqio.Feature(vocabulary)}
metric = metrics.PassthroughSquad.from_model_output(
inputs, model_output, features)
self.assertDictClose(metric.actual_compute(inputs, features)[0],
{"em": 25., "f1": 35.}, places=2)

def test_small(self):
inputs = [{"answers": ["abc abd", "$$$$"]}]

with mock.patch.object(
seqio.test_utils.MockVocabulary, "decode", new=mock_decode):
vocabulary = seqio.test_utils.MockVocabulary({"abd": 2}, vocab_size=10)

model_output = np.array([[2]])
features = {"targets": seqio.Feature(vocabulary)}
metric = metrics.PassthroughSquad.from_model_output(
inputs, model_output, features)
self.assertDictClose(metric.actual_compute(inputs, features)[0],
{"f1": 100 * 2.0 / 3.0, "em": 0.})


class ShardedSquadTest(test_utils.BaseMetricsTest):

def test_same(self):
ref = "this is a string"
inputs = [{"answers": ["", ref]}, {"answers": [ref, ref]}]

with mock.patch.object(
seqio.test_utils.MockVocabulary, "decode", new=mock_decode):
vocabulary = seqio.test_utils.MockVocabulary(
{
"this": 2,
"is": 3,
"a": 4,
"string": 5
}, vocab_size=10)

model_output = np.array([[2, 3, 4, 5], [2, 3, 4, 5]])
features = {"targets": seqio.Feature(vocabulary)}
metric = metrics.ShardedSquad.from_model_output(
inputs, model_output, features)
self.assertDictClose(metric.compute(), {"em": 100, "f1": 100})

def test_different(self):
ref = "this is a string"
inputs = [{"answers": [ref, ref]}, {"answers": [ref, ref]}]

with mock.patch.object(
seqio.test_utils.MockVocabulary, "decode", new=mock_decode):
vocabulary = seqio.test_utils.MockVocabulary(
{
"this": 2,
"is": 3,
"a": 4,
"string": 5,
"": 6
}, vocab_size=10)

model_output = np.array([[6], [6]])
features = {"targets": seqio.Feature(vocabulary)}
metric = metrics.ShardedSquad.from_model_output(
inputs, model_output, features)
self.assertDictClose(metric.compute(), {"em": 0, "f1": 0})

def test_big(self):
inputs = [
{"answers": ["big moose", "hippo"]},
{"answers": ["correct1"]},
{"answers": ["correct2.1", "correct2.2"]},
{"answers": ["a", "b"]},
]

with mock.patch.object(
seqio.test_utils.MockVocabulary, "decode", new=mock_decode):
vocabulary = seqio.test_utils.MockVocabulary(
{
"‘a": 2,
"big": 3,
"Moose!‘": 4,
"wrong": 5,
"correct2.2": 6,
"c": 7
}, vocab_size=10)

model_output = np.array([[2, 3, 4], [5, 0, 0], [6, 0, 0], [7, 0, 0]])
features = {"targets": seqio.Feature(vocabulary)}
metric = metrics.ShardedSquad.from_model_output(
inputs, model_output, features)
self.assertDictClose(metric.compute(), {"em": 25., "f1": 35.}, places=2)

def test_small(self):
inputs = [{"answers": ["abc abd", "$$$$"]}]

with mock.patch.object(
seqio.test_utils.MockVocabulary, "decode", new=mock_decode):
vocabulary = seqio.test_utils.MockVocabulary({"abd": 2}, vocab_size=10)

model_output = np.array([[2]])
features = {"targets": seqio.Feature(vocabulary)}
metric = metrics.ShardedSquad.from_model_output(
inputs, model_output, features)
self.assertDictClose(metric.compute(), {"f1": 100 * 2.0 / 3.0, "em": 0.})

def test_batch_update(self):
inputs1 = [
{"answers": ["big moose", "hippo"]},
{"answers": ["correct1"]}
]
inputs2 = [
{"answers": ["correct2.1", "correct2.2"]},
{"answers": ["a", "b"]},
]

with mock.patch.object(
seqio.test_utils.MockVocabulary, "decode", new=mock_decode):
vocabulary = seqio.test_utils.MockVocabulary(
{
"‘a": 2,
"big": 3,
"Moose!‘": 4,
"wrong": 5,
"correct2.2": 6,
"c": 7
}, vocab_size=10)

model_output1 = np.array([[2, 3, 4], [5, 0, 0]])
model_output2 = np.array([[6], [7]])
features = {"targets": seqio.Feature(vocabulary)}
metric1 = metrics.ShardedSquad.from_model_output(
inputs1, model_output1, features)
metric2 = metrics.ShardedSquad.from_model_output(
inputs2, model_output2, features)
metric = metric1.merge(metric2)
self.assertDictClose(metric.compute(), {"em": 25., "f1": 35.}, places=2)


if __name__ == "__main__":
absltest.main()

0 comments on commit 5d10d43

Please sign in to comment.