diff --git a/plugins/framework/src/fms_acceleration/model_patcher.py b/plugins/framework/src/fms_acceleration/model_patcher.py index 4a53f8d5..02b28eee 100644 --- a/plugins/framework/src/fms_acceleration/model_patcher.py +++ b/plugins/framework/src/fms_acceleration/model_patcher.py @@ -19,6 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union import importlib import inspect +import warnings # Third Party import pandas as pd @@ -268,6 +269,8 @@ def register(rule: ModelPatcherRule): @staticmethod def did_rule_trigger(module: torch.nn.Module, module_name: str): + + active_rule_name, active_rule = None, None for name, rule in ModelPatcher.rules.items(): # if there is no trigger @@ -275,9 +278,18 @@ def did_rule_trigger(module: torch.nn.Module, module_name: str): continue if rule.trigger.is_triggered(module, module_name): - return name, rule - - return None, None + # if active rule, assign the the current rule as active + if active_rule is None: + active_rule_name = name + active_rule = rule + # otherwise, if there is already an active rule, raise warning + # that subsequent compatible forward rules will be ignored for simple forward patches + # forwardbuilders are handled when they are decomposed into new simple forward rules + elif rule.forward is not None: + warnings.warn(f"rule {rule.rule_id} is ignored on {module_name} as an earlier rule {active_rule.rule_id} has been applied") + #raise Exception(f"rule {rule.rule_id} is ignored on {module_name} as an earlier rule has been applied") + + return active_rule_name, active_rule @staticmethod def _import_and_reload(model: torch.nn.Module): @@ -326,7 +338,18 @@ def _import_and_reload(model: torch.nn.Module): elif _target.startswith(module_path): _no_reload.append(rule) - assert len(_with_reload) <= 1, "can only have at most one rule with reload" + # If there are multiple reload targets, + # ensure that their paths do not conflict as reloading same module might reset patches + if len(_with_reload)>1: + # sort ascending target path length + _with_reload = sorted(_with_reload, key=lambda _rule: len(_rule.import_and_maybe_reload[2]), reverse=False) + for rule_s in _with_reload: + for rule_l in _with_reload[1:]: + # if target paths in rule s is a prefix of rule l, raise an error + _, _, _path_s = rule_s.import_and_maybe_reload + _, _, _path_l = rule_l.import_and_maybe_reload + assert not _path_l.startswith(_path_s), \ + f"Attempting to reload same path `{_path_s}` multiple times in {rule_s.rule_id} and {rule_l.rule_id}" # handle those with reload first for rule in _with_reload + _no_reload: @@ -444,6 +467,15 @@ def _patch_forwards( def patch(model: torch.nn.Module, **kwargs): # NOTE: for a set of rules, this patch function should be called # only once. We do not have any checks for this at the moment + + # 1. Iterate over all ModelPatcher rules + # 2. For import_and_maybe_reload rules, an assertion + # is currently thrown if there are multiple reloads + # 3. For _patch_forwards, ensure that the trigger check + # module or callable function is unique across all rules + # otherwise, an assertion is thrown as it could patch the + # forwards over previous patches + try: ModelPatcher._import_and_reload(model.get_base_model()) except AttributeError: diff --git a/plugins/framework/src/fms_acceleration/utils/test_utils.py b/plugins/framework/src/fms_acceleration/utils/test_utils.py index 4931e166..3952b528 100644 --- a/plugins/framework/src/fms_acceleration/utils/test_utils.py +++ b/plugins/framework/src/fms_acceleration/utils/test_utils.py @@ -186,7 +186,5 @@ def instantiate_model_patcher(): from fms_acceleration.model_patcher import ModelPatcher old_registrations = ModelPatcher.rules ModelPatcher.rules = {} - try: - yield - finally: - ModelPatcher.rules = old_registrations \ No newline at end of file + yield + ModelPatcher.rules = old_registrations \ No newline at end of file diff --git a/plugins/framework/tests/model_patcher_fixtures/module4/__init__.py b/plugins/framework/tests/model_patcher_fixtures/module4/__init__.py new file mode 100644 index 00000000..8f77a8ad --- /dev/null +++ b/plugins/framework/tests/model_patcher_fixtures/module4/__init__.py @@ -0,0 +1,8 @@ +from .module4_1 import mod_4_function +from .module5 import Module5Class, mod_5_function +import torch + +class Module4Class(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.attribute = Module5Class() diff --git a/plugins/framework/tests/model_patcher_fixtures/module4/module4_1.py b/plugins/framework/tests/model_patcher_fixtures/module4/module4_1.py new file mode 100644 index 00000000..aa7a9700 --- /dev/null +++ b/plugins/framework/tests/model_patcher_fixtures/module4/module4_1.py @@ -0,0 +1,2 @@ +def mod_4_function(): + return "unpatched_mod_function" \ No newline at end of file diff --git a/plugins/framework/tests/model_patcher_fixtures/module4/module5/__init__.py b/plugins/framework/tests/model_patcher_fixtures/module4/module5/__init__.py new file mode 100644 index 00000000..7652a2b7 --- /dev/null +++ b/plugins/framework/tests/model_patcher_fixtures/module4/module5/__init__.py @@ -0,0 +1 @@ +from .module5_1 import Module5Class, mod_5_function \ No newline at end of file diff --git a/plugins/framework/tests/model_patcher_fixtures/module4/module5/module5_1.py b/plugins/framework/tests/model_patcher_fixtures/module4/module5/module5_1.py new file mode 100644 index 00000000..dfba5e17 --- /dev/null +++ b/plugins/framework/tests/model_patcher_fixtures/module4/module5/module5_1.py @@ -0,0 +1,8 @@ +import torch + +class Module5Class(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + +def mod_5_function(): + return "unpatched_mod_function" \ No newline at end of file diff --git a/plugins/framework/tests/test_model_patcher.py b/plugins/framework/tests/test_model_patcher.py index 8dbd7795..f26c20c7 100644 --- a/plugins/framework/tests/test_model_patcher.py +++ b/plugins/framework/tests/test_model_patcher.py @@ -32,9 +32,7 @@ ) from .model_patcher_test_utils import create_module_class, isolate_test_module_fixtures -from .model_patcher_fixtures import module1 -from .model_patcher_fixtures import module2 -from fms_acceleration.utils.test_utils import instantiate_model_patcher +from .model_patcher_fixtures import module1, module2, module4 MOD_CLS_A = create_module_class("MOD_CLS_A") MOD_SUBCLS_A = create_module_class("MOD_SUBCLS_A", parent_class=MOD_CLS_A) @@ -279,6 +277,12 @@ def test_patch_target_module_replaces_module_or_function_correctly(): - attribute: mod_1_function - module2: - Module2Class: + + - module4: + - Module4Class: + - attribute: mod_1_function + + """ PatchedModuleClass = create_module_class( @@ -378,5 +382,3 @@ def patched_mod_function(): "tests.model_patcher_fixtures.module1.module3.module3_1", ) assert module1.module3.Module3Class().attribute() == "patched_mod_function" - - diff --git a/plugins/framework/tests/test_model_patcher2.py b/plugins/framework/tests/test_model_patcher2.py new file mode 100644 index 00000000..66081655 --- /dev/null +++ b/plugins/framework/tests/test_model_patcher2.py @@ -0,0 +1,245 @@ +# Third Party +import pytest # pylint: disable=(import-error +import torch + +# First Party +from fms_acceleration.model_patcher import ( + ModelPatcher, + ModelPatcherRule, + ModelPatcherTrigger, + patch_target_module, + ModelPatcherTriggerType, + ModelPatcherHistory, + combine_functions, + combine_triggers, +) + +from .model_patcher_test_utils import create_module_class, isolate_test_module_fixtures +from .model_patcher_fixtures import module1, module2, module4 +from fms_acceleration.utils.test_utils import instantiate_model_patcher + +from .test_model_patcher import DUMMY_RULE_ID + +#Test patching of model attribute +def test_simple_forward_rule_with_mp_replaces_old_forward(): # pylint: disable=redefined-outer-name + """ + model_patcher_fixtures: + - module1: + - module1_1: + - Module2Class: + - attribute: Module2Class + - mod_1_function + - module3: + - module3_1 + - Module3Class: + - attribute: mod_1_function + - module2: + - Module2Class: + + - module4: + - Module4Class(torch.nn.Module): + - attribute: mod_1_function + """ + + def patched_forward_function(X): + return "patched_forward_function" + + # 1. Create an instance of Module4Class as model + # 2. Add a submodule to Module4Class + # 3. Create and register rule to patch forward of submodule class + # 4. Patch model + # 5. Ensure that model's submodule forward is replaced + with isolate_test_module_fixtures(): + with instantiate_model_patcher(): + model = module4.Module4Class() + SubModule1 = create_module_class( + "SubModule1", + namespaces={"forward": lambda self: "unpatched_forward_function"} + ) + model.add_module("submodule_1", SubModule1()) + rule = ModelPatcherRule( + rule_id=DUMMY_RULE_ID, + trigger=ModelPatcherTrigger(check=SubModule1), + forward=patched_forward_function, + ) + ModelPatcher.register(rule) + ModelPatcher.patch(model) + + assert model.submodule_1.forward() == "patched_forward_function" + +def test_import_and_maybe_reload_rule_with_mp_replaces_old_attribute(): + # 1. Register rule replacing module5.module5_1.Module5Class with a patched_mod_function + # reload_target is test.model_patcher.fixtures.module4 + # 2. Patch module4.Module4Class with ModelPatcher + # 3. check patched module exist in module4.Module4Class.attribute + PatchedModuleClass = create_module_class( + "PatchedModClass", + ) + + + with isolate_test_module_fixtures(): + with instantiate_model_patcher(): + model = module4.Module4Class() + ModelPatcher.register( + ModelPatcherRule( + rule_id=DUMMY_RULE_ID, + import_and_maybe_reload=( + "tests.model_patcher_fixtures.module4.module5.Module5Class", + PatchedModuleClass, + "tests.model_patcher_fixtures.module4", + ), + ) + ) + ModelPatcher.patch(model) + assert isinstance(module4.Module4Class().attribute, PatchedModuleClass) + +# TODO forward builder test + + +def test_mp_throws_error_with_multiple_reloads_on_same_target(): + """ + Simulate a case where two rules attempt to reload on the same target prefix + + example: + - Rule 1 target path 1: x.y.z + - Rule 2 target path 2: x.y + + this might reverse the patch on Rule 1 and needs to be caught + + model_patcher_fixtures: + - module1: + - module1_1: + - Module2Class: + - attribute: Module2Class + - mod_1_function + - module3: + - module3_1 + - Module3Class: + - attribute: mod_1_function + - module2: + - Module2Class: + + - module4: + - Module4Class(torch.nn.Module): + - attribute: mod_1_function + - module4_1 + - mod_4_function + - module5: + - module5_1 + - Module5Class + - module_5_function + + """ + + PatchedModuleClass = create_module_class( + "PatchedModuleClass", + ) + + def patched_mod_function(): + return "patched_mod_function" + + # Demonstrate that the 2nd patch overwrites the 1st patch if the reload module paths are the same + with isolate_test_module_fixtures(): + # 1st patch on a function + patch_target_module( + "tests.model_patcher_fixtures.module4.module5.module5_1.mod_5_function", + patched_mod_function, + "tests.model_patcher_fixtures.module4.module5", + ) + + assert module4.module5.mod_5_function() == "patched_mod_function" + + # 2nd patch on a class that has a target path that reloads module5 as well + patch_target_module( + "tests.model_patcher_fixtures.module4.module5.module5_1.Module5Class", + PatchedModuleClass, + "tests.model_patcher_fixtures.module4.module5" + ) + + assert isinstance(module4.module5.Module5Class(), PatchedModuleClass) + assert module4.module5.mod_5_function() == "unpatched_mod_function" + + # Ensure that an assertion is raised if target paths share the same root path + with pytest.raises( + AssertionError, + ): + with isolate_test_module_fixtures(): + with instantiate_model_patcher(): + # 1. Initialize a model with module path tests.model_patcher_fixtures.module4 + model = module4.Module4Class() + + # 2. Simulate patching a function in module4.module5.module5_1 + ModelPatcher.register( + ModelPatcherRule( + rule_id=f"{DUMMY_RULE_ID}.2", + import_and_maybe_reload=( + "tests.model_patcher_fixtures.module4.module5.module5_1.mod_5_function", + patched_mod_function, + "tests.model_patcher_fixtures.module4.module5.module5_1", + ), + ) + ) + + # 3. Simulate patching a class in module4.module5.module5_1 + ModelPatcher.register( + ModelPatcherRule( + rule_id=f"{DUMMY_RULE_ID}.1", + import_and_maybe_reload=( + "tests.model_patcher_fixtures.module4.module5.module5_1.Module5Class", + PatchedModuleClass, + "tests.model_patcher_fixtures.module4", + ), + ) + ) + + # while ModelPatcher is patching different objects, repeated reloads on same path is risky + # since module4 is a parent of module5, reloading module4 again might affect the previous patch. + # To prevent this we throw an exception if the shorter target path is a prefix of the + # longer target path + ModelPatcher.patch(model) + + +def test_mp_throws_warning_with_multiple_patches(): + """ + Ensure for each module, only one forward patch is implemented on it. + The patch implementation checks if there are multiple forward patch rules that are applied to the module, + only the 1st forward patch rule is applied, the others will be ignored and a warning will be raised + + In the case of a list of new rules generated by `forwardbuilder`, it will be handled similarly since + it decomposes to multiple single forward patch rules downstream. + """ + with pytest.warns( + UserWarning, + ): + with isolate_test_module_fixtures(): + with instantiate_model_patcher(): + # 1. Create a model + # 2. Create a submodule to patch on + # 3. Create 1st rule to patch submodule forward function + # 4. Create 2nd rule to patch submodule forward function again + # 5. Throws warning that any subsequent forward patches after the 1st patch is ignored + + model = module4.Module4Class() + SubModule1 = create_module_class( + "SubModule1", + namespaces={"forward": lambda self: "unpatched_forward_function"} + ) + model.add_module("submodule_1", SubModule1()) + + ModelPatcher.register( + ModelPatcherRule( + rule_id=DUMMY_RULE_ID+".1", + trigger=ModelPatcherTrigger(check=SubModule1), + forward=lambda self: "patched_forward_function", + ) + ) + ModelPatcher.register( + ModelPatcherRule( + rule_id=DUMMY_RULE_ID+".2", + trigger=ModelPatcherTrigger(check=SubModule1), + forward=lambda self: "patched_forward_function_2", + ) + ) + ModelPatcher.patch(model) + + # TODO test on forward builder cases