-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
163 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
2 changes: 2 additions & 0 deletions
2
plugins/framework/tests/model_patcher_fixtures/module1/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .module3 import Module3Class | ||
from .module1_1 import Module1Class, mod_1_function |
8 changes: 8 additions & 0 deletions
8
plugins/framework/tests/model_patcher_fixtures/module1/module1_1.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from ..module2 import Module2Class | ||
|
||
class Module1Class: | ||
def __init__(self) -> None: | ||
self.attribute = Module2Class() | ||
|
||
def mod_1_function(): | ||
return 1 |
1 change: 1 addition & 0 deletions
1
plugins/framework/tests/model_patcher_fixtures/module1/module3/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .module3_1 import Module3Class |
5 changes: 5 additions & 0 deletions
5
plugins/framework/tests/model_patcher_fixtures/module1/module3/module3_1.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from ..module1_1 import mod_1_function | ||
|
||
class Module3Class: | ||
def __init__(self) -> None: | ||
self.attribute = mod_1_function |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
class Module2Class: | ||
def __init__(self) -> None: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,9 +30,10 @@ | |
combine_functions, | ||
combine_triggers, | ||
) | ||
from .model_patcher_test_utils import ( | ||
create_module_class, | ||
) | ||
|
||
from .model_patcher_test_utils import create_module_class | ||
from .model_patcher_fixtures import module1 | ||
from .model_patcher_fixtures import module2 | ||
from fms_acceleration.utils.test_utils import instantiate_model_patcher | ||
|
||
MOD_CLS_A = create_module_class("MOD_CLS_A") | ||
|
@@ -178,7 +179,7 @@ def check_module(module): | |
# Scenario 3: | ||
# Static check to ensure additional constraint is checked | ||
# 1. create an instance of ModClassA as model | ||
# 2A. register 2 submodules instances of ModClassB, Submodule_1 and SubModule2 | ||
# 2. register 2 submodules instances of ModClassB, Submodule_1 and SubModule_2 | ||
# 3. create a trigger that checks for an instance of module_B and `submodule_1` module name | ||
# 4. for each module in model, ensure returns true if trigger detects module, | ||
# otherwise it should return false | ||
|
@@ -240,3 +241,128 @@ def test_combine_mp_triggers_produces_correct_output(target_module, trigger_chec | |
logic=logic, | ||
).is_triggered(target_module) is expected_result | ||
|
||
|
||
def test_mp_rule_raises_error_when_arguments_incorrectly_configured(): | ||
"Ensure MP rule is throws appropriate error when wrong argument combinations are passed" | ||
# Test mp rule construction raises with multiple arguments | ||
with pytest.raises( | ||
ValueError, | ||
match="must only have only one of forward, " \ | ||
"foward builder, or import_and_maybe_reload, specified." | ||
): | ||
ModelPatcherRule( | ||
rule_id=DUMMY_RULE_ID, | ||
forward=lambda self, X: X, | ||
import_and_maybe_reload=(), | ||
forward_builder=lambda self, X: X, | ||
) | ||
|
||
# Test mp rule construction raises with trigger and import_and_reload | ||
with pytest.raises( | ||
ValueError, | ||
match="has import_and_maybe_reload specified, " \ | ||
"and trigger must be None." | ||
): | ||
ModelPatcherRule( | ||
rule_id=DUMMY_RULE_ID, | ||
trigger=ModelPatcherTrigger(check=torch.nn.Module), | ||
import_and_maybe_reload=(), | ||
) | ||
|
||
# Test that rule construction raises if forward_builder_args are provided without a forward_builder | ||
with pytest.raises( | ||
ValueError, | ||
match="has forward_builder_args but no " \ | ||
"forward_builder." | ||
): | ||
ModelPatcherRule( | ||
rule_id=DUMMY_RULE_ID, | ||
forward_builder_args=[] | ||
) | ||
|
||
# Test that rule construction fails if a valid import_and_maybe_reload | ||
# has a reload target in the patch of object to be patched | ||
# with pytest.raises( | ||
# ValueError, | ||
# match="import_and_maybe_reload specified has argument 3 in the same path " \ | ||
# "as argument 1. The path to reload has to be different from object to be patched." | ||
# ): | ||
# ModelPatcherRule( | ||
# rule_id=DUMMY_RULE_ID, | ||
# import_and_maybe_reload=( | ||
# "tests.model_patcher_test_utils.MOD_CLS_TO_PATCH", | ||
# torch.nn.Module, | ||
# "tests", | ||
# ) | ||
# ) | ||
|
||
def test_patch_target_module_replaces_module_or_function_correctly(): | ||
""" | ||
Test patching of standalone file functions | ||
Fixture Structure: | ||
model_patcher_fixtures | ||
This comment has been minimized.
Sorry, something went wrong.
fabianlim
Contributor
|
||
__init__.py | ||
module1 | ||
__init__.py | ||
module1_1.py | ||
module3 | ||
__init__.py | ||
module3_1.py | ||
module2.py | ||
""" | ||
|
||
# 1. Create a patched module, ModClassPatched | ||
# 2. Call patch_target_module on modules2 in model_patcher_fixtures | ||
# 3. check that modules1 has been reloaded with patched module | ||
|
||
PatchedModClass = create_module_class( | ||
"PatchedModClass", | ||
) | ||
|
||
def patched_mod_function(): | ||
return "patched_mod_function" | ||
|
||
# S1 - module1_1 has function mod_1_function | ||
# 1. Replace module1_1.mod_1_function with new function | ||
# 2. Ensure patch_target_module replaces with a new function | ||
patch_target_module( | ||
"tests.model_patcher_fixtures.module1.module1_1.mod_1_function", | ||
This comment has been minimized.
Sorry, something went wrong. |
||
patched_mod_function, | ||
"tests.model_patcher_fixtures.module1", | ||
force_target_module_check=False, | ||
) | ||
assert module1.mod_1_function() == patched_mod_function() | ||
|
||
# S2 - module1_1.Module1Class has an attribute module2.Module2Class | ||
# 1. Replace Module2Class with new class and reload module1_1 | ||
# 2. Ensure patch_target_module replaces the attribute with a new attr class | ||
patch_target_module( | ||
"tests.model_patcher_fixtures.module2.Module2Class", | ||
PatchedModClass, | ||
"tests.model_patcher_fixtures.module1.module1_1" | ||
This comment has been minimized.
Sorry, something went wrong.
fabianlim
Contributor
|
||
) | ||
assert isinstance(module1.Module1Class().attribute, PatchedModClass) | ||
|
||
# S3 - module1.module3.module3_1 is a submodule of module1 | ||
# 1. Replace module1.module3.module3_1.Module3Class with a new class | ||
# 2. Show that reloading module1 does not replace the class | ||
patch_target_module( | ||
"tests.model_patcher_fixtures.module1.module3.module3_1.Module3Class", | ||
PatchedModClass, | ||
"tests.model_patcher_fixtures.module1", | ||
force_target_module_check=False, | ||
) | ||
assert not isinstance(module1.module3.Module3Class(), PatchedModClass) | ||
|
||
# S4 - module1.module3 submodule has a dependency on parent module1.mod_1_function | ||
# 1. Replace the module1.module1_1.mod_1_function with a new function | ||
# 2. Ensure the function is replaced after reloading module1.module3 submodule | ||
This comment has been minimized.
Sorry, something went wrong.
fabianlim
Contributor
|
||
patch_target_module( | ||
"tests.model_patcher_fixtures.module1.module1_1.mod_1_function", | ||
patched_mod_function, | ||
"tests.model_patcher_fixtures.module1.module3.module3_1", | ||
) | ||
assert module1.module3.Module3Class().attribute() == patched_mod_function() | ||
|
||
|
lets not bother to check this here, lets have it checked at the global model patcher.