From 94fd63769d2705860feeef1a3c7f84f7d404e2db Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Thu, 6 Apr 2023 22:08:28 +0800 Subject: [PATCH] revise convert base model's unit test --- .../test_controlnet/test_controlnet_utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/test_models/test_editors/test_controlnet/test_controlnet_utils.py b/tests/test_models/test_editors/test_controlnet/test_controlnet_utils.py index d915dda639..c76bd3904a 100644 --- a/tests/test_models/test_editors/test_controlnet/test_controlnet_utils.py +++ b/tests/test_models/test_editors/test_controlnet/test_controlnet_utils.py @@ -23,6 +23,12 @@ def check_state_dict(s1, s2): assert (s1[k] == s2[k]).all() +def parameters(model): + state_dict = model.state_dict() + for v in state_dict.values(): + yield v + + def test_change_base_model(): control_state_dict = make_state_dict(dict(k1=1, k2=2, k3=3)) target_control_state_dict = make_state_dict(dict(k1=1.5, k2=2.5, k3=3)) @@ -33,7 +39,9 @@ def test_change_base_model(): controlnet = MagicMock() basemodel = MagicMock() currmodel = MagicMock() - + controlnet.parameters = MagicMock(return_value=parameters(controlnet)) + basemodel.parameters = MagicMock(return_value=parameters(basemodel)) + currmodel.parameters = MagicMock(return_value=parameters(currmodel)) controlnet.state_dict = MagicMock(return_value=control_state_dict) basemodel.state_dict = MagicMock(return_value=base_state_dict) currmodel.state_dict = MagicMock(return_value=curr_state_dict)