Skip to content

Commit

Permalink
Fix CI for metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
albertvillanova committed Aug 13, 2024
1 parent af818af commit 83e5c05
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@
"jiwer",
"langdetect",
"mauve-text",
"nltk",
"nltk<3.8.2",
"rouge_score",
"sacrebleu",
"sacremoses",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_inspect_dataset(path, tmp_path):
@pytest.mark.filterwarnings("ignore:metric_module_factory is deprecated:FutureWarning")
@pytest.mark.parametrize("path", ["accuracy"])
def test_inspect_metric(path, tmp_path):
inspect_metric(path, tmp_path, trust_remote_code=True)
inspect_metric(path, tmp_path, trust_remote_code=True, revision="2.21")
script_name = path + ".py"
assert script_name in os.listdir(tmp_path)
assert "__pycache__" not in os.listdir(tmp_path)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ def test_GithubMetricModuleFactory_with_internal_import(self):
# "squad_v2" requires additional imports (internal)
factory = GithubMetricModuleFactory(
"squad_v2",
revision="2.21",
download_config=self.download_config,
dynamic_modules_path=self.dynamic_modules_path,
trust_remote_code=True,
Expand All @@ -464,6 +465,7 @@ def test_GithubMetricModuleFactory_with_external_import(self):
# "bleu" requires additional imports (external from github)
factory = GithubMetricModuleFactory(
"bleu",
revision="2.21",
download_config=self.download_config,
dynamic_modules_path=self.dynamic_modules_path,
trust_remote_code=True,
Expand Down
3 changes: 2 additions & 1 deletion tests/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from datasets.features import Features, Sequence, Value
from datasets.metric import Metric, MetricInfo

from .utils import require_tf, require_torch
from .utils import require_numpy1_on_windows, require_tf, require_torch


class DummyMetric(Metric):
Expand Down Expand Up @@ -433,6 +433,7 @@ def test_input_numpy(self):
self.assertDictEqual(expected_results, metric.compute())
del metric

@require_numpy1_on_windows
@require_torch
def test_input_torch(self):
import torch
Expand Down

0 comments on commit 83e5c05

Please sign in to comment.