Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed error message for GatedRepoError #1832

Merged
merged 3 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions tests/torchtune/_cli/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,44 @@ def test_download_calls_snapshot(self, capsys, monkeypatch, snapshot_download):

# Make sure it was called twice
assert snapshot_download.call_count == 3

# GatedRepoError without --hf-token (expect prompt for token)
def test_gated_repo_error_no_token(self, capsys, monkeypatch, snapshot_download):
model = "meta-llama/Llama-2-7b"
testargs = f"tune download {model}".split()
monkeypatch.setattr(sys, "argv", testargs)

# Expect GatedRepoError without --hf-token provided
with pytest.raises(SystemExit, match="2"):
runpy.run_path(TUNE_PATH, run_name="__main__")

out_err = capsys.readouterr()
# Check that error message prompts for --hf-token
assert (
"It looks like you are trying to access a gated repository." in out_err.err
)
assert (
"Please ensure you have access to the repository and have provided the proper Hugging Face API token"
in out_err.err
)

# GatedRepoError with --hf-token (should not ask for token)
def test_gated_repo_error_with_token(self, capsys, monkeypatch, snapshot_download):
model = "meta-llama/Llama-2-7b"
testargs = f"tune download {model} --hf-token valid_token".split()
monkeypatch.setattr(sys, "argv", testargs)

# Expect GatedRepoError with --hf-token provided
with pytest.raises(SystemExit, match="2"):
runpy.run_path(TUNE_PATH, run_name="__main__")

out_err = capsys.readouterr()
# Check that error message does not prompt for --hf-token again
assert (
"It looks like you are trying to access a gated repository." in out_err.err
)
assert "Please ensure you have access to the repository." in out_err.err
assert (
"Please ensure you have access to the repository and have provided the proper Hugging Face API token"
not in out_err.err
)
18 changes: 12 additions & 6 deletions torchtune/_cli/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,18 @@ def _download_cmd(self, args: argparse.Namespace) -> None:
token=args.hf_token,
)
except GatedRepoError:
self._parser.error(
"It looks like you are trying to access a gated repository. Please ensure you "
"have access to the repository and have provided the proper Hugging Face API token "
"using the option `--hf-token` or by running `huggingface-cli login`."
"You can find your token by visiting https://huggingface.co/settings/tokens"
)
if args.hf_token:
self._parser.error(
"It looks like you are trying to access a gated repository. Please ensure you "
"have access to the repository."
)
else:
self._parser.error(
"It looks like you are trying to access a gated repository. Please ensure you "
"have access to the repository and have provided the proper Hugging Face API token "
"using the option `--hf-token` or by running `huggingface-cli login`."
"You can find your token by visiting https://huggingface.co/settings/tokens"
)
except RepositoryNotFoundError:
self._parser.error(
f"Repository '{args.repo_id}' not found on the Hugging Face Hub."
Expand Down
Loading