-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ad200ae
commit 5ad21de
Showing
3 changed files
with
48 additions
and
174 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,124 +1,50 @@ | ||
import click | ||
from GANDLF.entrypoints import append_copyright_to_help | ||
from GANDLF.cli.huggingface_hub_handler import push_to_model_hub, download_from_hub | ||
from GANDLF.cli import copyrightMessage | ||
from argparse import ArgumentParser, RawTextHelpFormatter | ||
|
||
|
||
@click.command() | ||
@click.option( | ||
"--upload/--download", | ||
"-u/-d", | ||
required=True, | ||
help="Upload or download to/from a Huggingface Repo", | ||
) | ||
@click.option( | ||
"--repo-id", | ||
"-rid", | ||
required=True, | ||
help="Downloading/Uploading: A user or an organization name and a repo name separated by a /", | ||
) | ||
@click.option( | ||
"--token", | ||
"-tk", | ||
help="Downloading/Uploading: A token to be used for the download/upload", | ||
) | ||
@click.option( | ||
"--revision", | ||
"-rv", | ||
help="Downloading/Uploading: git revision id which can be a branch name, a tag, or a commit hash", | ||
) | ||
@click.option( | ||
"--cache-dir", | ||
"-cdir", | ||
help="Downloading: path to the folder where cached files are stored", | ||
type=click.Path(exists=True, file_okay=False, dir_okay=True), | ||
) | ||
@click.option( | ||
"--local-dir", | ||
"-ldir", | ||
help="Downloading: if provided, the downloaded file will be placed under this directory", | ||
type=click.Path(exists=True, file_okay=False, dir_okay=True), | ||
) | ||
@click.option( | ||
"--force-download", | ||
"-fd", | ||
is_flag=True, | ||
help="Downloading: Whether the file should be downloaded even if it already exists in the local cache", | ||
) | ||
@click.option( | ||
"--folder-path", | ||
"-fp", | ||
help="Uploading: Path to the folder to upload on the local file system", | ||
type=click.Path(exists=True, file_okay=False, dir_okay=True), | ||
) | ||
@click.option( | ||
"--path-in-repo", | ||
"-pir", | ||
help="Uploading: Relative path of the directory in the repo. Will default to the root folder of the repository", | ||
) | ||
@click.option( | ||
"--commit-message", | ||
"-cr", | ||
help='Uploading: The summary / title / first line of the generated commit. Defaults to: f"Upload {path_in_repo} with huggingface_hub"', | ||
) | ||
@click.option( | ||
"--commit-description", | ||
"-cd", | ||
help="Uploading: The description of the generated commit", | ||
) | ||
@click.option( | ||
"--repo-type", | ||
"-rt", | ||
help='Uploading: Set to "dataset" or "space" if uploading to a dataset or space, "model" if uploading to a model. Default is model', | ||
) | ||
@click.option( | ||
"--allow-patterns", | ||
"-ap", | ||
help="Uploading: If provided, only files matching at least one pattern are uploaded.", | ||
) | ||
@click.option( | ||
"--ignore-patterns", | ||
"-ip", | ||
help="Uploading: If provided, files matching any of the patterns are not uploaded.", | ||
) | ||
@click.option( | ||
"--delete-patterns", | ||
"-dp", | ||
help="Uploading: If provided, remote files matching any of the patterns will be deleted from the repo while committing new files. This is useful if you don't know which files have already been uploaded.", | ||
from huggingface_hub.commands.delete_cache import DeleteCacheCommand | ||
from huggingface_hub.commands.download import DownloadCommand | ||
from huggingface_hub.commands.env import EnvironmentCommand | ||
from huggingface_hub.commands.lfs import LfsCommands | ||
from huggingface_hub.commands.scan_cache import ScanCacheCommand | ||
from huggingface_hub.commands.tag import TagCommands | ||
from huggingface_hub.commands.upload import UploadCommand | ||
from huggingface_hub.commands.user import UserCommands | ||
|
||
description = """Hugging Face Hub: Streamline model management with upload, download, and more\n\n""" | ||
|
||
|
||
@click.command( | ||
context_settings=dict(ignore_unknown_options=True), add_help_option=False | ||
) | ||
@append_copyright_to_help | ||
def new_way( | ||
upload: bool, | ||
repo_id: str, | ||
token: str, | ||
revision: str, | ||
cache_dir: str, | ||
local_dir: str, | ||
force_download: bool, | ||
folder_path: str, | ||
path_in_repo: str, | ||
commit_message: str, | ||
commit_description: str, | ||
repo_type: str, | ||
allow_patterns: str, | ||
ignore_patterns: str, | ||
delete_patterns: str, | ||
): | ||
"""Manages model transfers to and from the Hugging Face Hub""" | ||
if upload: | ||
push_to_model_hub( | ||
repo_id, | ||
folder_path, | ||
path_in_repo, | ||
commit_message, | ||
commit_description, | ||
token, | ||
repo_type, | ||
revision, | ||
allow_patterns, | ||
ignore_patterns, | ||
delete_patterns, | ||
) | ||
else: | ||
download_from_hub( | ||
repo_id, revision, cache_dir, local_dir, force_download, token | ||
) | ||
@click.argument("args", nargs=-1, type=click.UNPROCESSED) | ||
def new_way(args): | ||
"""Hugging Face Hub: Streamline model management with upload, download, and more""" | ||
|
||
parser = ArgumentParser( | ||
"gandlf hf", | ||
usage="gandlf hf <command> [<args>]", | ||
description=description + copyrightMessage, | ||
formatter_class=RawTextHelpFormatter, | ||
) | ||
|
||
commands_parser = parser.add_subparsers(help="gandlf hf command helpers") | ||
|
||
EnvironmentCommand.register_subcommand(commands_parser) | ||
UserCommands.register_subcommand(commands_parser) | ||
UploadCommand.register_subcommand(commands_parser) | ||
DownloadCommand.register_subcommand(commands_parser) | ||
LfsCommands.register_subcommand(commands_parser) | ||
ScanCacheCommand.register_subcommand(commands_parser) | ||
DeleteCacheCommand.register_subcommand(commands_parser) | ||
TagCommands.register_subcommand(commands_parser) | ||
|
||
args = parser.parse_args(args) | ||
|
||
if not hasattr(args, "func"): | ||
parser.print_help() | ||
exit(1) | ||
|
||
service = args.func(args) | ||
service.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters