From 546d6c0e3c32d41c49f4951a4ede553b62d75ba5 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Mon, 18 Sep 2023 17:24:26 +0800 Subject: [PATCH] fix #6833 Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 113 +++++++++++++++++++++++++--------- tests/test_bundle_get_data.py | 47 +++++++++++++- 2 files changed, 130 insertions(+), 30 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 6b34627a6a..65d58d0ad9 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -541,11 +541,24 @@ def load( return model +@deprecated_arg_default( + "model_info_url", + None, + "https://raw.githubusercontent.com/Project-MONAI/model-zoo/dev/models/model_info.json", + since="1.3", + replaced="1.5", +) def _get_all_bundles_info( - repo: str = "Project-MONAI/model-zoo", tag: str = "hosting_storage_v1", auth_token: str | None = None + repo: str = "Project-MONAI/model-zoo", + tag: str = "hosting_storage_v1", + auth_token: str | None = None, + model_info_url: str | None = None, ) -> dict[str, dict[str, dict[str, Any]]]: if has_requests: - request_url = f"https://api.github.com/repos/{repo}/releases" + if model_info_url is not None: + request_url = model_info_url + else: + request_url = f"https://api.github.com/repos/{repo}/releases" if auth_token is not None: headers = {"Authorization": f"Bearer {auth_token}"} resp = requests_get(request_url, headers=headers) @@ -558,33 +571,56 @@ def _get_all_bundles_info( bundle_name_pattern = re.compile(r"_v\d*.") bundles_info: dict[str, dict[str, dict[str, Any]]] = {} - for release in releases_list: - if release["tag_name"] == tag: - for asset in release["assets"]: - asset_name = bundle_name_pattern.split(asset["name"])[0] - if asset_name not in bundles_info: - bundles_info[asset_name] = {} - asset_version = asset["name"].split(f"{asset_name}_v")[-1].replace(".zip", "") - bundles_info[asset_name][asset_version] = { - "id": asset["id"], - "name": asset["name"], - "size": asset["size"], - "download_count": asset["download_count"], - "browser_download_url": asset["browser_download_url"], - "created_at": asset["created_at"], - "updated_at": asset["updated_at"], - } - return bundles_info + if model_info_url is not None: + for asset in releases_list.keys(): + asset_name = bundle_name_pattern.split(asset)[0] + if asset_name not in bundles_info: + bundles_info[asset_name] = {} + asset_version = asset.split(f"{asset_name}_v")[-1] + bundles_info[asset_name][asset_version] = { + "name": asset, + "browser_download_url": releases_list[asset]["source"], + } + return bundles_info + else: + for release in releases_list: + if release["tag_name"] == tag: + for asset in release["assets"]: + asset_name = bundle_name_pattern.split(asset["name"])[0] + if asset_name not in bundles_info: + bundles_info[asset_name] = {} + asset_version = asset["name"].split(f"{asset_name}_v")[-1].replace(".zip", "") + bundles_info[asset_name][asset_version] = { + "id": asset["id"], + "name": asset["name"], + "size": asset["size"], + "download_count": asset["download_count"], + "browser_download_url": asset["browser_download_url"], + "created_at": asset["created_at"], + "updated_at": asset["updated_at"], + } + return bundles_info return bundles_info +@deprecated_arg_default( + "model_info_url", + None, + "https://raw.githubusercontent.com/Project-MONAI/model-zoo/dev/models/model_info.json", + since="1.3", + replaced="1.5", +) def get_all_bundles_list( - repo: str = "Project-MONAI/model-zoo", tag: str = "hosting_storage_v1", auth_token: str | None = None + repo: str = "Project-MONAI/model-zoo", + tag: str = "hosting_storage_v1", + auth_token: str | None = None, + model_info_url: str | None = None, ) -> list[tuple[str, str]]: """ Get all bundles names (and the latest versions) that are stored in the release of specified repository - with the provided tag. The default values of arguments correspond to the release of MONAI model zoo. - In order to increase the rate limits of calling Github APIs, you can input your personal access token. + with the provided tag or listed in `model_info_url`. The default values of arguments correspond to the + release of MONAI model zoo. In order to increase the rate limits of calling Github APIs, you can input + your personal access token. Please check the following link for more details about rate limiting: https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting @@ -595,13 +631,14 @@ def get_all_bundles_list( repo: it should be in the form of "repo_owner/repo_name/". tag: the tag name of the release. auth_token: github personal access token. + model_info_url: a JSON file link containing all of the model information. Returns: a list of tuple in the form of (bundle name, latest version). """ - bundles_info = _get_all_bundles_info(repo=repo, tag=tag, auth_token=auth_token) + bundles_info = _get_all_bundles_info(repo=repo, tag=tag, auth_token=auth_token, model_info_url=model_info_url) bundles_list = [] for bundle_name in bundles_info: latest_version = sorted(bundles_info[bundle_name].keys())[-1] @@ -610,15 +647,23 @@ def get_all_bundles_list( return bundles_list +@deprecated_arg_default( + "model_info_url", + None, + "https://raw.githubusercontent.com/Project-MONAI/model-zoo/dev/models/model_info.json", + since="1.3", + replaced="1.5", +) def get_bundle_versions( bundle_name: str, repo: str = "Project-MONAI/model-zoo", tag: str = "hosting_storage_v1", auth_token: str | None = None, + model_info_url: str | None = None, ) -> dict[str, list[str] | str]: """ Get the latest version, as well as all existing versions of a bundle that is stored in the release of specified - repository with the provided tag. + repository with the provided tag or listed in `model_info_url`. In order to increase the rate limits of calling Github APIs, you can input your personal access token. Please check the following link for more details about rate limiting: https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting @@ -631,13 +676,14 @@ def get_bundle_versions( repo: it should be in the form of "repo_owner/repo_name/". tag: the tag name of the release. auth_token: github personal access token. + model_info_url: a JSON file link containing all of the model information. Returns: a dictionary that contains the latest version and all versions of a bundle. """ - bundles_info = _get_all_bundles_info(repo=repo, tag=tag, auth_token=auth_token) + bundles_info = _get_all_bundles_info(repo=repo, tag=tag, auth_token=auth_token, model_info_url=model_info_url) if bundle_name not in bundles_info: raise ValueError(f"bundle: {bundle_name} is not existing in repo: {repo}.") bundle_info = bundles_info[bundle_name] @@ -646,17 +692,27 @@ def get_bundle_versions( return {"latest_version": all_versions[-1], "all_versions": all_versions} +@deprecated_arg_default( + "model_info_url", + None, + "https://raw.githubusercontent.com/Project-MONAI/model-zoo/dev/models/model_info.json", + since="1.3", + replaced="1.5", +) def get_bundle_info( bundle_name: str, version: str | None = None, repo: str = "Project-MONAI/model-zoo", tag: str = "hosting_storage_v1", auth_token: str | None = None, + model_info_url: str | None = None, ) -> dict[str, Any]: """ Get all information (include "id", "name", "size", "download_count", "browser_download_url", "created_at", "updated_at") of a bundle - with the specified bundle name and version. + with the specified bundle name and version which is stored in the release of specified repository with the provided tag. + Since v1.5, it has been deprecated in favor of'model_info_url', which contains only "name" and "browser_download_url" + information about a bundle. In order to increase the rate limits of calling Github APIs, you can input your personal access token. Please check the following link for more details about rate limiting: https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting @@ -670,13 +726,14 @@ def get_bundle_info( repo: it should be in the form of "repo_owner/repo_name/". tag: the tag name of the release. auth_token: github personal access token. + model_info_url: a JSON file link containing all of the model information. Returns: a dictionary that contains the bundle's information. """ - bundles_info = _get_all_bundles_info(repo=repo, tag=tag, auth_token=auth_token) + bundles_info = _get_all_bundles_info(repo=repo, tag=tag, auth_token=auth_token, model_info_url=model_info_url) if bundle_name not in bundles_info: raise ValueError(f"bundle: {bundle_name} is not existing.") bundle_info = bundles_info[bundle_name] @@ -685,7 +742,7 @@ def get_bundle_info( if version not in bundle_info: raise ValueError(f"version: {version} of bundle: {bundle_name} is not existing.") - return bundle_info[version] + return bundle_info[version] # type: ignore[no-any-return] @deprecated_arg("runner_id", since="1.1", removed="1.3", new_name="run_id", msg_suffix="please use `run_id` instead.") diff --git a/tests/test_bundle_get_data.py b/tests/test_bundle_get_data.py index a560f3945f..f13963cdca 100644 --- a/tests/test_bundle_get_data.py +++ b/tests/test_bundle_get_data.py @@ -25,7 +25,16 @@ TEST_CASE_2 = [{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": None}] -TEST_CASE_FAKE_TOKEN = [{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": "ghp_errortoken"}] +TEST_CASE_FAKE_TOKEN_1 = [{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": "ghp_errortoken"}] + +TEST_CASE_FAKE_TOKEN_2 = [ + { + "bundle_name": "spleen_ct_segmentation", + "version": "0.1.0", + "auth_token": "ghp_errortoken", + "model_info_url": "https://raw.githubusercontent.com/Project-MONAI/model-zoo/dev/models/model_info.json", + } +] @skip_if_windows @@ -39,6 +48,16 @@ def test_get_all_bundles_list(self): self.assertTrue(isinstance(output[0], tuple)) self.assertTrue(len(output[0]) == 2) + @skip_if_quick + def test_get_all_bundles_list_model_info_url(self): + with skip_if_downloading_fails(): + output = get_all_bundles_list( + model_info_url="https://raw.githubusercontent.com/Project-MONAI/model-zoo/dev/models/model_info.json" + ) + self.assertTrue(isinstance(output, list)) + self.assertTrue(isinstance(output[0], tuple)) + self.assertTrue(len(output[0]) == 2) + @parameterized.expand([TEST_CASE_1]) @skip_if_quick def test_get_bundle_versions(self, params): @@ -48,6 +67,18 @@ def test_get_bundle_versions(self, params): self.assertTrue("latest_version" in output and "all_versions" in output) self.assertTrue("0.1.0" in output["all_versions"]) + @parameterized.expand([TEST_CASE_1]) + @skip_if_quick + def test_get_bundle_versions_model_info_url(self, params): + with skip_if_downloading_fails(): + output = get_bundle_versions( + model_info_url="https://raw.githubusercontent.com/Project-MONAI/model-zoo/dev/models/model_info.json", + **params, + ) + self.assertTrue(isinstance(output, dict)) + self.assertTrue("latest_version" in output and "all_versions" in output) + self.assertTrue("0.1.0" in output["all_versions"]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) @skip_if_quick def test_get_bundle_info(self, params): @@ -57,7 +88,19 @@ def test_get_bundle_info(self, params): for key in ["id", "name", "size", "download_count", "browser_download_url"]: self.assertTrue(key in output) - @parameterized.expand([TEST_CASE_FAKE_TOKEN]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @skip_if_quick + def test_get_bundle_info_model_info_url(self, params): + with skip_if_downloading_fails(): + output = get_bundle_info( + model_info_url="https://raw.githubusercontent.com/Project-MONAI/model-zoo/dev/models/model_info.json", + **params, + ) + self.assertTrue(isinstance(output, dict)) + for key in ["name", "browser_download_url"]: + self.assertTrue(key in output) + + @parameterized.expand([TEST_CASE_FAKE_TOKEN_1, TEST_CASE_FAKE_TOKEN_2]) @skip_if_quick def test_fake_token(self, params): with skip_if_downloading_fails():