Skip to content

Commit

Permalink
some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianlim committed Jul 24, 2024
1 parent 2683d9e commit 373e341
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 68 deletions.
42 changes: 25 additions & 17 deletions plugins/framework/tests/model_patcher_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,45 @@
from contextlib import contextmanager
from typing import Dict, Any, Type

ROOT = 'tests/model_patcher_fixtures'
PATHS = []
for root, dirs, files in os.walk(ROOT):
ROOT = 'tests.model_patcher_fixtures'
MODULE_PATHS = []
for root, dirs, files in os.walk(ROOT.replace('.', os.path.sep)):
for f in files:
filename, ext = os.path.splitext(f)
if ext != ".py":
continue
if filename != '__init__':
PATHS.append(os.path.join(root, f.replace(".py", "")))
p = os.path.join(root, filename)
else:
PATHS.append(root)
p = root

MODULE_PATHS.append(p.replace(os.path.sep, '.'))

@contextmanager
def manipulate_test_module_fixures():
def isolate_test_module_fixtures():
old_mod = {
k: sys.modules[k.replace("/", ".")] for k in PATHS
k: sys.modules[k] for k in MODULE_PATHS
}
try:
yield
finally:
# sort keys of descending length, to load children 1st
sorted_keys = sorted(old_mod.keys(), key=len, reverse=True)
for key in sorted_keys:
# Unclear why but needs a reload, to be investigated later
importlib.reload(old_mod[key])
yield

# Reload only reloads the speicified module, but makes not attempt to reload
# the imports of that module.
# - i.e., This moeans that if and import had been changed
# then the reload will take the changed import.
# - i.e., This also means that the individuals must be reloaded seperatedly
# for a complete reset.
#
# Therefore, we need to reload ALL Modules in opposite tree order, meaning that
# the children must be reloaded before their parent

for key in sorted(old_mod.keys(), key=len, reverse=True):
# Unclear why but needs a reload, to be investigated later
importlib.reload(old_mod[key])


def create_module_class(
class_name: str,
namespaces: Dict[str, Any] = {},
parent_class: Type = torch.nn.Module
):
cls = type(class_name, (parent_class,), namespaces)
return cls
return type(class_name, (parent_class,), namespaces)
105 changes: 54 additions & 51 deletions plugins/framework/tests/test_model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
combine_triggers,
)

from .model_patcher_test_utils import create_module_class, manipulate_test_module_fixures
from .model_patcher_test_utils import create_module_class, isolate_test_module_fixtures
from .model_patcher_fixtures import module1
from .model_patcher_fixtures import module2
from fms_acceleration.utils.test_utils import instantiate_model_patcher
Expand Down Expand Up @@ -102,25 +102,6 @@ def test_mp_trigger_constructs_with_check_and_trigger_type_args():
type=ModelPatcherTriggerType.callable,
)

# def test_mp_trigger_constructs_with_all_specified_args():
# "Test construction of trigger with check, type and module_name arguments"
# # check that trigger constructs
# ModelPatcherTrigger(
# check=MODULE_A,
# type=ModelPatcherTriggerType.module,
# module_name = MODULE_A
# )
# # raises error if module_name is incorrect type
# with pytest.raises(
# AssertionError,
# match = "module_name has to be type `str`"
# ):
# ModelPatcherTrigger(
# check=torch.nn.Module,
# type=ModelPatcherTriggerType.module,
# module_name = int
# )

def test_mp_trigger_correctly_triggers():
"Test for correctnness of trigger behaviour"

Expand Down Expand Up @@ -280,22 +261,6 @@ def test_mp_rule_raises_error_when_arguments_incorrectly_configured():
forward_builder_args=[]
)

# Test that rule construction fails if a valid import_and_maybe_reload
# has a reload target in the patch of object to be patched
# with pytest.raises(
# ValueError,
# match="import_and_maybe_reload specified has argument 3 in the same path " \
# "as argument 1. The path to reload has to be different from object to be patched."
# ):
# ModelPatcherRule(
# rule_id=DUMMY_RULE_ID,
# import_and_maybe_reload=(
# "tests.model_patcher_test_utils.MOD_CLS_TO_PATCH",
# torch.nn.Module,
# "tests",
# )
# )

def test_patch_target_module_replaces_module_or_function_correctly():
"""
Test patching of standalone file functions
Expand Down Expand Up @@ -326,50 +291,88 @@ def patched_mod_function():
# S1 - module1_1 has function mod_1_function
# 1. Replace module1_1.mod_1_function with new function
# 2. Ensure patch_target_module replaces with a new function
with manipulate_test_module_fixures():
with isolate_test_module_fixtures():
patch_target_module(
"tests.model_patcher_fixtures.module1.module1_1.mod_1_function",
patched_mod_function,
"tests.model_patcher_fixtures.module1",
)
assert module1.mod_1_function() == "patched_mod_function"

# test patches are reset outside the context manager
# # test patches are reset outside the context manager
assert module1.mod_1_function() == "unpatched_mod_function"

# S2 - module1_1.Module1Class has an attribute module2.Module2Class
# 1. Replace Module2Class with new class and reload module1_1
# 2. Ensure patch_target_module replaces the attribute with a new attr class

# tests.model_patcher_fixtures.module1 won't work as child module1_1 has been cached
with manipulate_test_module_fixures():
with isolate_test_module_fixtures():
patch_target_module(
"tests.model_patcher_fixtures.module2.Module2Class",
PatchedModuleClass,
"tests.model_patcher_fixtures.module1.module1_1"
)
assert isinstance(module1.Module1Class().attribute, PatchedModuleClass)

# S3 - module1.module3.module3_1 is a submodule of module1
# check the the fixture isolation works
assert not isinstance(module1.Module1Class().attribute, PatchedModuleClass)

# S3.1 - module1.module3.module3_1 is a submodule of module1
# 1. Replace module1.module3.module3_1.Module3Class with a new class
# 2. Show that reloading module1 does not replace the class
with manipulate_test_module_fixures():
# 2. No target reploading
# - this test shows that a replacement only affects the EXACT PATH that was patched
with isolate_test_module_fixtures():
patch_target_module(
"tests.model_patcher_fixtures.module1.module3.module3_1.Module3Class",
PatchedModuleClass,
"tests.model_patcher_fixtures.module1",
"tests.model_patcher_fixtures.module1.module3.module3_1.Module3Class",
PatchedModuleClass,
)

# - this is the exact module path that was patched, so it will reflect the patched class
assert isinstance(module1.module3.module3_1.Module3Class(), PatchedModuleClass)

# - this is the top-level module path, and shows that upper level paths will be
# be affected
assert not isinstance(module1.module3.Module3Class(), PatchedModuleClass)

# S3.2 - module1.module3.module3_1 is a submodule of module1
# 1. Replace module1.module3.module3_1.Module3Class with a new class
# 2. reload the top-level module path module1
with isolate_test_module_fixtures():
patch_target_module(
"tests.model_patcher_fixtures.module1.module3.module3_1.Module3Class",
PatchedModuleClass,
"tests.model_patcher_fixtures.module1",
)

# - the reload of the top level module path module1, will replace module1.module3
# with the original version
assert not isinstance(module1.module3.module3_1.Module3Class(), PatchedModuleClass)
assert not isinstance(module1.module3.Module3Class(), PatchedModuleClass)

# S3.3 - module1.module3 is a submodule of module1
# 1. Replace module1.module3.module3_1.Module3Class with a new class
# 2. reload the top-level module path module1
with isolate_test_module_fixtures():
patch_target_module(
"tests.model_patcher_fixtures.module1.module3.Module3Class",
PatchedModuleClass,
"tests.model_patcher_fixtures.module1",
)

# - the reload of the top level module path module1, will replace module1.module3
# with the original version
assert not isinstance(module1.module3.module3_1.Module3Class(), PatchedModuleClass)
assert not isinstance(module1.module3.Module3Class(), PatchedModuleClass)

# S4 - module1.module3 submodule has a dependency on parent module1.mod_1_function
# S4 - module1.module3 submodule has a dependency on
# module1.module1_1.mod_1_function
# 1. Replace the module1.module1_1.mod_1_function with a new function
# 2. Ensure the function is replaced after reloading module1.module3 submodule
with manipulate_test_module_fixures():
# 2. Ensure the target reloading of module1.module3 picks up the patched function
with isolate_test_module_fixtures():
patch_target_module(
"tests.model_patcher_fixtures.module1.module1_1.mod_1_function",
patched_mod_function,
"tests.model_patcher_fixtures.module1.module3.module3_1",
)
assert module1.module3.Module3Class().attribute() == patched_mod_function()
assert module1.module3.Module3Class().attribute() == "patched_mod_function"


0 comments on commit 373e341

Please sign in to comment.