Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Jul 18, 2024
1 parent a1a06bc commit 7865b4a
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions tests/utils/test_configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ def setUpClass(cls):
cls._token = TOKEN
HfFolder.save_token(TOKEN)

@staticmethod
def _try_delete_repo(repo_id, token):
try:
# Reset repo
delete_repo(repo_id=repo_id, token=token)
except: # noqa E722
pass

@classmethod
def tearDownClass(cls):
try:
Expand Down Expand Up @@ -125,13 +133,19 @@ def test_push_to_hub(self):
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)

def test_push_to_hub_via_save_pretrained(self):

try:
# Reset repo
delete_repo(token=self._token, repo_id=tmp_repo)
except: # noqa E722
pass
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-config-{Path(tmp_dir).name}"

config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
# Push to hub via save_pretrained
config.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)

Expand All @@ -140,13 +154,8 @@ def test_push_to_hub(self):
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
finally:
# Always try to delete
try:
# Reset repo
delete_repo(token=self._token, repo_id=tmp_repo)
except: # noqa E722
pass

# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)

def test_push_to_hub_in_organization(self):
config = BertConfig(
Expand Down

0 comments on commit 7865b4a

Please sign in to comment.