Skip to content

Commit

Permalink
revise convert base model's unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoXing1996 committed Apr 6, 2023
1 parent ba3c437 commit 94fd637
Showing 1 changed file with 9 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ def check_state_dict(s1, s2):
assert (s1[k] == s2[k]).all()


def parameters(model):
state_dict = model.state_dict()
for v in state_dict.values():
yield v


def test_change_base_model():
control_state_dict = make_state_dict(dict(k1=1, k2=2, k3=3))
target_control_state_dict = make_state_dict(dict(k1=1.5, k2=2.5, k3=3))
Expand All @@ -33,7 +39,9 @@ def test_change_base_model():
controlnet = MagicMock()
basemodel = MagicMock()
currmodel = MagicMock()

controlnet.parameters = MagicMock(return_value=parameters(controlnet))
basemodel.parameters = MagicMock(return_value=parameters(basemodel))
currmodel.parameters = MagicMock(return_value=parameters(currmodel))
controlnet.state_dict = MagicMock(return_value=control_state_dict)
basemodel.state_dict = MagicMock(return_value=base_state_dict)
currmodel.state_dict = MagicMock(return_value=curr_state_dict)
Expand Down

0 comments on commit 94fd637

Please sign in to comment.