Skip to content

Commit

Permalink
Fixing a regression with return_all_scores introduced in #17606 (#1…
Browse files Browse the repository at this point in the history
…7906)

Fixing a regression with `return_all_scores` introduced in #17606

- The legacy test actually tested `return_all_scores=False` (the actual
  default) instead of `return_all_scores=True` (the actual weird case).

This commit adds the correct legacy test and fixes it.

Tmp legacy tests.

Actually fix the regression (also contains lists)

Less diffed code.
  • Loading branch information
Narsil authored Jun 28, 2022
1 parent 5f1e67a commit 776855c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/transformers/pipelines/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ def __call__(self, *args, **kwargs):
If `top_k` is used, one such dictionary is returned per label.
"""
result = super().__call__(*args, **kwargs)
if isinstance(args[0], str) and isinstance(result, dict):
# TODO try and retrieve it in a nicer way from _sanitize_parameters.
_legacy = "top_k" not in kwargs
if isinstance(args[0], str) and _legacy:
# This pipeline is odd, and return a list when single item is run
return [result]
else:
Expand Down
23 changes: 23 additions & 0 deletions tests/pipelines/test_pipelines_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,29 @@ def test_small_model_pt(self):
outputs = text_classifier("This is great !", return_all_scores=False)
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])

outputs = text_classifier("This is great !", return_all_scores=True)
self.assertEqual(
nested_simplify(outputs), [[{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}]]
)

outputs = text_classifier(["This is great !", "Something else"], return_all_scores=True)
self.assertEqual(
nested_simplify(outputs),
[
[{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}],
[{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}],
],
)

outputs = text_classifier(["This is great !", "Something else"], return_all_scores=False)
self.assertEqual(
nested_simplify(outputs),
[
{"label": "LABEL_0", "score": 0.504},
{"label": "LABEL_0", "score": 0.504},
],
)

@require_torch
def test_accepts_torch_device(self):
import torch
Expand Down

0 comments on commit 776855c

Please sign in to comment.