From f984cdcf98d8ddc8b95b603779d206fbff9d76cf Mon Sep 17 00:00:00 2001 From: Swapnil Jikar <112884653+WizKnight@users.noreply.github.com> Date: Thu, 26 Sep 2024 20:04:25 +0530 Subject: [PATCH] Feature: switch visibility with update_repo_settings #2537 (#2541) * Enhance `update_repo_settings` to manage repo visibility * Enhance `update_repo_settings` to manage repo visibility * Enhance `update_repo_settings` to manage repo visibility * Enhance `update_repo_settings` to manage repo visibility * Enhance `update_repo_settings` to manage repo visibility * Apply suggestions from code review --------- Co-authored-by: Lucain --- docs/source/en/guides/repository.md | 4 +-- src/huggingface_hub/hf_api.py | 42 ++++++++++++++++++++++------- tests/test_file_download.py | 1 + tests/test_hf_api.py | 36 ++++++++++++------------- tests/test_snapshot_download.py | 3 ++- 5 files changed, 55 insertions(+), 31 deletions(-) diff --git a/docs/source/en/guides/repository.md b/docs/source/en/guides/repository.md index 77bf797e61..adfc881298 100644 --- a/docs/source/en/guides/repository.md +++ b/docs/source/en/guides/repository.md @@ -151,8 +151,8 @@ Some settings are specific to Spaces (hardware, environment variables,...). To c A repository can be public or private. A private repository is only visible to you or members of the organization in which the repository is located. Change a repository to private as shown in the following: ```py ->>> from huggingface_hub import update_repo_visibility ->>> update_repo_visibility(repo_id=repo_id, private=True) +>>> from huggingface_hub import update_repo_settings +>>> update_repo_settings(repo_id=repo_id, private=True) ``` ### Setup gated access diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index ace2505162..98d556de42 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -3524,6 +3524,7 @@ def delete_repo( if not missing_ok: raise + @_deprecate_method(version="0.29", message="Please use `update_repo_settings` instead.") @validate_hf_hub_args def update_repo_visibility( self, @@ -3535,6 +3536,8 @@ def update_repo_visibility( ) -> Dict[str, bool]: """Update the visibility setting of a repository. + Deprecated. Use `update_repo_settings` instead. + Args: repo_id (`str`, *optional*): A namespace (user or an organization) and a repo name separated by a `/`. @@ -3581,29 +3584,34 @@ def update_repo_settings( self, repo_id: str, *, - gated: Literal["auto", "manual", False] = False, + gated: Optional[Literal["auto", "manual", False]] = None, + private: Optional[bool] = None, token: Union[str, bool, None] = None, repo_type: Optional[str] = None, ) -> None: """ - Update the gated settings of a repository. - To give more control over how repos are used, the Hub allows repo authors to enable **access requests** for their repos. + Update the settings of a repository, including gated access and visibility. + + To give more control over how repos are used, the Hub allows repo authors to enable + access requests for their repos, and also to set the visibility of the repo to private. Args: repo_id (`str`): A namespace (user or an organization) and a repo name separated by a /. gated (`Literal["auto", "manual", False]`, *optional*): - The gated release status for the repository. + The gated status for the repository. If set to `None` (default), the `gated` setting of the repository won't be updated. * "auto": The repository is gated, and access requests are automatically approved or denied based on predefined criteria. * "manual": The repository is gated, and access requests require manual approval. - * False (default): The repository is not gated, and anyone can access it. + * False : The repository is not gated, and anyone can access it. + private (`bool`, *optional*): + Whether the model repo should be private. token (`Union[str, bool, None]`, *optional*): A valid user access token (string). Defaults to the locally saved token, which is the recommended method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). To disable authentication, pass False. repo_type (`str`, *optional*): - The type of the repository to update settings from (`"model"`, `"dataset"` or `"space"`. + The type of the repository to update settings from (`"model"`, `"dataset"` or `"space"`). Defaults to `"model"`. Raises: @@ -3613,22 +3621,38 @@ def update_repo_settings( If repo_type is not one of the values in constants.REPO_TYPES. [`~utils.HfHubHTTPError`]: If the request to the Hugging Face Hub API fails. + [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. """ - if gated not in ["auto", "manual", False]: - raise ValueError(f"Invalid gated status, must be one of 'auto', 'manual', or False. Got '{gated}'.") if repo_type not in constants.REPO_TYPES: raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") if repo_type is None: repo_type = constants.REPO_TYPE_MODEL # default repo type + # Check if both gated and private are None + if gated is None and private is None: + raise ValueError("At least one of 'gated' or 'private' must be provided.") + # Build headers headers = self._build_hf_headers(token=token) + # Prepare the JSON payload for the PUT request + payload: Dict = {} + + if gated is not None: + if gated not in ["auto", "manual", False]: + raise ValueError(f"Invalid gated status, must be one of 'auto', 'manual', or False. Got '{gated}'.") + payload["gated"] = gated + + if private is not None: + payload["private"] = private + r = get_session().put( url=f"{self.endpoint}/api/{repo_type}s/{repo_id}/settings", headers=headers, - json={"gated": gated}, + json=payload, ) hf_raise_for_status(r) diff --git a/tests/test_file_download.py b/tests/test_file_download.py index 3496f82349..9e265f4b5d 100644 --- a/tests/test_file_download.py +++ b/tests/test_file_download.py @@ -161,6 +161,7 @@ def test_download_from_a_gated_repo_with_hf_hub_download(self, repo_url: RepoUrl repo_id=repo_url.repo_id, filename=".gitattributes", token=OTHER_TOKEN, cache_dir=tmpdir ) + @expect_deprecation("update_repo_visibility") @use_tmp_repo() def test_download_regular_file_from_private_renamed_repo(self, repo_url: RepoUrl) -> None: """Regression test for #1999. diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index d5acf70ccc..a028066847 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -20,7 +20,6 @@ import types import unittest import uuid -import warnings from collections.abc import Iterable from concurrent.futures import Future from dataclasses import fields @@ -93,6 +92,7 @@ DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT, ENDPOINT_PRODUCTION, SAMPLE_DATASET_IDENTIFIER, + expect_deprecation, repo_name, require_git_lfs, rmtree_with_retry, @@ -124,18 +124,6 @@ def setUpClass(cls): cls._api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) -def test_repo_id_no_warning(): - # tests that passing repo_id as positional arg doesn't raise any warnings - # for {create, delete}_repo and update_repo_visibility - api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) - - with warnings.catch_warnings(record=True) as record: - repo_id = api.create_repo(repo_name()).repo_id - api.update_repo_visibility(repo_id, private=True) - api.delete_repo(repo_id) - assert not len(record) - - class HfApiRepoFileExistsTest(HfApiCommonTest): def setUp(self) -> None: super().setUp() @@ -210,6 +198,7 @@ def test_delete_repo_error_message(self): def test_delete_repo_missing_ok(self) -> None: self._api.delete_repo("repo-that-does-not-exist", missing_ok=True) + @expect_deprecation("update_repo_visibility") def test_create_update_and_delete_repo(self): repo_id = self._api.create_repo(repo_id=repo_name()).repo_id res = self._api.update_repo_visibility(repo_id=repo_id, private=True) @@ -218,6 +207,7 @@ def test_create_update_and_delete_repo(self): assert not res["private"] self._api.delete_repo(repo_id=repo_id) + @expect_deprecation("update_repo_visibility") def test_create_update_and_delete_model_repo(self): repo_id = self._api.create_repo(repo_id=repo_name(), repo_type=constants.REPO_TYPE_MODEL).repo_id res = self._api.update_repo_visibility(repo_id=repo_id, private=True, repo_type=constants.REPO_TYPE_MODEL) @@ -226,6 +216,7 @@ def test_create_update_and_delete_model_repo(self): assert not res["private"] self._api.delete_repo(repo_id=repo_id, repo_type=constants.REPO_TYPE_MODEL) + @expect_deprecation("update_repo_visibility") def test_create_update_and_delete_dataset_repo(self): repo_id = self._api.create_repo(repo_id=repo_name(), repo_type=constants.REPO_TYPE_DATASET).repo_id res = self._api.update_repo_visibility(repo_id=repo_id, private=True, repo_type=constants.REPO_TYPE_DATASET) @@ -234,6 +225,7 @@ def test_create_update_and_delete_dataset_repo(self): assert not res["private"] self._api.delete_repo(repo_id=repo_id, repo_type=constants.REPO_TYPE_DATASET) + @expect_deprecation("update_repo_visibility") def test_create_update_and_delete_space_repo(self): with pytest.raises(ValueError, match=r"No space_sdk provided.*"): self._api.create_repo(repo_id=repo_name(), repo_type=constants.REPO_TYPE_SPACE, space_sdk=None) @@ -286,9 +278,11 @@ def test_update_repo_settings(self, repo_url: RepoUrl): repo_id = repo_url.repo_id for gated_value in ["auto", "manual", False]: - self._api.update_repo_settings(repo_id=repo_id, gated=gated_value) - info = self._api.model_info(repo_id, expand="gated") - assert info.gated == gated_value + for private_value in [True, False]: # Test both private and public settings + self._api.update_repo_settings(repo_id=repo_id, gated=gated_value, private=private_value) + info = self._api.model_info(repo_id) + assert info.gated == gated_value + assert info.private == private_value # Verify the private setting @use_tmp_repo(repo_type="dataset") def test_update_dataset_repo_settings(self, repo_url: RepoUrl): @@ -296,9 +290,13 @@ def test_update_dataset_repo_settings(self, repo_url: RepoUrl): repo_type = repo_url.repo_type for gated_value in ["auto", "manual", False]: - self._api.update_repo_settings(repo_id=repo_id, repo_type=repo_type, gated=gated_value) - info = self._api.dataset_info(repo_id, expand="gated") - assert info.gated == gated_value + for private_value in [True, False]: + self._api.update_repo_settings( + repo_id=repo_id, repo_type=repo_type, gated=gated_value, private=private_value + ) + info = self._api.dataset_info(repo_id) + assert info.gated == gated_value + assert info.private == private_value class CommitApiTest(HfApiCommonTest): diff --git a/tests/test_snapshot_download.py b/tests/test_snapshot_download.py index 17f4c561b3..641e7980ab 100644 --- a/tests/test_snapshot_download.py +++ b/tests/test_snapshot_download.py @@ -8,7 +8,7 @@ from huggingface_hub.utils import SoftTemporaryDirectory from .testing_constants import TOKEN -from .testing_utils import OfflineSimulationMode, offline, repo_name +from .testing_utils import OfflineSimulationMode, expect_deprecation, offline, repo_name class SnapshotDownloadTests(unittest.TestCase): @@ -95,6 +95,7 @@ def test_download_model(self): # folder name contains the revision's commit sha. self.assertTrue(self.first_commit_hash in storage_folder) + @expect_deprecation("update_repo_visibility") def test_download_private_model(self): self.api.update_repo_visibility(repo_id=self.repo_id, private=True)