diff --git a/.gitignore b/.gitignore index 149eca463..8c8b2456a 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,7 @@ nr_*/ /docs/make.bat /docs/Makefile /examples/training/quora_duplicate_questions/quora-IR-dataset/ -build \ No newline at end of file +build + +htmlcov +.coverage \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 3196b6c2a..4dd678748 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ -transformers>=4.6.0,<5.0.0 -tokenizers>=0.10.3 +transformers>=4.32.0,<5.0.0 tqdm torch>=1.6.0 numpy @@ -7,5 +6,5 @@ scikit-learn scipy nltk sentencepiece -huggingface-hub +huggingface-hub>=0.15.1 Pillow \ No newline at end of file diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index ad642b853..cd645caf7 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -9,7 +9,7 @@ import numpy as np from numpy import ndarray import transformers -from huggingface_hub import HfApi, HfFolder, Repository +from huggingface_hub import HfApi import torch from torch import nn, Tensor, device from torch.optim import Optimizer @@ -468,8 +468,9 @@ def _create_model_card(self, path: str, model_name: Optional[str] = None, train_ fOut.write(model_card.strip()) def save_to_hub(self, - repo_name: str, + repo_id: str, organization: Optional[str] = None, + token: Optional[str] = None, private: Optional[bool] = None, commit_message: str = "Add new SentenceTransformer model.", local_model_path: Optional[str] = None, @@ -479,90 +480,61 @@ def save_to_hub(self, """ Uploads all elements of this Sentence Transformer to a new HuggingFace Hub repository. - :param repo_name: Repository name for your model in the Hub. - :param organization: Organization in which you want to push your model or tokenizer (you must be a member of this organization). + :param repo_id: Repository name for your model in the Hub, including the user or organization. + :param token: An authentication token (See https://huggingface.co/settings/token) :param private: Set to true, for hosting a prive model :param commit_message: Message to commit while pushing. :param local_model_path: Path of the model locally. If set, this file path will be uploaded. Otherwise, the current model will be uploaded :param exist_ok: If true, saving to an existing repository is OK. If false, saving only to a new repository is possible :param replace_model_card: If true, replace an existing model card in the hub with the automatically created model card :param train_datasets: Datasets used to train the model. If set, the datasets will be added to the model card in the Hub. - :return: The url of the commit of your model in the given repository. - """ - token = HfFolder.get_token() - if token is None: - raise ValueError("You must login to the Hugging Face hub on this computer by typing `transformers-cli login`.") + :param organization: Deprecated. Organization in which you want to push your model or tokenizer (you must be a member of this organization). - if '/' in repo_name: - splits = repo_name.split('/', maxsplit=1) - if organization is None or organization == splits[0]: - organization = splits[0] - repo_name = splits[1] + :return: The url of the commit of your model in the repository on the Hugging Face Hub. + """ + if organization: + if "/" not in repo_id: + logger.warning( + f"Providing an `organization` to `save_to_hub` is deprecated, please use `repo_id=\"{organization}/{repo_id}\"` instead." + ) + repo_id = f"{organization}/{repo_id}" + elif repo_id.split("/")[0] != organization: + raise ValueError("Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id`.") else: - raise ValueError("You passed and invalid repository name: {}.".format(repo_name)) + logger.warning( + f"Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id=\"{repo_id}\"` instead." + ) - endpoint = "https://huggingface.co" - repo_id = repo_name - if organization: - repo_id = f"{organization}/{repo_id}" - repo_url = HfApi(endpoint=endpoint).create_repo( + api = HfApi(token=token) + repo_url = api.create_repo( + repo_id=repo_id, + private=private, + repo_type=None, + exist_ok=exist_ok, + ) + if local_model_path: + folder_url = api.upload_folder( repo_id=repo_id, - token=token, - private=private, - repo_type=None, - exist_ok=exist_ok, + folder_path=local_model_path, + commit_message=commit_message ) - full_model_name = repo_url[len(endpoint)+1:].strip("/") - - with tempfile.TemporaryDirectory() as tmp_dir: - # First create the repo (and clone its content if it's nonempty). - logger.info("Create repository and clone it if it exists") - repo = Repository(tmp_dir, clone_from=repo_url) - - # If user provides local files, copy them. - if local_model_path: - copy_tree(local_model_path, tmp_dir) - else: # Else, save model directly into local repo. + else: + with tempfile.TemporaryDirectory() as tmp_dir: create_model_card = replace_model_card or not os.path.exists(os.path.join(tmp_dir, 'README.md')) - self.save(tmp_dir, model_name=full_model_name, create_model_card=create_model_card, train_datasets=train_datasets) - - #Find files larger 5M and track with git-lfs - large_files = [] - for root, dirs, files in os.walk(tmp_dir): - for filename in files: - file_path = os.path.join(root, filename) - rel_path = os.path.relpath(file_path, tmp_dir) - - if os.path.getsize(file_path) > (5 * 1024 * 1024): - large_files.append(rel_path) - - if len(large_files) > 0: - logger.info("Track files with git lfs: {}".format(", ".join(large_files))) - repo.lfs_track(large_files) - - logger.info("Push model to the hub. This might take a while") - push_return = repo.push_to_hub(commit_message=commit_message) - - def on_rm_error(func, path, exc_info): - # path contains the path of the file that couldn't be removed - # let's just assume that it's read-only and unlink it. - try: - os.chmod(path, stat.S_IWRITE) - os.unlink(path) - except: - pass - - # Remove .git folder. On Windows, the .git folder might be read-only and cannot be deleted - # Hence, try to set write permissions on error - try: - for f in os.listdir(tmp_dir): - shutil.rmtree(os.path.join(tmp_dir, f), onerror=on_rm_error) - except Exception as e: - logger.warning("Error when deleting temp folder: {}".format(str(e))) - pass + self.save(tmp_dir, model_name=repo_url.repo_id, create_model_card=create_model_card, train_datasets=train_datasets) + folder_url = api.upload_folder( + repo_id=repo_id, + folder_path=tmp_dir, + commit_message=commit_message + ) + refs = api.list_repo_refs(repo_id=repo_id) + for branch in refs.branches: + if branch.name == "main": + return f"https://huggingface.co/{repo_id}/commit/{branch.target_commit}" + # This isn't expected to ever be reached. + return folder_url - return push_return def smart_batching_collate(self, batch): """ diff --git a/setup.py b/setup.py index baff91e65..34a507450 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ packages=find_packages(), python_requires=">=3.8.0", install_requires=[ - 'transformers>=4.6.0,<5.0.0', + 'transformers>=4.32.0,<5.0.0', 'tqdm', 'torch>=1.6.0', 'numpy', @@ -27,7 +27,7 @@ 'scipy', 'nltk', 'sentencepiece', - 'huggingface-hub>=0.4.0', + 'huggingface-hub>=0.15.1', 'Pillow' ], classifiers=[ diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index e0d3acf7a..a05ba1cba 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -3,64 +3,132 @@ """ +import logging from pathlib import Path import tempfile +import pytest +from huggingface_hub import HfApi, RepoUrl, GitRefs, GitRefInfo import torch from sentence_transformers import SentenceTransformer from sentence_transformers.models import Transformer, Pooling -import unittest - - -class TestSentenceTransformer(unittest.TestCase): - def test_load_with_safetensors(self): - with tempfile.TemporaryDirectory() as cache_folder: - safetensors_model = SentenceTransformer( - "sentence-transformers-testing/stsb-bert-tiny-safetensors", - cache_folder=cache_folder, - ) - - # Only the safetensors file must be loaded - pytorch_files = list(Path(cache_folder).glob("**/pytorch_model.bin")) - self.assertEqual(0, len(pytorch_files), msg="PyTorch model file must not be downloaded.") - safetensors_files = list(Path(cache_folder).glob("**/model.safetensors")) - self.assertEqual(1, len(safetensors_files), msg="Safetensors model file must be downloaded.") - - with tempfile.TemporaryDirectory() as cache_folder: - transformer = Transformer( - "sentence-transformers-testing/stsb-bert-tiny-safetensors", - cache_dir=cache_folder, - model_args={"use_safetensors": False}, - ) - pooling = Pooling(transformer.get_word_embedding_dimension()) - pytorch_model = SentenceTransformer(modules=[transformer, pooling]) - - # Only the pytorch file must be loaded - pytorch_files = list(Path(cache_folder).glob("**/pytorch_model.bin")) - self.assertEqual(1, len(pytorch_files), msg="PyTorch model file must be downloaded.") - safetensors_files = list(Path(cache_folder).glob("**/model.safetensors")) - self.assertEqual(0, len(safetensors_files), msg="Safetensors model file must not be downloaded.") - - sentences = ["This is a test sentence", "This is another test sentence"] - self.assertTrue( - torch.equal(safetensors_model.encode(sentences, convert_to_tensor=True), pytorch_model.encode(sentences, convert_to_tensor=True)), - msg="Ensure that Safetensors and PyTorch loaded models result in identical embeddings", + + +def test_load_with_safetensors() -> None: + with tempfile.TemporaryDirectory() as cache_folder: + safetensors_model = SentenceTransformer( + "sentence-transformers-testing/stsb-bert-tiny-safetensors", + cache_folder=cache_folder, + ) + + # Only the safetensors file must be loaded + pytorch_files = list(Path(cache_folder).glob("**/pytorch_model.bin")) + assert 0 == len(pytorch_files), "PyTorch model file must not be downloaded." + safetensors_files = list(Path(cache_folder).glob("**/model.safetensors")) + assert 1 == len(safetensors_files), "Safetensors model file must be downloaded." + + with tempfile.TemporaryDirectory() as cache_folder: + transformer = Transformer( + "sentence-transformers-testing/stsb-bert-tiny-safetensors", + cache_dir=cache_folder, + model_args={"use_safetensors": False}, ) + pooling = Pooling(transformer.get_word_embedding_dimension()) + pytorch_model = SentenceTransformer(modules=[transformer, pooling]) + + # Only the pytorch file must be loaded + pytorch_files = list(Path(cache_folder).glob("**/pytorch_model.bin")) + assert 1 == len(pytorch_files), "PyTorch model file must be downloaded." + safetensors_files = list(Path(cache_folder).glob("**/model.safetensors")) + assert 0 == len(safetensors_files), "Safetensors model file must not be downloaded." + + sentences = ["This is a test sentence", "This is another test sentence"] + assert torch.equal( + safetensors_model.encode(sentences, convert_to_tensor=True), + pytorch_model.encode(sentences, convert_to_tensor=True), + ), "Ensure that Safetensors and PyTorch loaded models result in identical embeddings" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test moving devices effectively.") +def test_to() -> None: + model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors", device="cpu") + + test_device = torch.device("cuda") + assert model.device.type == "cpu" + assert test_device.type == "cuda" + + model.to(test_device) + assert model.device.type == "cuda", "The model device should have updated" - @unittest.skipUnless(torch.cuda.is_available(), reason="CUDA must be available to test moving devices effectively.") - def test_to(self): - model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors", device="cpu") + model.encode("Test sentence") + assert model.device.type == "cuda", "Encoding shouldn't change the device" - test_device = torch.device("cuda") - self.assertEqual(model.device.type, "cpu") - self.assertEqual(test_device.type, "cuda") + assert model._target_device == model.device, "Prevent backwards compatibility failure for _target_device" + model._target_device = "cpu" + assert model.device.type == "cpu", "Ensure that setting `_target_device` doesn't crash." - model.to(test_device) - self.assertEqual(model.device.type, "cuda", msg="The model device should have updated") - model.encode("Test sentence") - self.assertEqual(model.device.type, "cuda", msg="Encoding shouldn't change the device") +def test_save_to_hub(monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture) -> None: + def mock_create_repo(self, repo_id, **kwargs): + return RepoUrl(f"https://huggingface.co/{repo_id}") + + mock_upload_folder_kwargs = {} + + def mock_upload_folder(self, **kwargs): + nonlocal mock_upload_folder_kwargs + mock_upload_folder_kwargs = kwargs + + def mock_list_repo_refs(self, repo_id=None, **kwargs): + try: + git_ref_info = GitRefInfo(name="main", ref="refs/heads/main", target_commit="123456") + except TypeError: + git_ref_info = GitRefInfo(dict(name="main", ref="refs/heads/main", targetCommit="123456")) + return GitRefs(branches=[git_ref_info], converts=[], tags=[]) + + monkeypatch.setattr(HfApi, "create_repo", mock_create_repo) + monkeypatch.setattr(HfApi, "upload_folder", mock_upload_folder) + monkeypatch.setattr(HfApi, "list_repo_refs", mock_list_repo_refs) + + model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors") + url = model.save_to_hub("sentence-transformers-testing/stsb-bert-tiny-safetensors") + assert mock_upload_folder_kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors" + assert url == "https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors/commit/123456" + mock_upload_folder_kwargs.clear() + + with pytest.raises( + ValueError, match="Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id`." + ): + model.save_to_hub("sentence-transformers-testing/stsb-bert-tiny-safetensors", organization="unrelated") + + caplog.clear() + with caplog.at_level(logging.WARNING): + url = model.save_to_hub( + "sentence-transformers-testing/stsb-bert-tiny-safetensors", organization="sentence-transformers-testing" + ) + assert mock_upload_folder_kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors" + assert url == "https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors/commit/123456" + assert len(caplog.record_tuples) == 1 + assert ( + caplog.record_tuples[0][2] + == 'Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id="sentence-transformers-testing/stsb-bert-tiny-safetensors"` instead.' + ) + mock_upload_folder_kwargs.clear() + + caplog.clear() + with caplog.at_level(logging.WARNING): + url = model.save_to_hub("stsb-bert-tiny-safetensors", organization="sentence-transformers-testing") + assert mock_upload_folder_kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors" + assert url == "https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors/commit/123456" + assert len(caplog.record_tuples) == 1 + assert ( + caplog.record_tuples[0][2] + == 'Providing an `organization` to `save_to_hub` is deprecated, please use `repo_id="sentence-transformers-testing/stsb-bert-tiny-safetensors"` instead.' + ) + mock_upload_folder_kwargs.clear() - self.assertEqual(model._target_device, model.device, msg="Prevent backwards compatibility failure for _target_device") - model._target_device = "cpu" - self.assertEqual(model.device.type, "cpu", msg="Ensure that setting `_target_device` doesn't crash.") \ No newline at end of file + url = model.save_to_hub( + "sentence-transformers-testing/stsb-bert-tiny-safetensors", local_model_path="my_fake_local_model_path" + ) + assert mock_upload_folder_kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors" + assert mock_upload_folder_kwargs["folder_path"] == "my_fake_local_model_path" + assert url == "https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors/commit/123456"