Skip to content

Commit

Permalink
additional changes to trigger unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
achew010 committed Jul 23, 2024
1 parent 252a73c commit 94e217e
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 57 deletions.
8 changes: 4 additions & 4 deletions plugins/framework/src/fms_acceleration/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class ModelPatcherTrigger:

# the trigger operation
check: Union[
torch.nn.Module, # trigger on isinstance
Type[torch.nn.Module], # trigger on isinstance
Callable[[torch.nn.Module], bool], # trigger on callable
]

Expand All @@ -87,7 +87,7 @@ class ModelPatcherTrigger:

def is_triggered(
self,
module: Type[torch.nn.Module],
module: torch.nn.Module,
module_name: str = None,
):
"Check if trigger returns truthful."
Expand Down Expand Up @@ -488,13 +488,13 @@ def combine_triggers(*triggers: ModelPatcherTrigger, logic: str = "OR"):
# NOTE: this can be probably simplified
def _or_logic(*args, **kwargs):
for trig in triggers:
if trig.is_triggered(*args, **kwargs) is True:
if trig.is_triggered(*args, **kwargs):
return True
return False

def _and_logic(*args, **kwargs):
for trig in triggers:
if trig.is_triggered(*args, **kwargs) is False:
if not trig.is_triggered(*args, **kwargs):
return False
return True

Expand Down
2 changes: 1 addition & 1 deletion plugins/framework/tests/model_patcher_test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from typing import Dict, Any, Type

def create_dummy_module(
def create_module_class(
class_name: str,
namespaces: Dict[str, Any] = {},
parent_class: Type = torch.nn.Module
Expand Down
102 changes: 50 additions & 52 deletions plugins/framework/tests/test_model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@
combine_triggers,
)
from .model_patcher_test_utils import (
create_dummy_module,
create_module_class,
)
from fms_acceleration.utils.test_utils import instantiate_model_patcher

MODULE_A = create_dummy_module("A")
MODULE_SUB_A = create_dummy_module("sub_A", parent_class=MODULE_A)
MODULE_B = create_dummy_module("B")
MOD_CLS_A = create_module_class("MOD_CLS_A")
MOD_SUBCLS_A = create_module_class("MOD_SUBCLS_A", parent_class=MOD_CLS_A)
MOD_CLS_B = create_module_class("MOD_CLS_B")

def returns_false(*args, **kwargs):
"falsy function"
Expand Down Expand Up @@ -103,7 +103,7 @@ def test_mp_trigger_constructs_with_check_and_trigger_type_args():

# def test_mp_trigger_constructs_with_all_specified_args():
# "Test construction of trigger with check, type and module_name arguments"
# # check that trigger constructs
# # check that trigger constructs
# ModelPatcherTrigger(
# check=MODULE_A,
# type=ModelPatcherTriggerType.module,
Expand All @@ -123,84 +123,82 @@ def test_mp_trigger_constructs_with_check_and_trigger_type_args():
def test_mp_trigger_correctly_triggers():
"Test for correctnness of trigger behaviour"

module_A = create_dummy_module(
"module_A",
ModClassA = create_module_class(
"ModClassA",
namespaces={"attr_1": None},
)

module_B = create_dummy_module(
"module_B",
ModClassB = create_module_class(
"ModClassB",
)

subclass_A = create_dummy_module(
"subclass_A",
parent_class=module_A,
ModSubClassA = create_module_class(
"ModSubClassA",
parent_class=ModClassA,
)

# Scenario 1:
# if check is a Callable, is_triggered result must be equal to the boolean output of check
# 1. create function to check that returns true is is instance of module_A, otherwise return False
# 2. create trigger that checks using above function
# 1. create function to check that returns true if module has attribute `attr_1`, otherwise return False
# 2. create trigger that checks the above function
# 3. create a subclass of module_A and ensure is_triggered returns True
# 4. create a module_B and ensure is_triggered returns False
def check_module(module):
if isinstance(module, module_A):
if hasattr(module, "attr_1"):
return True
return False

assert ModelPatcherTrigger(check=check_module).is_triggered(
subclass_A(),
ModClassA(),
) is True

assert ModelPatcherTrigger(check=check_module).is_triggered(
module_B(),
ModClassB(),
) is False

# Scenario 2:
# Ensure return True, if is not an instance of ModelPatcherTrigger.check
# 1. create trigger that checks for module_A
# 2. create a subclass of module_A and check is_triggered returns True
# 3. create a module_B and check is_triggered returns False
assert ModelPatcherTrigger(check=module_A).is_triggered(
subclass_A(),
# 1. create trigger that checks for ModClassA
# 2. create an instance of ModClassA and check is_triggered returns True
# 3. create a subclass instance of ModClassA and check is_triggered returns True
# 4. create an instance of ModClassB and check is_triggered returns False
assert ModelPatcherTrigger(check=ModClassA).is_triggered(
ModClassA(),
) is True

assert ModelPatcherTrigger(check=ModClassA).is_triggered(
ModSubClassA(),
) is True

# Ensure returns False, if is not an instance of ModelPatcherTrigger.check
assert ModelPatcherTrigger(check=module_A).is_triggered(
module_B(),
assert ModelPatcherTrigger(check=ModClassA).is_triggered(
ModClassB(),
) is False

# Scenario 3:
# Static check to ensure additional constraint is checked
# 1. create an instance of module_B as model
# 2. register 2 submodules that inherit from module_B, submodule_1 and submodule_2
# 2. create a trigger that checks for an instance of module_B and `submodule_1` module name
# 3. for each module in model, ensure returns true if trigger detects module,
# 1. create an instance of ModClassA as model
# 2A. register 2 submodules instances of ModClassB, Submodule_1 and SubModule2
# 3. create a trigger that checks for an instance of module_B and `submodule_1` module name
# 4. for each module in model, ensure returns true if trigger detects module,
# otherwise it should return false

# Create model
model = module_A()
model = ModClassA()
# register submodules
submodule_A = create_dummy_module(
"submodule_1",
parent_class=module_B,
)
submodule_B = create_dummy_module(
"submodule_2",
parent_class=module_B,
)
model.add_module("submodule_1", submodule_A())
model.add_module("submodule_2", submodule_B())
model.add_module("submodule_1", ModClassB())
model.add_module("submodule_2", ModClassB())
# create trigger with search criteria
trigger = ModelPatcherTrigger(check=module_B, module_name="submodule_1")
trigger = ModelPatcherTrigger(check=ModClassB, module_name="submodule_1")
# iterate through modules in model
for name, module in model.named_modules():
if name == "submodule_1":
# assert that is_triggered returns true when module is found
# assert that is_triggered returns true when module is found
assert trigger.is_triggered(module, name) is True
else:
# assert that is_triggered otherwise returns false
# assert that is_triggered otherwise returns false
assert trigger.is_triggered(module, name) is False


# Each test instance has
# - target_module,
# - tuple of trigger check arguments
Expand All @@ -212,21 +210,21 @@ def check_module(module):
# 4. Otherwise, ensure that the combined_trigger returns the expected result on the target module
@pytest.mark.parametrize(
"target_module,trigger_checks,logic,expected_result", [
[MODULE_SUB_A(), (returns_true, MODULE_B), "OR", True], # True False
[MODULE_SUB_A(), (MODULE_B, returns_false), "OR", False], # False False
[MODULE_SUB_A(), (MODULE_A, returns_true), "OR", True], # True True
[MODULE_SUB_A(), (returns_true, MODULE_B), "AND", False], # True False
[MODULE_SUB_A(), (MODULE_B, returns_false), "AND", False], # False False
[MODULE_SUB_A(), (MODULE_A, returns_true), "AND", True], # True True
[MOD_SUBCLS_A(), (returns_true, MOD_CLS_B), "OR", True],
[MOD_SUBCLS_A(), (MOD_CLS_B, returns_false), "OR", False],
[MOD_SUBCLS_A(), (MOD_CLS_A, returns_true), "OR", True],
[MOD_CLS_B(), (returns_false, MOD_CLS_A), "AND", False],
[MOD_CLS_B(), (MOD_CLS_B, returns_false), "AND", False],
[MOD_CLS_B(), (MOD_CLS_B, returns_true), "AND", True],
[
MODULE_SUB_A(), (MODULE_B, MODULE_A), "NOR",
MOD_SUBCLS_A(), (MOD_CLS_B, MOD_CLS_A), "NOR",
(AssertionError, "Only `AND`, `OR` logic implemented for combining triggers")
],
])
def test_correct_output_combine_mp_triggers(target_module, trigger_checks, logic, expected_result):
def test_combine_mp_triggers_produces_correct_output(target_module, trigger_checks, logic, expected_result):
triggers = [ModelPatcherTrigger(check=check) for check in trigger_checks]

# if expected_result is a tuple of (Exception, Exception_message)
# if expected_result is a tuple of (Exception, Exception_message)
if isinstance(expected_result, tuple):
with pytest.raises(
expected_result[0],
Expand Down

0 comments on commit 94e217e

Please sign in to comment.