Skip to content

Commit

Permalink
Simplify state_dict test expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Feb 13, 2024
1 parent c18c7f8 commit 9a53df5
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions tests/test_adaption_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def test_save_pretrained_regression(self) -> None:
assert state_dict.keys() == state_dict_from_pretrained.keys()

# Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate).
assert len(list(state_dict.keys())) == 4
assert len(state_dict) == 4

# check if tensors equal
for key in state_dict.keys():
Expand Down Expand Up @@ -180,7 +180,7 @@ def test_save_pretrained(self) -> None:
assert state_dict.keys() == state_dict_from_pretrained.keys()

# Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate).
assert len(list(state_dict.keys())) == 4
assert len(state_dict) == 4

# check if tensors equal
for key in state_dict.keys():
Expand Down Expand Up @@ -228,7 +228,7 @@ def test_save_pretrained_selected_adapters(self) -> None:
assert state_dict.keys() == state_dict_from_pretrained.keys()

# Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate).
assert len(list(state_dict.keys())) == 4
assert len(state_dict) == 4

# check if tensors equal
for key in state_dict.keys():
Expand Down
4 changes: 2 additions & 2 deletions tests/test_multitask_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_save_pretrained(self) -> None:
assert state_dict.keys() == state_dict_from_pretrained.keys()

# Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate).
assert len(list(state_dict.keys())) == 3
assert len(state_dict) == 3

# check if tensors equal
for key in state_dict.keys():
Expand Down Expand Up @@ -177,7 +177,7 @@ def test_save_pretrained_regression(self) -> None:
assert state_dict.keys() == state_dict_from_pretrained.keys()

# Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate).
assert len(list(state_dict.keys())) == 3
assert len(state_dict) == 3

# check if tensors equal
for key in state_dict.keys():
Expand Down

0 comments on commit 9a53df5

Please sign in to comment.