Skip to content

Commit

Permalink
refactor: add warning that --ignore-patterns is ignored
Browse files Browse the repository at this point in the history
  • Loading branch information
KeijiBranshi committed Nov 19, 2024
1 parent c2dee3a commit f91607b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
18 changes: 18 additions & 0 deletions tests/torchtune/_cli/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,24 @@ def test_download_from_kaggle_warn_when_output_dir_provided(
output = capsys.readouterr().out
assert "Successfully downloaded model repo" in output

def test_download_from_kaggle_warn_when_ignore_patterns_provided(
self, capsys, monkeypatch, mocker, tmpdir
):
model = "metaresearch/llama-3.2/pytorch/1b"
testargs = f'tune download {model} --source kaggle --ignore-patterns "*.glob-pattern"'.split()
monkeypatch.setattr(sys, "argv", testargs)
# mock out kagglehub.model_download to get around key storage
mocker.patch("torchtune._cli.download.model_download", return_value=tmpdir)

with pytest.warns(
UserWarning,
match="--ignore-patterns flag is not supported for Kaggle model downloads",
):
runpy.run_path(TUNE_PATH, run_name="__main__")

output = capsys.readouterr().out
assert "Successfully downloaded model repo" in output

# tests when --kaggle-username and --kaggle-api-key are provided as CLI args
def test_download_from_kaggle_when_credentials_provided(
self, capsys, monkeypatch, mocker
Expand Down
7 changes: 6 additions & 1 deletion torchtune/_cli/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,16 @@ def _download_from_kaggle(self, args: argparse.Namespace) -> None:
self._set_kaggle_credentials(args)

# kagglehub doesn't currently support `local_dir` and `ignore_patterns` like huggingface_hub
if args.output_dir is not None:
if args.output_dir:
warn(
"--output-dir flag is not supported for Kaggle model downloads. "
"This argument will be ignored."
)
if args.ignore_patterns:
warn(
"--ignore-patterns flag is not supported for Kaggle model downloads. "
"This argument will be ignored."
)

try:
output_dir = model_download(model_handle)
Expand Down

0 comments on commit f91607b

Please sign in to comment.