diff --git a/plugins/framework/src/fms_acceleration/model_patcher.py b/plugins/framework/src/fms_acceleration/model_patcher.py index f3825fde..149f1650 100644 --- a/plugins/framework/src/fms_acceleration/model_patcher.py +++ b/plugins/framework/src/fms_acceleration/model_patcher.py @@ -74,7 +74,7 @@ class ModelPatcherTrigger: # the trigger operation check: Union[ - torch.nn.Module, # trigger on isinstance + Type[torch.nn.Module], # trigger on isinstance Callable[[torch.nn.Module], bool], # trigger on callable ] @@ -87,7 +87,7 @@ class ModelPatcherTrigger: def is_triggered( self, - module: Type[torch.nn.Module], + module: torch.nn.Module, module_name: str = None, ): "Check if trigger returns truthful." @@ -488,13 +488,13 @@ def combine_triggers(*triggers: ModelPatcherTrigger, logic: str = "OR"): # NOTE: this can be probably simplified def _or_logic(*args, **kwargs): for trig in triggers: - if trig.is_triggered(*args, **kwargs) is True: + if trig.is_triggered(*args, **kwargs): return True return False def _and_logic(*args, **kwargs): for trig in triggers: - if trig.is_triggered(*args, **kwargs) is False: + if not trig.is_triggered(*args, **kwargs): return False return True diff --git a/plugins/framework/tests/model_patcher_test_utils.py b/plugins/framework/tests/model_patcher_test_utils.py index bbc235ab..3f624542 100644 --- a/plugins/framework/tests/model_patcher_test_utils.py +++ b/plugins/framework/tests/model_patcher_test_utils.py @@ -1,7 +1,7 @@ import torch from typing import Dict, Any, Type -def create_dummy_module( +def create_module_class( class_name: str, namespaces: Dict[str, Any] = {}, parent_class: Type = torch.nn.Module diff --git a/plugins/framework/tests/test_model_patcher.py b/plugins/framework/tests/test_model_patcher.py index 49f44bb4..a18db72f 100644 --- a/plugins/framework/tests/test_model_patcher.py +++ b/plugins/framework/tests/test_model_patcher.py @@ -31,13 +31,13 @@ combine_triggers, ) from .model_patcher_test_utils import ( - create_dummy_module, + create_module_class, ) from fms_acceleration.utils.test_utils import instantiate_model_patcher -MODULE_A = create_dummy_module("A") -MODULE_SUB_A = create_dummy_module("sub_A", parent_class=MODULE_A) -MODULE_B = create_dummy_module("B") +MOD_CLS_A = create_module_class("MOD_CLS_A") +MOD_SUBCLS_A = create_module_class("MOD_SUBCLS_A", parent_class=MOD_CLS_A) +MOD_CLS_B = create_module_class("MOD_CLS_B") def returns_false(*args, **kwargs): "falsy function" @@ -103,7 +103,7 @@ def test_mp_trigger_constructs_with_check_and_trigger_type_args(): # def test_mp_trigger_constructs_with_all_specified_args(): # "Test construction of trigger with check, type and module_name arguments" -# # check that trigger constructs +# # check that trigger constructs # ModelPatcherTrigger( # check=MODULE_A, # type=ModelPatcherTriggerType.module, @@ -123,84 +123,82 @@ def test_mp_trigger_constructs_with_check_and_trigger_type_args(): def test_mp_trigger_correctly_triggers(): "Test for correctnness of trigger behaviour" - module_A = create_dummy_module( - "module_A", + ModClassA = create_module_class( + "ModClassA", + namespaces={"attr_1": None}, ) - module_B = create_dummy_module( - "module_B", + ModClassB = create_module_class( + "ModClassB", ) - subclass_A = create_dummy_module( - "subclass_A", - parent_class=module_A, + ModSubClassA = create_module_class( + "ModSubClassA", + parent_class=ModClassA, ) # Scenario 1: # if check is a Callable, is_triggered result must be equal to the boolean output of check - # 1. create function to check that returns true is is instance of module_A, otherwise return False - # 2. create trigger that checks using above function + # 1. create function to check that returns true if module has attribute `attr_1`, otherwise return False + # 2. create trigger that checks the above function # 3. create a subclass of module_A and ensure is_triggered returns True # 4. create a module_B and ensure is_triggered returns False def check_module(module): - if isinstance(module, module_A): + if hasattr(module, "attr_1"): return True return False assert ModelPatcherTrigger(check=check_module).is_triggered( - subclass_A(), + ModClassA(), ) is True assert ModelPatcherTrigger(check=check_module).is_triggered( - module_B(), + ModClassB(), ) is False # Scenario 2: # Ensure return True, if is not an instance of ModelPatcherTrigger.check - # 1. create trigger that checks for module_A - # 2. create a subclass of module_A and check is_triggered returns True - # 3. create a module_B and check is_triggered returns False - assert ModelPatcherTrigger(check=module_A).is_triggered( - subclass_A(), + # 1. create trigger that checks for ModClassA + # 2. create an instance of ModClassA and check is_triggered returns True + # 3. create a subclass instance of ModClassA and check is_triggered returns True + # 4. create an instance of ModClassB and check is_triggered returns False + assert ModelPatcherTrigger(check=ModClassA).is_triggered( + ModClassA(), + ) is True + + assert ModelPatcherTrigger(check=ModClassA).is_triggered( + ModSubClassA(), ) is True # Ensure returns False, if is not an instance of ModelPatcherTrigger.check - assert ModelPatcherTrigger(check=module_A).is_triggered( - module_B(), + assert ModelPatcherTrigger(check=ModClassA).is_triggered( + ModClassB(), ) is False # Scenario 3: # Static check to ensure additional constraint is checked - # 1. create an instance of module_B as model - # 2. register 2 submodules that inherit from module_B, submodule_1 and submodule_2 - # 2. create a trigger that checks for an instance of module_B and `submodule_1` module name - # 3. for each module in model, ensure returns true if trigger detects module, + # 1. create an instance of ModClassA as model + # 2A. register 2 submodules instances of ModClassB, Submodule_1 and SubModule2 + # 3. create a trigger that checks for an instance of module_B and `submodule_1` module name + # 4. for each module in model, ensure returns true if trigger detects module, # otherwise it should return false + # Create model - model = module_A() + model = ModClassA() # register submodules - submodule_A = create_dummy_module( - "submodule_1", - parent_class=module_B, - ) - submodule_B = create_dummy_module( - "submodule_2", - parent_class=module_B, - ) - model.add_module("submodule_1", submodule_A()) - model.add_module("submodule_2", submodule_B()) + model.add_module("submodule_1", ModClassB()) + model.add_module("submodule_2", ModClassB()) # create trigger with search criteria - trigger = ModelPatcherTrigger(check=module_B, module_name="submodule_1") + trigger = ModelPatcherTrigger(check=ModClassB, module_name="submodule_1") # iterate through modules in model for name, module in model.named_modules(): if name == "submodule_1": - # assert that is_triggered returns true when module is found + # assert that is_triggered returns true when module is found assert trigger.is_triggered(module, name) is True else: - # assert that is_triggered otherwise returns false + # assert that is_triggered otherwise returns false assert trigger.is_triggered(module, name) is False - # Each test instance has # - target_module, # - tuple of trigger check arguments @@ -212,21 +210,21 @@ def check_module(module): # 4. Otherwise, ensure that the combined_trigger returns the expected result on the target module @pytest.mark.parametrize( "target_module,trigger_checks,logic,expected_result", [ - [MODULE_SUB_A(), (returns_true, MODULE_B), "OR", True], # True False - [MODULE_SUB_A(), (MODULE_B, returns_false), "OR", False], # False False - [MODULE_SUB_A(), (MODULE_A, returns_true), "OR", True], # True True - [MODULE_SUB_A(), (returns_true, MODULE_B), "AND", False], # True False - [MODULE_SUB_A(), (MODULE_B, returns_false), "AND", False], # False False - [MODULE_SUB_A(), (MODULE_A, returns_true), "AND", True], # True True + [MOD_SUBCLS_A(), (returns_true, MOD_CLS_B), "OR", True], + [MOD_SUBCLS_A(), (MOD_CLS_B, returns_false), "OR", False], + [MOD_SUBCLS_A(), (MOD_CLS_A, returns_true), "OR", True], + [MOD_CLS_B(), (returns_false, MOD_CLS_A), "AND", False], + [MOD_CLS_B(), (MOD_CLS_B, returns_false), "AND", False], + [MOD_CLS_B(), (MOD_CLS_B, returns_true), "AND", True], [ - MODULE_SUB_A(), (MODULE_B, MODULE_A), "NOR", + MOD_SUBCLS_A(), (MOD_CLS_B, MOD_CLS_A), "NOR", (AssertionError, "Only `AND`, `OR` logic implemented for combining triggers") ], ]) -def test_correct_output_combine_mp_triggers(target_module, trigger_checks, logic, expected_result): +def test_combine_mp_triggers_produces_correct_output(target_module, trigger_checks, logic, expected_result): triggers = [ModelPatcherTrigger(check=check) for check in trigger_checks] - # if expected_result is a tuple of (Exception, Exception_message) + # if expected_result is a tuple of (Exception, Exception_message) if isinstance(expected_result, tuple): with pytest.raises( expected_result[0],