From 776855c75275484a971d450a1a0dd4c5c8b5727b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 28 Jun 2022 17:24:45 -0400 Subject: [PATCH] Fixing a regression with `return_all_scores` introduced in #17606 (#17906) Fixing a regression with `return_all_scores` introduced in #17606 - The legacy test actually tested `return_all_scores=False` (the actual default) instead of `return_all_scores=True` (the actual weird case). This commit adds the correct legacy test and fixes it. Tmp legacy tests. Actually fix the regression (also contains lists) Less diffed code. --- .../pipelines/text_classification.py | 4 +++- .../test_pipelines_text_classification.py | 23 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/text_classification.py b/src/transformers/pipelines/text_classification.py index 590c87c02201c7..dd8de4c7357f2a 100644 --- a/src/transformers/pipelines/text_classification.py +++ b/src/transformers/pipelines/text_classification.py @@ -136,7 +136,9 @@ def __call__(self, *args, **kwargs): If `top_k` is used, one such dictionary is returned per label. """ result = super().__call__(*args, **kwargs) - if isinstance(args[0], str) and isinstance(result, dict): + # TODO try and retrieve it in a nicer way from _sanitize_parameters. + _legacy = "top_k" not in kwargs + if isinstance(args[0], str) and _legacy: # This pipeline is odd, and return a list when single item is run return [result] else: diff --git a/tests/pipelines/test_pipelines_text_classification.py b/tests/pipelines/test_pipelines_text_classification.py index 9251b299224c52..6bbc84989a211d 100644 --- a/tests/pipelines/test_pipelines_text_classification.py +++ b/tests/pipelines/test_pipelines_text_classification.py @@ -60,6 +60,29 @@ def test_small_model_pt(self): outputs = text_classifier("This is great !", return_all_scores=False) self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}]) + outputs = text_classifier("This is great !", return_all_scores=True) + self.assertEqual( + nested_simplify(outputs), [[{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}]] + ) + + outputs = text_classifier(["This is great !", "Something else"], return_all_scores=True) + self.assertEqual( + nested_simplify(outputs), + [ + [{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}], + [{"label": "LABEL_0", "score": 0.504}, {"label": "LABEL_1", "score": 0.496}], + ], + ) + + outputs = text_classifier(["This is great !", "Something else"], return_all_scores=False) + self.assertEqual( + nested_simplify(outputs), + [ + {"label": "LABEL_0", "score": 0.504}, + {"label": "LABEL_0", "score": 0.504}, + ], + ) + @require_torch def test_accepts_torch_device(self): import torch