From a2dc287cbef5311cf1a32ad4e3685f4052db227c Mon Sep 17 00:00:00 2001 From: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Date: Fri, 7 Jun 2024 14:20:29 +0200 Subject: [PATCH] Remove default `trust_remote_code=True` (#6954) * removev default trust_remote_code=True * fix tests * fix tests * again * again * style --- src/datasets/commands/test.py | 10 +++- src/datasets/config.py | 2 +- src/datasets/load.py | 5 +- tests/commands/test_test.py | 7 ++- tests/features/test_audio.py | 4 +- tests/features/test_image.py | 4 +- tests/test_hf_gcp.py | 9 +++- tests/test_hub.py | 2 +- tests/test_inspect.py | 6 +-- tests/test_load.py | 89 ++++++++++++++++++++++------------- tests/test_metric_common.py | 6 ++- tests/test_warnings.py | 11 +++-- 12 files changed, 101 insertions(+), 54 deletions(-) diff --git a/src/datasets/commands/test.py b/src/datasets/commands/test.py index da82427e935..986bf490a1d 100644 --- a/src/datasets/commands/test.py +++ b/src/datasets/commands/test.py @@ -3,7 +3,7 @@ from argparse import ArgumentParser from pathlib import Path from shutil import copyfile, rmtree -from typing import Generator +from typing import Generator, Optional import datasets.config from datasets.builder import DatasetBuilder @@ -29,6 +29,7 @@ def _test_command_factory(args): args.force_redownload, args.clear_cache, args.num_proc, + args.trust_remote_code, ) @@ -67,6 +68,9 @@ def register_subcommand(parser: ArgumentParser): help="Remove downloaded files and cached datasets after each config test", ) test_parser.add_argument("--num_proc", type=int, default=None, help="Number of processes") + test_parser.add_argument( + "--trust_remote_code", action="store_true", help="whether to trust the code execution of the load script" + ) # aliases test_parser.add_argument("--save_infos", action="store_true", help="alias to save_info") test_parser.add_argument("dataset", type=str, help="Name of the dataset to download") @@ -84,6 +88,7 @@ def __init__( force_redownload: bool, clear_cache: bool, num_proc: int, + trust_remote_code: Optional[bool], ): self._dataset = dataset self._name = name @@ -95,6 +100,7 @@ def __init__( self._force_redownload = force_redownload self._clear_cache = clear_cache self._num_proc = num_proc + self._trust_remote_code = trust_remote_code if clear_cache and not cache_dir: print( "When --clear_cache is used, specifying a cache directory is mandatory.\n" @@ -111,7 +117,7 @@ def run(self): print("Both parameters `config` and `all_configs` can't be used at once.") exit(1) path, config_name = self._dataset, self._name - module = dataset_module_factory(path) + module = dataset_module_factory(path, trust_remote_code=self._trust_remote_code) builder_cls = import_main_class(module.module_path) n_builders = len(builder_cls.BUILDER_CONFIGS) if self._all_configs and builder_cls.BUILDER_CONFIGS else 1 diff --git a/src/datasets/config.py b/src/datasets/config.py index 9668dfbd91e..c5b0f3ded40 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -186,7 +186,7 @@ HF_DATASETS_MULTITHREADING_MAX_WORKERS = 16 # Remote dataset scripts support -__HF_DATASETS_TRUST_REMOTE_CODE = os.environ.get("HF_DATASETS_TRUST_REMOTE_CODE", "1") +__HF_DATASETS_TRUST_REMOTE_CODE = os.environ.get("HF_DATASETS_TRUST_REMOTE_CODE", "ask") HF_DATASETS_TRUST_REMOTE_CODE: Optional[bool] = ( True if __HF_DATASETS_TRUST_REMOTE_CODE.upper() in ENV_VARS_TRUE_VALUES diff --git a/src/datasets/load.py b/src/datasets/load.py index 824817843fd..2dab9f7a7e6 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -2007,7 +2007,10 @@ def metric_module_factory( raise FileNotFoundError(f"Couldn't find a metric script at {relative_to_absolute_path(path)}") elif os.path.isfile(combined_path): return LocalMetricModuleFactory( - combined_path, download_mode=download_mode, dynamic_modules_path=dynamic_modules_path + combined_path, + download_mode=download_mode, + dynamic_modules_path=dynamic_modules_path, + trust_remote_code=trust_remote_code, ).get_module() elif is_relative_path(path) and path.count("/") == 0: try: diff --git a/tests/commands/test_test.py b/tests/commands/test_test.py index c26ded02c05..c94b9c0de56 100644 --- a/tests/commands/test_test.py +++ b/tests/commands/test_test.py @@ -21,8 +21,9 @@ "force_redownload", "clear_cache", "num_proc", + "trust_remote_code", ], - defaults=[None, None, None, False, False, False, False, False, None], + defaults=[None, None, None, False, False, False, False, False, None, None], ) @@ -32,7 +33,9 @@ def is_1percent_close(source, target): @pytest.mark.integration def test_test_command(dataset_loading_script_dir): - args = _TestCommandArgs(dataset=dataset_loading_script_dir, all_configs=True, save_infos=True) + args = _TestCommandArgs( + dataset=dataset_loading_script_dir, all_configs=True, save_infos=True, trust_remote_code=True + ) test_command = TestCommand(*args) test_command.run() dataset_readme_path = os.path.join(dataset_loading_script_dir, "README.md") diff --git a/tests/features/test_audio.py b/tests/features/test_audio.py index 255a6e4f765..5b58cb7c329 100644 --- a/tests/features/test_audio.py +++ b/tests/features/test_audio.py @@ -604,9 +604,9 @@ def test_load_dataset_with_audio_feature(streaming, jsonl_audio_dataset_path, sh @pytest.mark.integration def test_dataset_with_audio_feature_loaded_from_cache(): # load first time - ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean") + ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", trust_remote_code=True) # load from cache - ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") + ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", trust_remote_code=True, split="validation") assert isinstance(ds, Dataset) diff --git a/tests/features/test_image.py b/tests/features/test_image.py index 35e62f7d2bf..2698c76a472 100644 --- a/tests/features/test_image.py +++ b/tests/features/test_image.py @@ -614,7 +614,9 @@ def test_load_dataset_with_image_feature(shared_datadir, data_dir, dataset_loadi import PIL.Image image_path = str(shared_datadir / "test_image_rgb.jpg") - dset = load_dataset(dataset_loading_script_dir, split="train", data_dir=data_dir, streaming=streaming) + dset = load_dataset( + dataset_loading_script_dir, split="train", data_dir=data_dir, streaming=streaming, trust_remote_code=True + ) item = dset[0] if not streaming else next(iter(dset)) assert item.keys() == {"image", "caption"} assert isinstance(item["image"], PIL.Image.Image) diff --git a/tests/test_hf_gcp.py b/tests/test_hf_gcp.py index 3620b7a7c14..a8fff905a82 100644 --- a/tests/test_hf_gcp.py +++ b/tests/test_hf_gcp.py @@ -72,6 +72,7 @@ def test_dataset_info_available(self, dataset, config_name, revision): config_name, revision=revision, cache_dir=tmp_dir, + trust_remote_code=True, ) dataset_info_url = "/".join( @@ -88,7 +89,7 @@ def test_dataset_info_available(self, dataset, config_name, revision): @pytest.mark.integration def test_as_dataset_from_hf_gcs(tmp_path_factory): tmp_dir = tmp_path_factory.mktemp("test_hf_gcp") / "test_wikipedia_simple" - builder = load_dataset_builder("wikipedia", "20220301.frr", cache_dir=tmp_dir) + builder = load_dataset_builder("wikipedia", "20220301.frr", cache_dir=tmp_dir, trust_remote_code=True) # use the HF cloud storage, not the original download_and_prepare that uses apache-beam builder._download_and_prepare = None builder.download_and_prepare(try_from_hf_gcs=True) @@ -99,7 +100,11 @@ def test_as_dataset_from_hf_gcs(tmp_path_factory): @pytest.mark.integration def test_as_streaming_dataset_from_hf_gcs(tmp_path): builder = load_dataset_builder( - "wikipedia", "20220301.frr", revision="4d013bdd32c475c8536aae00a56efc774f061649", cache_dir=tmp_path + "wikipedia", + "20220301.frr", + revision="4d013bdd32c475c8536aae00a56efc774f061649", + cache_dir=tmp_path, + trust_remote_code=True, ) ds = builder.as_streaming_dataset() assert ds diff --git a/tests/test_hub.py b/tests/test_hub.py index 49b6d504099..9a413ef8e86 100644 --- a/tests/test_hub.py +++ b/tests/test_hub.py @@ -62,7 +62,7 @@ def test_convert_to_parquet(temporary_repo, hf_api, hf_token, ci_hub_config, ci_ with patch.object(datasets.hub.HfApi, "create_commit", return_value=commit_info) as mock_create_commit: with patch.object(datasets.hub.HfApi, "create_branch") as mock_create_branch: with patch.object(datasets.hub.HfApi, "list_repo_tree", return_value=[]): # not needed - _ = convert_to_parquet(repo_id, token=hf_token) + _ = convert_to_parquet(repo_id, token=hf_token, trust_remote_code=True) # mock_create_branch assert mock_create_branch.called assert mock_create_branch.call_count == 2 diff --git a/tests/test_inspect.py b/tests/test_inspect.py index 4f85dabe714..1221a842012 100644 --- a/tests/test_inspect.py +++ b/tests/test_inspect.py @@ -29,7 +29,7 @@ def test_inspect_dataset(path, tmp_path): @pytest.mark.filterwarnings("ignore:metric_module_factory is deprecated:FutureWarning") @pytest.mark.parametrize("path", ["accuracy"]) def test_inspect_metric(path, tmp_path): - inspect_metric(path, tmp_path) + inspect_metric(path, tmp_path, trust_remote_code=True) script_name = path + ".py" assert script_name in os.listdir(tmp_path) assert "__pycache__" not in os.listdir(tmp_path) @@ -79,7 +79,7 @@ def test_get_dataset_config_info_error(path, config_name, expected_exception): ], ) def test_get_dataset_config_names(path, expected): - config_names = get_dataset_config_names(path) + config_names = get_dataset_config_names(path, trust_remote_code=True) assert config_names == expected @@ -97,7 +97,7 @@ def test_get_dataset_config_names(path, expected): ], ) def test_get_dataset_default_config_name(path, expected): - default_config_name = get_dataset_default_config_name(path) + default_config_name = get_dataset_default_config_name(path, trust_remote_code=True) if expected: assert default_config_name == expected else: diff --git a/tests/test_load.py b/tests/test_load.py index c7c413ae10b..2c02e834892 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -441,6 +441,7 @@ def test_HubDatasetModuleFactoryWithScript_with_hub_dataset(self): download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path, revision="861aac88b2c6247dd93ade8b1c189ce714627750", + trust_remote_code=True, ) module_factory_result = factory.get_module() assert importlib.import_module(module_factory_result.module_path) is not None @@ -449,7 +450,10 @@ def test_HubDatasetModuleFactoryWithScript_with_hub_dataset(self): def test_GithubMetricModuleFactory_with_internal_import(self): # "squad_v2" requires additional imports (internal) factory = GithubMetricModuleFactory( - "squad_v2", download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path + "squad_v2", + download_config=self.download_config, + dynamic_modules_path=self.dynamic_modules_path, + trust_remote_code=True, ) module_factory_result = factory.get_module() assert importlib.import_module(module_factory_result.module_path) is not None @@ -458,7 +462,10 @@ def test_GithubMetricModuleFactory_with_internal_import(self): def test_GithubMetricModuleFactory_with_external_import(self): # "bleu" requires additional imports (external from github) factory = GithubMetricModuleFactory( - "bleu", download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path + "bleu", + download_config=self.download_config, + dynamic_modules_path=self.dynamic_modules_path, + trust_remote_code=True, ) module_factory_result = factory.get_module() assert importlib.import_module(module_factory_result.module_path) is not None @@ -466,7 +473,10 @@ def test_GithubMetricModuleFactory_with_external_import(self): def test_LocalMetricModuleFactory(self): path = os.path.join(self._metric_loading_script_dir, f"{METRIC_LOADING_SCRIPT_NAME}.py") factory = LocalMetricModuleFactory( - path, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path + path, + download_config=self.download_config, + dynamic_modules_path=self.dynamic_modules_path, + trust_remote_code=True, ) module_factory_result = factory.get_module() assert importlib.import_module(module_factory_result.module_path) is not None @@ -474,7 +484,10 @@ def test_LocalMetricModuleFactory(self): def test_LocalDatasetModuleFactoryWithScript(self): path = os.path.join(self._dataset_loading_script_dir, f"{DATASET_LOADING_SCRIPT_NAME}.py") factory = LocalDatasetModuleFactoryWithScript( - path, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path + path, + download_config=self.download_config, + dynamic_modules_path=self.dynamic_modules_path, + trust_remote_code=True, ) module_factory_result = factory.get_module() assert importlib.import_module(module_factory_result.module_path) is not None @@ -485,8 +498,7 @@ def test_LocalDatasetModuleFactoryWithScript_dont_trust_remote_code(self): factory = LocalDatasetModuleFactoryWithScript( path, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path ) - with patch.object(config, "HF_DATASETS_TRUST_REMOTE_CODE", None): # this will be the default soon - self.assertRaises(ValueError, factory.get_module) + self.assertRaises(ValueError, factory.get_module) factory = LocalDatasetModuleFactoryWithScript( path, download_config=self.download_config, @@ -823,6 +835,7 @@ def test_HubDatasetModuleFactoryWithScript(self): SAMPLE_DATASET_IDENTIFIER, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path, + trust_remote_code=True, ) module_factory_result = factory.get_module() assert importlib.import_module(module_factory_result.module_path) is not None @@ -882,7 +895,10 @@ def test_CachedDatasetModuleFactory(self): def test_CachedDatasetModuleFactory_with_script(self): path = os.path.join(self._dataset_loading_script_dir, f"{DATASET_LOADING_SCRIPT_NAME}.py") factory = LocalDatasetModuleFactoryWithScript( - path, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path + path, + download_config=self.download_config, + dynamic_modules_path=self.dynamic_modules_path, + trust_remote_code=True, ) module_factory_result = factory.get_module() for offline_mode in OfflineSimulationMode: @@ -899,7 +915,10 @@ def test_CachedDatasetModuleFactory_with_script(self): def test_CachedMetricModuleFactory(self): path = os.path.join(self._metric_loading_script_dir, f"{METRIC_LOADING_SCRIPT_NAME}.py") factory = LocalMetricModuleFactory( - path, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path + path, + download_config=self.download_config, + dynamic_modules_path=self.dynamic_modules_path, + trust_remote_code=True, ) module_factory_result = factory.get_module() for offline_mode in OfflineSimulationMode: @@ -964,7 +983,7 @@ def test_dataset_module_factory(self): dummy_code = "MY_DUMMY_VARIABLE = 'hello there'" module_dir = self._dummy_module_dir(tmp_dir, "__dummy_module_name1__", dummy_code) dataset_module = datasets.load.dataset_module_factory( - module_dir, dynamic_modules_path=self.dynamic_modules_path + module_dir, dynamic_modules_path=self.dynamic_modules_path, trust_remote_code=True ) dummy_module = importlib.import_module(dataset_module.module_path) self.assertEqual(dummy_module.MY_DUMMY_VARIABLE, "hello there") @@ -974,7 +993,7 @@ def test_dataset_module_factory(self): module_dir = self._dummy_module_dir(tmp_dir, "__dummy_module_name1__", dummy_code) module_path = os.path.join(module_dir, "__dummy_module_name1__.py") dataset_module = datasets.load.dataset_module_factory( - module_path, dynamic_modules_path=self.dynamic_modules_path + module_path, dynamic_modules_path=self.dynamic_modules_path, trust_remote_code=True ) dummy_module = importlib.import_module(dataset_module.module_path) self.assertEqual(dummy_module.MY_DUMMY_VARIABLE, "general kenobi") @@ -1007,13 +1026,13 @@ def test_offline_dataset_module_factory_with_script(self): dummy_code = "MY_DUMMY_VARIABLE = 'hello there'" module_dir = self._dummy_module_dir(tmp_dir, "__dummy_module_name2__", dummy_code) dataset_module_1 = datasets.load.dataset_module_factory( - module_dir, dynamic_modules_path=self.dynamic_modules_path + module_dir, dynamic_modules_path=self.dynamic_modules_path, trust_remote_code=True ) time.sleep(0.1) # make sure there's a difference in the OS update time of the python file dummy_code = "MY_DUMMY_VARIABLE = 'general kenobi'" module_dir = self._dummy_module_dir(tmp_dir, "__dummy_module_name2__", dummy_code) dataset_module_2 = datasets.load.dataset_module_factory( - module_dir, dynamic_modules_path=self.dynamic_modules_path + module_dir, dynamic_modules_path=self.dynamic_modules_path, trust_remote_code=True ) for offline_simulation_mode in list(OfflineSimulationMode): with offline(offline_simulation_mode): @@ -1129,7 +1148,7 @@ def test_load_dataset_builder_with_metadata_configs_pickable(serializer): def test_load_dataset_builder_for_absolute_script_dir(dataset_loading_script_dir, data_dir): - builder = datasets.load_dataset_builder(dataset_loading_script_dir, data_dir=data_dir) + builder = datasets.load_dataset_builder(dataset_loading_script_dir, data_dir=data_dir, trust_remote_code=True) assert isinstance(builder, DatasetBuilder) assert builder.name == DATASET_LOADING_SCRIPT_NAME assert builder.dataset_name == DATASET_LOADING_SCRIPT_NAME @@ -1140,7 +1159,7 @@ def test_load_dataset_builder_for_relative_script_dir(dataset_loading_script_dir with set_current_working_directory_to_temp_dir(): relative_script_dir = DATASET_LOADING_SCRIPT_NAME shutil.copytree(dataset_loading_script_dir, relative_script_dir) - builder = datasets.load_dataset_builder(relative_script_dir, data_dir=data_dir) + builder = datasets.load_dataset_builder(relative_script_dir, data_dir=data_dir, trust_remote_code=True) assert isinstance(builder, DatasetBuilder) assert builder.name == DATASET_LOADING_SCRIPT_NAME assert builder.dataset_name == DATASET_LOADING_SCRIPT_NAME @@ -1149,7 +1168,9 @@ def test_load_dataset_builder_for_relative_script_dir(dataset_loading_script_dir def test_load_dataset_builder_for_script_path(dataset_loading_script_dir, data_dir): builder = datasets.load_dataset_builder( - os.path.join(dataset_loading_script_dir, DATASET_LOADING_SCRIPT_NAME + ".py"), data_dir=data_dir + os.path.join(dataset_loading_script_dir, DATASET_LOADING_SCRIPT_NAME + ".py"), + data_dir=data_dir, + trust_remote_code=True, ) assert isinstance(builder, DatasetBuilder) assert builder.name == DATASET_LOADING_SCRIPT_NAME @@ -1198,7 +1219,7 @@ def test_load_dataset_builder_for_community_dataset_with_script(): @pytest.mark.integration def test_load_dataset_builder_for_community_dataset_with_script_no_parquet_export(): with patch.object(config, "USE_PARQUET_EXPORT", False): - builder = datasets.load_dataset_builder(SAMPLE_DATASET_IDENTIFIER) + builder = datasets.load_dataset_builder(SAMPLE_DATASET_IDENTIFIER, trust_remote_code=True) assert isinstance(builder, DatasetBuilder) assert builder.name == SAMPLE_DATASET_IDENTIFIER.split("/")[-1] assert builder.dataset_name == SAMPLE_DATASET_IDENTIFIER.split("/")[-1] @@ -1241,7 +1262,9 @@ def test_load_dataset_builder_fail(): @pytest.mark.parametrize("keep_in_memory", [False, True]) def test_load_dataset_local_script(dataset_loading_script_dir, data_dir, keep_in_memory, caplog): with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase(): - dataset = load_dataset(dataset_loading_script_dir, data_dir=data_dir, keep_in_memory=keep_in_memory) + dataset = load_dataset( + dataset_loading_script_dir, data_dir=data_dir, keep_in_memory=keep_in_memory, trust_remote_code=True + ) assert isinstance(dataset, DatasetDict) assert all(isinstance(d, Dataset) for d in dataset.values()) assert len(dataset) == 2 @@ -1249,7 +1272,7 @@ def test_load_dataset_local_script(dataset_loading_script_dir, data_dir, keep_in def test_load_dataset_cached_local_script(dataset_loading_script_dir, data_dir, caplog): - dataset = load_dataset(dataset_loading_script_dir, data_dir=data_dir) + dataset = load_dataset(dataset_loading_script_dir, data_dir=data_dir, trust_remote_code=True) assert isinstance(dataset, DatasetDict) assert all(isinstance(d, Dataset) for d in dataset.values()) assert len(dataset) == 2 @@ -1304,7 +1327,7 @@ def test_load_dataset_cached_from_hub(stream_from_cache, caplog): def test_load_dataset_streaming(dataset_loading_script_dir, data_dir): - dataset = load_dataset(dataset_loading_script_dir, streaming=True, data_dir=data_dir) + dataset = load_dataset(dataset_loading_script_dir, streaming=True, data_dir=data_dir, trust_remote_code=True) assert isinstance(dataset, IterableDatasetDict) assert all(isinstance(d, IterableDataset) for d in dataset.values()) assert len(dataset) == 2 @@ -1586,7 +1609,9 @@ def test_load_dataset_private_zipped_images(hf_private_dataset_repo_zipped_img_d def test_load_dataset_then_move_then_reload(dataset_loading_script_dir, data_dir, tmp_path, caplog): cache_dir1 = tmp_path / "cache1" cache_dir2 = tmp_path / "cache2" - dataset = load_dataset(dataset_loading_script_dir, data_dir=data_dir, split="train", cache_dir=cache_dir1) + dataset = load_dataset( + dataset_loading_script_dir, data_dir=data_dir, split="train", cache_dir=cache_dir1, trust_remote_code=True + ) fingerprint1 = dataset._fingerprint del dataset os.rename(cache_dir1, cache_dir2) @@ -1614,7 +1639,9 @@ def test_load_dataset_builder_then_edit_then_load_again(tmp_path: Path): def test_load_dataset_readonly(dataset_loading_script_dir, dataset_loading_script_dir_readonly, data_dir, tmp_path): cache_dir1 = tmp_path / "cache1" cache_dir2 = tmp_path / "cache2" - dataset = load_dataset(dataset_loading_script_dir, data_dir=data_dir, split="train", cache_dir=cache_dir1) + dataset = load_dataset( + dataset_loading_script_dir, data_dir=data_dir, split="train", cache_dir=cache_dir1, trust_remote_code=True + ) fingerprint1 = dataset._fingerprint del dataset # Load readonly dataset and check that the fingerprint is the same. @@ -1637,7 +1664,7 @@ def test_load_dataset_local_with_default_in_memory( expected_in_memory = False with assert_arrow_memory_increases() if expected_in_memory else assert_arrow_memory_doesnt_increase(): - dataset = load_dataset(dataset_loading_script_dir, data_dir=data_dir) + dataset = load_dataset(dataset_loading_script_dir, data_dir=data_dir, trust_remote_code=True) assert (dataset["train"].dataset_size < max_in_memory_dataset_size) is expected_in_memory @@ -1655,7 +1682,7 @@ def test_load_from_disk_with_default_in_memory( else: expected_in_memory = False - dset = load_dataset(dataset_loading_script_dir, data_dir=data_dir, keep_in_memory=True) + dset = load_dataset(dataset_loading_script_dir, data_dir=data_dir, keep_in_memory=True, trust_remote_code=True) dataset_path = os.path.join(tmp_path, "saved_dataset") dset.save_to_disk(dataset_path) @@ -1745,19 +1772,13 @@ def test_load_dataset_without_script_with_zip(zip_csv_path): assert ds["train"][0] == {"col_1": 0, "col_2": 0, "col_3": 0.0} -@pytest.mark.parametrize("trust_remote_code, expected", [(False, False), (True, True), (None, True)]) -def test_resolve_trust_remote_code(trust_remote_code, expected): - assert resolve_trust_remote_code(trust_remote_code, repo_id="dummy") is expected - - @pytest.mark.parametrize("trust_remote_code, expected", [(False, False), (True, True), (None, ValueError)]) def test_resolve_trust_remote_code_future(trust_remote_code, expected): - with patch.object(config, "HF_DATASETS_TRUST_REMOTE_CODE", None): # this will be the default soon - if isinstance(expected, bool): - resolve_trust_remote_code(trust_remote_code, repo_id="dummy") is expected - else: - with pytest.raises(expected): - resolve_trust_remote_code(trust_remote_code, repo_id="dummy") + if isinstance(expected, bool): + resolve_trust_remote_code(trust_remote_code, repo_id="dummy") is expected + else: + with pytest.raises(expected): + resolve_trust_remote_code(trust_remote_code, repo_id="dummy") @pytest.mark.integration diff --git a/tests/test_metric_common.py b/tests/test_metric_common.py index 21832efc4a9..2a7407e7a43 100644 --- a/tests/test_metric_common.py +++ b/tests/test_metric_common.py @@ -99,7 +99,9 @@ class LocalMetricTest(parameterized.TestCase): def test_load_metric(self, metric_name): doctest.ELLIPSIS_MARKER = "[...]" metric_module = importlib.import_module( - datasets.load.metric_module_factory(os.path.join("metrics", metric_name)).module_path + datasets.load.metric_module_factory( + os.path.join("metrics", metric_name), trust_remote_code=True + ).module_path ) metric = datasets.load.import_main_class(metric_module.__name__, dataset=False) # check parameters @@ -213,7 +215,7 @@ def predict(self, data, *args, **kwargs): def test_seqeval_raises_when_incorrect_scheme(): - metric = load_metric(os.path.join("metrics", "seqeval")) + metric = load_metric(os.path.join("metrics", "seqeval"), trust_remote_code=True) wrong_scheme = "ERROR" error_message = f"Scheme should be one of [IOB1, IOB2, IOE1, IOE2, IOBES, BILOU], got {wrong_scheme}" with pytest.raises(ValueError, match=re.escape(error_message)): diff --git a/tests/test_warnings.py b/tests/test_warnings.py index eedcbb82ae4..62028507ef4 100644 --- a/tests/test_warnings.py +++ b/tests/test_warnings.py @@ -25,10 +25,15 @@ def list_metrics(self): @pytest.mark.parametrize( - "func, args", [(load_metric, ("metrics/mse",)), (list_metrics, ()), (inspect_metric, ("metrics/mse", "tmp_path"))] + "func, args, kwargs", + [ + (load_metric, ("metrics/mse",), {"trust_remote_code": True}), + (list_metrics, (), {}), + (inspect_metric, ("metrics/mse", "tmp_path"), {"trust_remote_code": True}), + ], ) -def test_metric_deprecation_warning(func, args, mock_emitted_deprecation_warnings, mock_hfh, tmp_path): +def test_metric_deprecation_warning(func, args, kwargs, mock_emitted_deprecation_warnings, mock_hfh, tmp_path): if "tmp_path" in args: args = tuple(arg if arg != "tmp_path" else tmp_path for arg in args) with pytest.warns(FutureWarning, match="https://huggingface.co/docs/evaluate"): - func(*args) + func(*args, **kwargs)