Skip to content

Commit

Permalink
adding MP Rule unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
achew010 committed Jul 23, 2024
1 parent 94e217e commit a31bf6e
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 4 deletions.
14 changes: 14 additions & 0 deletions plugins/framework/src/fms_acceleration/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def patch_target_module(
to_patch: str,
replace_with: Any,
target_module: str = None,
force_target_module_check: bool = True,
):
to_patch = to_patch.split(".")
assert len(to_patch) > 1, "must have an object to patch"
Expand All @@ -42,6 +43,11 @@ def patch_target_module(
setattr(source, obj_name_to_patch, replace_with)

if target_module is not None:
# if target module is a parent package of to_patch,
# it will reload the old object over the patch
if force_target_module_check:
assert target_module not in to_patch, \
"argument target_module cannot have same root path as to_patch"
# reload and this should get the patched object
target_module = importlib.import_module(target_module)
importlib.reload(target_module)
Expand Down Expand Up @@ -196,6 +202,14 @@ def __post_init__(self):
"forward_builder."
)

# if self.import_and_maybe_reload is not None and self.import_and_maybe_reload[2] in self.import_and_maybe_reload[0]:
# raise ValueError(
# f"Rule '{self.rule_id}' import_and_maybe_reload specified has argument 3 in the same path "
# "as argument 1. The path to reload has to be different from object to be patched."
# )




# helpful to keep a history of all patching that has been done
@dataclass
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .module3 import Module3Class
from .module1_1 import Module1Class, mod_1_function
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from ..module2 import Module2Class

class Module1Class:
def __init__(self) -> None:
self.attribute = Module2Class()

def mod_1_function():
return 1
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .module3_1 import Module3Class
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ..module1_1 import mod_1_function

class Module3Class:
def __init__(self) -> None:
self.attribute = mod_1_function
3 changes: 3 additions & 0 deletions plugins/framework/tests/model_patcher_fixtures/module2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class Module2Class:
def __init__(self) -> None:
pass
130 changes: 126 additions & 4 deletions plugins/framework/tests/test_model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
combine_functions,
combine_triggers,
)
from .model_patcher_test_utils import (
create_module_class,
)

from .model_patcher_test_utils import create_module_class
from .model_patcher_fixtures import module1
from .model_patcher_fixtures import module2
from fms_acceleration.utils.test_utils import instantiate_model_patcher

MOD_CLS_A = create_module_class("MOD_CLS_A")
Expand Down Expand Up @@ -178,7 +179,7 @@ def check_module(module):
# Scenario 3:
# Static check to ensure additional constraint is checked
# 1. create an instance of ModClassA as model
# 2A. register 2 submodules instances of ModClassB, Submodule_1 and SubModule2
# 2. register 2 submodules instances of ModClassB, Submodule_1 and SubModule_2
# 3. create a trigger that checks for an instance of module_B and `submodule_1` module name
# 4. for each module in model, ensure returns true if trigger detects module,
# otherwise it should return false
Expand Down Expand Up @@ -240,3 +241,124 @@ def test_combine_mp_triggers_produces_correct_output(target_module, trigger_chec
logic=logic,
).is_triggered(target_module) is expected_result


def test_mp_rule_raises_error_when_arguments_incorrectly_configured():
"Ensure MP rule is throws appropriate error when wrong argument combinations are passed"
# Test mp rule construction raises with multiple arguments
with pytest.raises(
ValueError,
match="must only have only one of forward, " \
"foward builder, or import_and_maybe_reload, specified."
):
ModelPatcherRule(
rule_id=DUMMY_RULE_ID,
forward=lambda self, X: X,
import_and_maybe_reload=(),
forward_builder=lambda self, X: X,
)

# Test mp rule construction raises with trigger and import_and_reload
with pytest.raises(
ValueError,
match="has import_and_maybe_reload specified, " \
"and trigger must be None."
):
ModelPatcherRule(
rule_id=DUMMY_RULE_ID,
trigger=ModelPatcherTrigger(check=torch.nn.Module),
import_and_maybe_reload=(),
)

# Test that rule construction raises if forward_builder_args are provided without a forward_builder
with pytest.raises(
ValueError,
match="has forward_builder_args but no " \
"forward_builder."
):
ModelPatcherRule(
rule_id=DUMMY_RULE_ID,
forward_builder_args=[]
)

# Test that rule construction fails if a valid import_and_maybe_reload
# has a reload target in the patch of object to be patched
# with pytest.raises(
# ValueError,
# match="import_and_maybe_reload specified has argument 3 in the same path " \
# "as argument 1. The path to reload has to be different from object to be patched."
# ):
# ModelPatcherRule(
# rule_id=DUMMY_RULE_ID,
# import_and_maybe_reload=(
# "tests.model_patcher_test_utils.MOD_CLS_TO_PATCH",
# torch.nn.Module,
# "tests",
# )
# )

def test_patch_target_module_replaces_module_or_function_correctly():
"""
Test patching of standalone file functions
Fixture Structure:
model_patcher_fixtures
__init__.py
module1
__init__.py
module1_1.py
module3
__init__.py
module3_1.py
module2.py
"""

PatchedModuleClass = create_module_class(
"PatchedModClass",
)

def patched_mod_function():
return "patched_mod_function"

# S1 - module1_1 has function mod_1_function
# 1. Replace module1_1.mod_1_function with new function
# 2. Ensure patch_target_module replaces with a new function
patch_target_module(
"tests.model_patcher_fixtures.module1.module1_1.mod_1_function",
patched_mod_function,
"tests.model_patcher_fixtures.module1",
force_target_module_check=False,
)
assert module1.mod_1_function() == patched_mod_function()

# S2 - module1_1.Module1Class has an attribute module2.Module2Class
# 1. Replace Module2Class with new class and reload module1_1
# 2. Ensure patch_target_module replaces the attribute with a new attr class
patch_target_module(
"tests.model_patcher_fixtures.module2.Module2Class",
PatchedModuleClass,
"tests.model_patcher_fixtures.module1.module1_1"
)
assert isinstance(module1.Module1Class().attribute, PatchedModuleClass)

# S3 - module1.module3.module3_1 is a submodule of module1
# 1. Replace module1.module3.module3_1.Module3Class with a new class
# 2. Show that reloading module1 does not replace the class
patch_target_module(
"tests.model_patcher_fixtures.module1.module3.module3_1.Module3Class",
PatchedModuleClass,
"tests.model_patcher_fixtures.module1",
force_target_module_check=False,
)
assert not isinstance(module1.module3.Module3Class(), PatchedModuleClass)

# S4 - module1.module3 submodule has a dependency on parent module1.mod_1_function
# 1. Replace the module1.module1_1.mod_1_function with a new function
# 2. Ensure the function is replaced after reloading module1.module3 submodule
patch_target_module(
"tests.model_patcher_fixtures.module1.module1_1.mod_1_function",
patched_mod_function,
"tests.model_patcher_fixtures.module1.module3.module3_1",
)
assert module1.module3.Module3Class().attribute() == patched_mod_function()


0 comments on commit a31bf6e

Please sign in to comment.