From 5ad21ded48025df51dd0ad865b470fb98cc7c535 Mon Sep 17 00:00:00 2001 From: "pranaya.singh" Date: Wed, 19 Jun 2024 14:32:52 +0530 Subject: [PATCH] added huggingface cli --- GANDLF/cli/huggingface_hub_handler.py | 52 ------- GANDLF/entrypoints/hf_hub_integration.py | 166 +++++++---------------- GANDLF/entrypoints/subcommands.py | 4 +- 3 files changed, 48 insertions(+), 174 deletions(-) delete mode 100644 GANDLF/cli/huggingface_hub_handler.py diff --git a/GANDLF/cli/huggingface_hub_handler.py b/GANDLF/cli/huggingface_hub_handler.py deleted file mode 100644 index 4bdc0a5bb..000000000 --- a/GANDLF/cli/huggingface_hub_handler.py +++ /dev/null @@ -1,52 +0,0 @@ -from huggingface_hub import HfApi, snapshot_download -from typing import List, Union - - -def push_to_model_hub( - repo_id: str, - folder_path: str, - path_in_repo: Union[str, None] = None, - commit_message: Union[str, None] = None, - commit_description: Union[str, None] = None, - token: Union[str, None] = None, - repo_type: Union[str, None] = None, - revision: Union[str, None] = None, - allow_patterns: Union[List[str], str, None] = None, - ignore_patterns: Union[List[str], str, None] = None, - delete_patterns: Union[List[str], str, None] = None, -): - api = HfApi(token=token) - - api.create_repo(repo_id, exist_ok=True) - - api.upload_folder( - repo_id=repo_id, - token=token, - folder_path=folder_path, - path_in_repo=path_in_repo, - commit_message=commit_message, - commit_description=commit_description, - repo_type=repo_type, - revision=revision, - allow_patterns=allow_patterns, - ignore_patterns=ignore_patterns, - delete_patterns=delete_patterns, - ) - - -def download_from_hub( - repo_id: str, - revision: Union[str, None] = None, - cache_dir: Union[str, None] = None, - local_dir: Union[str, None] = None, - force_download: bool = False, - token: Union[str, None] = None, -): - snapshot_download( - repo_id=repo_id, - revision=revision, - cache_dir=cache_dir, - local_dir=local_dir, - force_download=force_download, - token=token, - ) diff --git a/GANDLF/entrypoints/hf_hub_integration.py b/GANDLF/entrypoints/hf_hub_integration.py index 4c923c15d..d5a3b3c93 100644 --- a/GANDLF/entrypoints/hf_hub_integration.py +++ b/GANDLF/entrypoints/hf_hub_integration.py @@ -1,124 +1,50 @@ import click -from GANDLF.entrypoints import append_copyright_to_help -from GANDLF.cli.huggingface_hub_handler import push_to_model_hub, download_from_hub +from GANDLF.cli import copyrightMessage +from argparse import ArgumentParser, RawTextHelpFormatter -@click.command() -@click.option( - "--upload/--download", - "-u/-d", - required=True, - help="Upload or download to/from a Huggingface Repo", -) -@click.option( - "--repo-id", - "-rid", - required=True, - help="Downloading/Uploading: A user or an organization name and a repo name separated by a /", -) -@click.option( - "--token", - "-tk", - help="Downloading/Uploading: A token to be used for the download/upload", -) -@click.option( - "--revision", - "-rv", - help="Downloading/Uploading: git revision id which can be a branch name, a tag, or a commit hash", -) -@click.option( - "--cache-dir", - "-cdir", - help="Downloading: path to the folder where cached files are stored", - type=click.Path(exists=True, file_okay=False, dir_okay=True), -) -@click.option( - "--local-dir", - "-ldir", - help="Downloading: if provided, the downloaded file will be placed under this directory", - type=click.Path(exists=True, file_okay=False, dir_okay=True), -) -@click.option( - "--force-download", - "-fd", - is_flag=True, - help="Downloading: Whether the file should be downloaded even if it already exists in the local cache", -) -@click.option( - "--folder-path", - "-fp", - help="Uploading: Path to the folder to upload on the local file system", - type=click.Path(exists=True, file_okay=False, dir_okay=True), -) -@click.option( - "--path-in-repo", - "-pir", - help="Uploading: Relative path of the directory in the repo. Will default to the root folder of the repository", -) -@click.option( - "--commit-message", - "-cr", - help='Uploading: The summary / title / first line of the generated commit. Defaults to: f"Upload {path_in_repo} with huggingface_hub"', -) -@click.option( - "--commit-description", - "-cd", - help="Uploading: The description of the generated commit", -) -@click.option( - "--repo-type", - "-rt", - help='Uploading: Set to "dataset" or "space" if uploading to a dataset or space, "model" if uploading to a model. Default is model', -) -@click.option( - "--allow-patterns", - "-ap", - help="Uploading: If provided, only files matching at least one pattern are uploaded.", -) -@click.option( - "--ignore-patterns", - "-ip", - help="Uploading: If provided, files matching any of the patterns are not uploaded.", -) -@click.option( - "--delete-patterns", - "-dp", - help="Uploading: If provided, remote files matching any of the patterns will be deleted from the repo while committing new files. This is useful if you don't know which files have already been uploaded.", +from huggingface_hub.commands.delete_cache import DeleteCacheCommand +from huggingface_hub.commands.download import DownloadCommand +from huggingface_hub.commands.env import EnvironmentCommand +from huggingface_hub.commands.lfs import LfsCommands +from huggingface_hub.commands.scan_cache import ScanCacheCommand +from huggingface_hub.commands.tag import TagCommands +from huggingface_hub.commands.upload import UploadCommand +from huggingface_hub.commands.user import UserCommands + +description = """Hugging Face Hub: Streamline model management with upload, download, and more\n\n""" + + +@click.command( + context_settings=dict(ignore_unknown_options=True), add_help_option=False ) -@append_copyright_to_help -def new_way( - upload: bool, - repo_id: str, - token: str, - revision: str, - cache_dir: str, - local_dir: str, - force_download: bool, - folder_path: str, - path_in_repo: str, - commit_message: str, - commit_description: str, - repo_type: str, - allow_patterns: str, - ignore_patterns: str, - delete_patterns: str, -): - """Manages model transfers to and from the Hugging Face Hub""" - if upload: - push_to_model_hub( - repo_id, - folder_path, - path_in_repo, - commit_message, - commit_description, - token, - repo_type, - revision, - allow_patterns, - ignore_patterns, - delete_patterns, - ) - else: - download_from_hub( - repo_id, revision, cache_dir, local_dir, force_download, token - ) +@click.argument("args", nargs=-1, type=click.UNPROCESSED) +def new_way(args): + """Hugging Face Hub: Streamline model management with upload, download, and more""" + + parser = ArgumentParser( + "gandlf hf", + usage="gandlf hf []", + description=description + copyrightMessage, + formatter_class=RawTextHelpFormatter, + ) + + commands_parser = parser.add_subparsers(help="gandlf hf command helpers") + + EnvironmentCommand.register_subcommand(commands_parser) + UserCommands.register_subcommand(commands_parser) + UploadCommand.register_subcommand(commands_parser) + DownloadCommand.register_subcommand(commands_parser) + LfsCommands.register_subcommand(commands_parser) + ScanCacheCommand.register_subcommand(commands_parser) + DeleteCacheCommand.register_subcommand(commands_parser) + TagCommands.register_subcommand(commands_parser) + + args = parser.parse_args(args) + + if not hasattr(args, "func"): + parser.print_help() + exit(1) + + service = args.func(args) + service.run() diff --git a/GANDLF/entrypoints/subcommands.py b/GANDLF/entrypoints/subcommands.py index f4c6aed75..789cbe263 100644 --- a/GANDLF/entrypoints/subcommands.py +++ b/GANDLF/entrypoints/subcommands.py @@ -12,7 +12,7 @@ from GANDLF.entrypoints.generate_metrics import new_way as generate_metrics_command from GANDLF.entrypoints.debug_info import new_way as debug_info_command from GANDLF.entrypoints.split_csv import new_way as split_csv_command -from GANDLF.entrypoints.hf_hub_integration import new_way as hf +from GANDLF.entrypoints.hf_hub_integration import new_way as hf_command cli_subcommands = { "anonymizer": anonymizer_command, @@ -29,5 +29,5 @@ "generate-metrics": generate_metrics_command, "debug-info": debug_info_command, "split-csv": split_csv_command, - "hf": hf, + "hf": hf_command, }