Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove default trust_remote_code=True #6954

Merged
merged 6 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/datasets/commands/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from argparse import ArgumentParser
from pathlib import Path
from shutil import copyfile, rmtree
from typing import Generator
from typing import Generator, Optional

import datasets.config
from datasets.builder import DatasetBuilder
Expand All @@ -29,6 +29,7 @@ def _test_command_factory(args):
args.force_redownload,
args.clear_cache,
args.num_proc,
args.trust_remote_code,
)


Expand Down Expand Up @@ -67,6 +68,9 @@ def register_subcommand(parser: ArgumentParser):
help="Remove downloaded files and cached datasets after each config test",
)
test_parser.add_argument("--num_proc", type=int, default=None, help="Number of processes")
test_parser.add_argument(
"--trust_remote_code", action="store_true", help="whether to trust the code execution of the load script"
)
# aliases
test_parser.add_argument("--save_infos", action="store_true", help="alias to save_info")
test_parser.add_argument("dataset", type=str, help="Name of the dataset to download")
Expand All @@ -84,6 +88,7 @@ def __init__(
force_redownload: bool,
clear_cache: bool,
num_proc: int,
trust_remote_code: Optional[bool],
):
self._dataset = dataset
self._name = name
Expand All @@ -95,6 +100,7 @@ def __init__(
self._force_redownload = force_redownload
self._clear_cache = clear_cache
self._num_proc = num_proc
self._trust_remote_code = trust_remote_code
if clear_cache and not cache_dir:
print(
"When --clear_cache is used, specifying a cache directory is mandatory.\n"
Expand All @@ -111,7 +117,7 @@ def run(self):
print("Both parameters `config` and `all_configs` can't be used at once.")
exit(1)
path, config_name = self._dataset, self._name
module = dataset_module_factory(path)
module = dataset_module_factory(path, trust_remote_code=self._trust_remote_code)
builder_cls = import_main_class(module.module_path)
n_builders = len(builder_cls.BUILDER_CONFIGS) if self._all_configs and builder_cls.BUILDER_CONFIGS else 1

Expand Down
2 changes: 1 addition & 1 deletion src/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@
HF_DATASETS_MULTITHREADING_MAX_WORKERS = 16

# Remote dataset scripts support
__HF_DATASETS_TRUST_REMOTE_CODE = os.environ.get("HF_DATASETS_TRUST_REMOTE_CODE", "1")
__HF_DATASETS_TRUST_REMOTE_CODE = os.environ.get("HF_DATASETS_TRUST_REMOTE_CODE", "ask")
HF_DATASETS_TRUST_REMOTE_CODE: Optional[bool] = (
True
if __HF_DATASETS_TRUST_REMOTE_CODE.upper() in ENV_VARS_TRUE_VALUES
Expand Down
5 changes: 4 additions & 1 deletion src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -2007,7 +2007,10 @@ def metric_module_factory(
raise FileNotFoundError(f"Couldn't find a metric script at {relative_to_absolute_path(path)}")
elif os.path.isfile(combined_path):
return LocalMetricModuleFactory(
combined_path, download_mode=download_mode, dynamic_modules_path=dynamic_modules_path
combined_path,
download_mode=download_mode,
dynamic_modules_path=dynamic_modules_path,
trust_remote_code=trust_remote_code,
).get_module()
elif is_relative_path(path) and path.count("/") == 0:
try:
Expand Down
7 changes: 5 additions & 2 deletions tests/commands/test_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
"force_redownload",
"clear_cache",
"num_proc",
"trust_remote_code",
],
defaults=[None, None, None, False, False, False, False, False, None],
defaults=[None, None, None, False, False, False, False, False, None, None],
)


Expand All @@ -32,7 +33,9 @@ def is_1percent_close(source, target):

@pytest.mark.integration
def test_test_command(dataset_loading_script_dir):
args = _TestCommandArgs(dataset=dataset_loading_script_dir, all_configs=True, save_infos=True)
args = _TestCommandArgs(
dataset=dataset_loading_script_dir, all_configs=True, save_infos=True, trust_remote_code=True
)
test_command = TestCommand(*args)
test_command.run()
dataset_readme_path = os.path.join(dataset_loading_script_dir, "README.md")
Expand Down
4 changes: 2 additions & 2 deletions tests/features/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,9 +604,9 @@ def test_load_dataset_with_audio_feature(streaming, jsonl_audio_dataset_path, sh
@pytest.mark.integration
def test_dataset_with_audio_feature_loaded_from_cache():
# load first time
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean")
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", trust_remote_code=True)
# load from cache
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", trust_remote_code=True, split="validation")
assert isinstance(ds, Dataset)


Expand Down
4 changes: 3 additions & 1 deletion tests/features/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,9 @@ def test_load_dataset_with_image_feature(shared_datadir, data_dir, dataset_loadi
import PIL.Image

image_path = str(shared_datadir / "test_image_rgb.jpg")
dset = load_dataset(dataset_loading_script_dir, split="train", data_dir=data_dir, streaming=streaming)
dset = load_dataset(
dataset_loading_script_dir, split="train", data_dir=data_dir, streaming=streaming, trust_remote_code=True
)
item = dset[0] if not streaming else next(iter(dset))
assert item.keys() == {"image", "caption"}
assert isinstance(item["image"], PIL.Image.Image)
Expand Down
9 changes: 7 additions & 2 deletions tests/test_hf_gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def test_dataset_info_available(self, dataset, config_name, revision):
config_name,
revision=revision,
cache_dir=tmp_dir,
trust_remote_code=True,
)

dataset_info_url = "/".join(
Expand All @@ -88,7 +89,7 @@ def test_dataset_info_available(self, dataset, config_name, revision):
@pytest.mark.integration
def test_as_dataset_from_hf_gcs(tmp_path_factory):
tmp_dir = tmp_path_factory.mktemp("test_hf_gcp") / "test_wikipedia_simple"
builder = load_dataset_builder("wikipedia", "20220301.frr", cache_dir=tmp_dir)
builder = load_dataset_builder("wikipedia", "20220301.frr", cache_dir=tmp_dir, trust_remote_code=True)
# use the HF cloud storage, not the original download_and_prepare that uses apache-beam
builder._download_and_prepare = None
builder.download_and_prepare(try_from_hf_gcs=True)
Expand All @@ -99,7 +100,11 @@ def test_as_dataset_from_hf_gcs(tmp_path_factory):
@pytest.mark.integration
def test_as_streaming_dataset_from_hf_gcs(tmp_path):
builder = load_dataset_builder(
"wikipedia", "20220301.frr", revision="4d013bdd32c475c8536aae00a56efc774f061649", cache_dir=tmp_path
"wikipedia",
"20220301.frr",
revision="4d013bdd32c475c8536aae00a56efc774f061649",
cache_dir=tmp_path,
trust_remote_code=True,
)
ds = builder.as_streaming_dataset()
assert ds
Expand Down
2 changes: 1 addition & 1 deletion tests/test_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_convert_to_parquet(temporary_repo, hf_api, hf_token, ci_hub_config, ci_
with patch.object(datasets.hub.HfApi, "create_commit", return_value=commit_info) as mock_create_commit:
with patch.object(datasets.hub.HfApi, "create_branch") as mock_create_branch:
with patch.object(datasets.hub.HfApi, "list_repo_tree", return_value=[]): # not needed
_ = convert_to_parquet(repo_id, token=hf_token)
_ = convert_to_parquet(repo_id, token=hf_token, trust_remote_code=True)
# mock_create_branch
assert mock_create_branch.called
assert mock_create_branch.call_count == 2
Expand Down
6 changes: 3 additions & 3 deletions tests/test_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_inspect_dataset(path, tmp_path):
@pytest.mark.filterwarnings("ignore:metric_module_factory is deprecated:FutureWarning")
@pytest.mark.parametrize("path", ["accuracy"])
def test_inspect_metric(path, tmp_path):
inspect_metric(path, tmp_path)
inspect_metric(path, tmp_path, trust_remote_code=True)
script_name = path + ".py"
assert script_name in os.listdir(tmp_path)
assert "__pycache__" not in os.listdir(tmp_path)
Expand Down Expand Up @@ -79,7 +79,7 @@ def test_get_dataset_config_info_error(path, config_name, expected_exception):
],
)
def test_get_dataset_config_names(path, expected):
config_names = get_dataset_config_names(path)
config_names = get_dataset_config_names(path, trust_remote_code=True)
assert config_names == expected


Expand All @@ -97,7 +97,7 @@ def test_get_dataset_config_names(path, expected):
],
)
def test_get_dataset_default_config_name(path, expected):
default_config_name = get_dataset_default_config_name(path)
default_config_name = get_dataset_default_config_name(path, trust_remote_code=True)
if expected:
assert default_config_name == expected
else:
Expand Down
Loading
Loading