Skip to content

Commit

Permalink
Fixing the test.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Jun 8, 2022
1 parent 96f3df4 commit fb69da1
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/pipelines/test_pipelines_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ def run_pipeline_test(self, text_classifier, _):

# Forcing to get all results with `top_k=None`
# This is NOT the legacy format
outputs = text_classifier(valid_inputs, top_k=None)
N = len(model.config.id2label.values())
self.assertEqual(
nested_simplify(outputs, top_k=None),
[[{"label": ANY(str), "score": ANY(float)}], [{"label": ANY(str), "score": ANY(float)}]],
nested_simplify(outputs),
[[{"label": ANY(str), "score": ANY(float)}] * N, [{"label": ANY(str), "score": ANY(float)}] * N],
)
self.assertTrue(outputs[0]["label"] in model.config.id2label.values())
self.assertTrue(outputs[1]["label"] in model.config.id2label.values())

valid_inputs = {"text": "HuggingFace is in ", "text_pair": "Paris is in France"}
outputs = text_classifier(valid_inputs)
Expand Down

0 comments on commit fb69da1

Please sign in to comment.