From 785b971df3ee213befa98997b9ad7c78e2396f3f Mon Sep 17 00:00:00 2001 From: 1000850000 user Date: Tue, 23 Jul 2024 11:42:35 +0000 Subject: [PATCH] adding MP Rule unit tests --- .../src/fms_acceleration/model_patcher.py | 14 ++ .../tests/model_patcher_fixtures/__init__.py | 0 .../module1/__init__.py | 2 + .../module1/module1_1.py | 8 ++ .../module1/module3/__init__.py | 1 + .../module1/module3/module3_1.py | 5 + .../tests/model_patcher_fixtures/module2.py | 3 + plugins/framework/tests/test_model_patcher.py | 134 +++++++++++++++++- 8 files changed, 163 insertions(+), 4 deletions(-) create mode 100644 plugins/framework/tests/model_patcher_fixtures/__init__.py create mode 100644 plugins/framework/tests/model_patcher_fixtures/module1/__init__.py create mode 100644 plugins/framework/tests/model_patcher_fixtures/module1/module1_1.py create mode 100644 plugins/framework/tests/model_patcher_fixtures/module1/module3/__init__.py create mode 100644 plugins/framework/tests/model_patcher_fixtures/module1/module3/module3_1.py create mode 100644 plugins/framework/tests/model_patcher_fixtures/module2.py diff --git a/plugins/framework/src/fms_acceleration/model_patcher.py b/plugins/framework/src/fms_acceleration/model_patcher.py index 149f1650..0b35154a 100644 --- a/plugins/framework/src/fms_acceleration/model_patcher.py +++ b/plugins/framework/src/fms_acceleration/model_patcher.py @@ -31,6 +31,7 @@ def patch_target_module( to_patch: str, replace_with: Any, target_module: str = None, + force_target_module_check: bool = True, ): to_patch = to_patch.split(".") assert len(to_patch) > 1, "must have an object to patch" @@ -42,6 +43,11 @@ def patch_target_module( setattr(source, obj_name_to_patch, replace_with) if target_module is not None: + # if target module is a parent package of to_patch, + # it will reload the old object over the patch + if force_target_module_check: + assert target_module not in to_patch, \ + "argument target_module cannot have same root path as to_patch" # reload and this should get the patched object target_module = importlib.import_module(target_module) importlib.reload(target_module) @@ -196,6 +202,14 @@ def __post_init__(self): "forward_builder." ) + # if self.import_and_maybe_reload is not None and self.import_and_maybe_reload[2] in self.import_and_maybe_reload[0]: + # raise ValueError( + # f"Rule '{self.rule_id}' import_and_maybe_reload specified has argument 3 in the same path " + # "as argument 1. The path to reload has to be different from object to be patched." + # ) + + + # helpful to keep a history of all patching that has been done @dataclass diff --git a/plugins/framework/tests/model_patcher_fixtures/__init__.py b/plugins/framework/tests/model_patcher_fixtures/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/plugins/framework/tests/model_patcher_fixtures/module1/__init__.py b/plugins/framework/tests/model_patcher_fixtures/module1/__init__.py new file mode 100644 index 00000000..546e2bed --- /dev/null +++ b/plugins/framework/tests/model_patcher_fixtures/module1/__init__.py @@ -0,0 +1,2 @@ +from .module3 import Module3Class +from .module1_1 import Module1Class, mod_1_function \ No newline at end of file diff --git a/plugins/framework/tests/model_patcher_fixtures/module1/module1_1.py b/plugins/framework/tests/model_patcher_fixtures/module1/module1_1.py new file mode 100644 index 00000000..2959805f --- /dev/null +++ b/plugins/framework/tests/model_patcher_fixtures/module1/module1_1.py @@ -0,0 +1,8 @@ +from ..module2 import Module2Class + +class Module1Class: + def __init__(self) -> None: + self.attribute = Module2Class() + +def mod_1_function(): + return 1 \ No newline at end of file diff --git a/plugins/framework/tests/model_patcher_fixtures/module1/module3/__init__.py b/plugins/framework/tests/model_patcher_fixtures/module1/module3/__init__.py new file mode 100644 index 00000000..9aa0c47d --- /dev/null +++ b/plugins/framework/tests/model_patcher_fixtures/module1/module3/__init__.py @@ -0,0 +1 @@ +from .module3_1 import Module3Class \ No newline at end of file diff --git a/plugins/framework/tests/model_patcher_fixtures/module1/module3/module3_1.py b/plugins/framework/tests/model_patcher_fixtures/module1/module3/module3_1.py new file mode 100644 index 00000000..b4d27c66 --- /dev/null +++ b/plugins/framework/tests/model_patcher_fixtures/module1/module3/module3_1.py @@ -0,0 +1,5 @@ +from ..module1_1 import mod_1_function + +class Module3Class: + def __init__(self) -> None: + self.attribute = mod_1_function diff --git a/plugins/framework/tests/model_patcher_fixtures/module2.py b/plugins/framework/tests/model_patcher_fixtures/module2.py new file mode 100644 index 00000000..1ab99494 --- /dev/null +++ b/plugins/framework/tests/model_patcher_fixtures/module2.py @@ -0,0 +1,3 @@ +class Module2Class: + def __init__(self) -> None: + pass diff --git a/plugins/framework/tests/test_model_patcher.py b/plugins/framework/tests/test_model_patcher.py index a18db72f..7fd3a853 100644 --- a/plugins/framework/tests/test_model_patcher.py +++ b/plugins/framework/tests/test_model_patcher.py @@ -30,9 +30,10 @@ combine_functions, combine_triggers, ) -from .model_patcher_test_utils import ( - create_module_class, -) + +from .model_patcher_test_utils import create_module_class +from .model_patcher_fixtures import module1 +from .model_patcher_fixtures import module2 from fms_acceleration.utils.test_utils import instantiate_model_patcher MOD_CLS_A = create_module_class("MOD_CLS_A") @@ -178,7 +179,7 @@ def check_module(module): # Scenario 3: # Static check to ensure additional constraint is checked # 1. create an instance of ModClassA as model - # 2A. register 2 submodules instances of ModClassB, Submodule_1 and SubModule2 + # 2. register 2 submodules instances of ModClassB, Submodule_1 and SubModule_2 # 3. create a trigger that checks for an instance of module_B and `submodule_1` module name # 4. for each module in model, ensure returns true if trigger detects module, # otherwise it should return false @@ -240,3 +241,128 @@ def test_combine_mp_triggers_produces_correct_output(target_module, trigger_chec logic=logic, ).is_triggered(target_module) is expected_result + +def test_mp_rule_raises_error_when_arguments_incorrectly_configured(): + "Ensure MP rule is throws appropriate error when wrong argument combinations are passed" + # Test mp rule construction raises with multiple arguments + with pytest.raises( + ValueError, + match="must only have only one of forward, " \ + "foward builder, or import_and_maybe_reload, specified." + ): + ModelPatcherRule( + rule_id=DUMMY_RULE_ID, + forward=lambda self, X: X, + import_and_maybe_reload=(), + forward_builder=lambda self, X: X, + ) + + # Test mp rule construction raises with trigger and import_and_reload + with pytest.raises( + ValueError, + match="has import_and_maybe_reload specified, " \ + "and trigger must be None." + ): + ModelPatcherRule( + rule_id=DUMMY_RULE_ID, + trigger=ModelPatcherTrigger(check=torch.nn.Module), + import_and_maybe_reload=(), + ) + + # Test that rule construction raises if forward_builder_args are provided without a forward_builder + with pytest.raises( + ValueError, + match="has forward_builder_args but no " \ + "forward_builder." + ): + ModelPatcherRule( + rule_id=DUMMY_RULE_ID, + forward_builder_args=[] + ) + + # Test that rule construction fails if a valid import_and_maybe_reload + # has a reload target in the patch of object to be patched + # with pytest.raises( + # ValueError, + # match="import_and_maybe_reload specified has argument 3 in the same path " \ + # "as argument 1. The path to reload has to be different from object to be patched." + # ): + # ModelPatcherRule( + # rule_id=DUMMY_RULE_ID, + # import_and_maybe_reload=( + # "tests.model_patcher_test_utils.MOD_CLS_TO_PATCH", + # torch.nn.Module, + # "tests", + # ) + # ) + +def test_patch_target_module_replaces_module_or_function_correctly(): + """ + Test patching of standalone file functions + + Fixture Structure: + model_patcher_fixtures + __init__.py + module1 + __init__.py + module1_1.py + module3 + __init__.py + module3_1.py + module2.py + """ + + # 1. Create a patched module, ModClassPatched + # 2. Call patch_target_module on modules2 in model_patcher_fixtures + # 3. check that modules1 has been reloaded with patched module + + PatchedModClass = create_module_class( + "PatchedModClass", + ) + + def patched_mod_function(): + return "patched_mod_function" + + # S1 - module1_1 has function mod_1_function + # 1. Replace module1_1.mod_1_function with new function + # 2. Ensure patch_target_module replaces with a new function + patch_target_module( + "tests.model_patcher_fixtures.module1.module1_1.mod_1_function", + patched_mod_function, + "tests.model_patcher_fixtures.module1", + force_target_module_check=False, + ) + assert module1.mod_1_function() == patched_mod_function() + + # S2 - module1_1.Module1Class has an attribute module2.Module2Class + # 1. Replace Module2Class with new class and reload module1_1 + # 2. Ensure patch_target_module replaces the attribute with a new attr class + patch_target_module( + "tests.model_patcher_fixtures.module2.Module2Class", + PatchedModClass, + "tests.model_patcher_fixtures.module1.module1_1" + ) + assert isinstance(module1.Module1Class().attribute, PatchedModClass) + + # S3 - module1.module3.module3_1 is a submodule of module1 + # 1. Replace module1.module3.module3_1.Module3Class with a new class + # 2. Show that reloading module1 does not replace the class + patch_target_module( + "tests.model_patcher_fixtures.module1.module3.module3_1.Module3Class", + PatchedModClass, + "tests.model_patcher_fixtures.module1", + force_target_module_check=False, + ) + assert not isinstance(module1.module3.Module3Class(), PatchedModClass) + + # S4 - module1.module3 submodule has a dependency on parent module1.mod_1_function + # 1. Replace the module1.module1_1.mod_1_function with a new function + # 2. Ensure the function is replaced after reloading module1.module3 submodule + patch_target_module( + "tests.model_patcher_fixtures.module1.module1_1.mod_1_function", + patched_mod_function, + "tests.model_patcher_fixtures.module1.module3.module3_1", + ) + assert module1.module3.Module3Class().attribute() == patched_mod_function() + +