Skip to content

Commit

Permalink
Prevent backward compatibility breaking: return commit link
Browse files Browse the repository at this point in the history
Thanks @Wauplin for the help
  • Loading branch information
tomaarsen committed Dec 13, 2023
1 parent 1d03b12 commit e3eb4ab
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 26 deletions.
28 changes: 18 additions & 10 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def save_to_hub(self,
:param train_datasets: Datasets used to train the model. If set, the datasets will be added to the model card in the Hub.
:param organization: Deprecated. Organization in which you want to push your model or tokenizer (you must be a member of this organization).
:return: The URL to visualize the uploaded model on the Hugging Face hub.
:return: The url of the commit of your model in the repository on the Hugging Face Hub.
"""
if organization:
if "/" not in repo_id:
Expand All @@ -495,20 +495,28 @@ def save_to_hub(self,
exist_ok=exist_ok,
)
if local_model_path:
return api.upload_folder(
folder_url = api.upload_folder(
repo_id=repo_id,
folder_path=local_model_path,
commit_message=commit_message
)
else:
with tempfile.TemporaryDirectory() as tmp_dir:
create_model_card = replace_model_card or not os.path.exists(os.path.join(tmp_dir, 'README.md'))
self.save(tmp_dir, model_name=repo_url.repo_id, create_model_card=create_model_card, train_datasets=train_datasets)
folder_url = api.upload_folder(
repo_id=repo_id,
folder_path=tmp_dir,
commit_message=commit_message
)

refs = api.list_repo_refs(repo_id=repo_id)
for branch in refs.branches:
if branch.name == "main":
return f"https://huggingface.co/{repo_id}/commit/{branch.target_commit}"
# This isn't expected to ever be reached.
return folder_url

with tempfile.TemporaryDirectory() as tmp_dir:
create_model_card = replace_model_card or not os.path.exists(os.path.join(tmp_dir, 'README.md'))
self.save(tmp_dir, model_name=repo_url.repo_id, create_model_card=create_model_card, train_datasets=train_datasets)
return api.upload_folder(
repo_id=repo_id,
folder_path=tmp_dir,
commit_message=commit_message
)

def smart_batching_collate(self, batch):
"""
Expand Down
63 changes: 47 additions & 16 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import tempfile
import pytest

from huggingface_hub import HfApi, RepoUrl
from huggingface_hub import HfApi, RepoUrl, GitRefs, GitRefInfo
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer, Pooling
Expand Down Expand Up @@ -67,37 +67,68 @@ def test_to() -> None:
model._target_device = "cpu"
assert model.device.type == "cpu", "Ensure that setting `_target_device` doesn't crash."


def test_save_to_hub(monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture) -> None:
def mock_create_repo(self, repo_id, **kwargs):
return RepoUrl(f"https://huggingface.co/{repo_id}")

mock_upload_folder_kwargs = {}

def mock_upload_folder(self, **kwargs):
return kwargs
nonlocal mock_upload_folder_kwargs
mock_upload_folder_kwargs = kwargs

def mock_list_repo_refs(self, repo_id=None, **kwargs):
try:
git_ref_info = GitRefInfo(name="main", ref="refs/heads/main", target_commit="123456")
except TypeError:
git_ref_info = GitRefInfo(dict(name="main", ref="refs/heads/main", targetCommit="123456"))
return GitRefs(branches=[git_ref_info], converts=[], tags=[])

monkeypatch.setattr(HfApi, "create_repo", mock_create_repo)
monkeypatch.setattr(HfApi, "upload_folder", mock_upload_folder)
monkeypatch.setattr(HfApi, "list_repo_refs", mock_list_repo_refs)

model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")
kwargs = model.save_to_hub("sentence-transformers-testing/stsb-bert-tiny-safetensors")
assert kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"

with pytest.raises(ValueError, match="Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id`."):
url = model.save_to_hub("sentence-transformers-testing/stsb-bert-tiny-safetensors")
assert mock_upload_folder_kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"
assert url == "https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors/commit/123456"
mock_upload_folder_kwargs.clear()

with pytest.raises(
ValueError, match="Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id`."
):
model.save_to_hub("sentence-transformers-testing/stsb-bert-tiny-safetensors", organization="unrelated")

caplog.clear()
with caplog.at_level(logging.WARNING):
kwargs = model.save_to_hub("sentence-transformers-testing/stsb-bert-tiny-safetensors", organization="sentence-transformers-testing")
assert kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"
url = model.save_to_hub(
"sentence-transformers-testing/stsb-bert-tiny-safetensors", organization="sentence-transformers-testing"
)
assert mock_upload_folder_kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"
assert url == "https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors/commit/123456"
assert len(caplog.record_tuples) == 1
assert caplog.record_tuples[0][2] == "Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id=\"sentence-transformers-testing/stsb-bert-tiny-safetensors\"` instead."
assert (
caplog.record_tuples[0][2]
== 'Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id="sentence-transformers-testing/stsb-bert-tiny-safetensors"` instead.'
)
mock_upload_folder_kwargs.clear()

caplog.clear()
with caplog.at_level(logging.WARNING):
kwargs = model.save_to_hub("stsb-bert-tiny-safetensors", organization="sentence-transformers-testing")
assert kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"
url = model.save_to_hub("stsb-bert-tiny-safetensors", organization="sentence-transformers-testing")
assert mock_upload_folder_kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"
assert url == "https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors/commit/123456"
assert len(caplog.record_tuples) == 1
assert caplog.record_tuples[0][2] == "Providing an `organization` to `save_to_hub` is deprecated, please use `repo_id=\"sentence-transformers-testing/stsb-bert-tiny-safetensors\"` instead."

kwargs = model.save_to_hub("sentence-transformers-testing/stsb-bert-tiny-safetensors", local_model_path="my_fake_local_model_path")
assert kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"
assert kwargs["folder_path"] == "my_fake_local_model_path"
assert (
caplog.record_tuples[0][2]
== 'Providing an `organization` to `save_to_hub` is deprecated, please use `repo_id="sentence-transformers-testing/stsb-bert-tiny-safetensors"` instead.'
)
mock_upload_folder_kwargs.clear()

url = model.save_to_hub(
"sentence-transformers-testing/stsb-bert-tiny-safetensors", local_model_path="my_fake_local_model_path"
)
assert mock_upload_folder_kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"
assert mock_upload_folder_kwargs["folder_path"] == "my_fake_local_model_path"
assert url == "https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors/commit/123456"

0 comments on commit e3eb4ab

Please sign in to comment.