Skip to content

Commit

Permalink
add context manager to isolate patching unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
achew010 committed Jul 24, 2024
1 parent a31bf6e commit 2683d9e
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 45 deletions.
6 changes: 0 additions & 6 deletions plugins/framework/src/fms_acceleration/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def patch_target_module(
to_patch: str,
replace_with: Any,
target_module: str = None,
force_target_module_check: bool = True,
):
to_patch = to_patch.split(".")
assert len(to_patch) > 1, "must have an object to patch"
Expand All @@ -43,11 +42,6 @@ def patch_target_module(
setattr(source, obj_name_to_patch, replace_with)

if target_module is not None:
# if target module is a parent package of to_patch,
# it will reload the old object over the patch
if force_target_module_check:
assert target_module not in to_patch, \
"argument target_module cannot have same root path as to_patch"
# reload and this should get the patched object
target_module = importlib.import_module(target_module)
importlib.reload(target_module)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ def __init__(self) -> None:
self.attribute = Module2Class()

def mod_1_function():
return 1
return "unpatched_mod_function"
33 changes: 32 additions & 1 deletion plugins/framework/tests/model_patcher_test_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,41 @@
import torch
import os
import sys
import importlib
from contextlib import contextmanager
from typing import Dict, Any, Type

ROOT = 'tests/model_patcher_fixtures'
PATHS = []
for root, dirs, files in os.walk(ROOT):
for f in files:
filename, ext = os.path.splitext(f)
if ext != ".py":
continue
if filename != '__init__':
PATHS.append(os.path.join(root, f.replace(".py", "")))
else:
PATHS.append(root)

@contextmanager
def manipulate_test_module_fixures():
old_mod = {
k: sys.modules[k.replace("/", ".")] for k in PATHS
}
try:
yield
finally:
# sort keys of descending length, to load children 1st
sorted_keys = sorted(old_mod.keys(), key=len, reverse=True)
for key in sorted_keys:
# Unclear why but needs a reload, to be investigated later
importlib.reload(old_mod[key])


def create_module_class(
class_name: str,
namespaces: Dict[str, Any] = {},
parent_class: Type = torch.nn.Module
):
cls = type(class_name, (parent_class,), namespaces)
return cls
return cls
85 changes: 48 additions & 37 deletions plugins/framework/tests/test_model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
combine_triggers,
)

from .model_patcher_test_utils import create_module_class
from .model_patcher_test_utils import create_module_class, manipulate_test_module_fixures
from .model_patcher_fixtures import module1
from .model_patcher_fixtures import module2
from fms_acceleration.utils.test_utils import instantiate_model_patcher
Expand Down Expand Up @@ -300,16 +300,20 @@ def test_patch_target_module_replaces_module_or_function_correctly():
"""
Test patching of standalone file functions
Fixture Structure:
model_patcher_fixtures
__init__.py
module1
__init__.py
module1_1.py
module3
__init__.py
module3_1.py
module2.py
Fixtures Class Structure
model_patcher_fixtures:
- module1:
- module1_1:
- Module2Class:
- attribute: Module2Class
- mod_1_function
- module3:
- module3_1
- Module3Class:
- attribute: mod_1_function
- module2:
- Module2Class:
"""

PatchedModuleClass = create_module_class(
Expand All @@ -322,43 +326,50 @@ def patched_mod_function():
# S1 - module1_1 has function mod_1_function
# 1. Replace module1_1.mod_1_function with new function
# 2. Ensure patch_target_module replaces with a new function
patch_target_module(
"tests.model_patcher_fixtures.module1.module1_1.mod_1_function",
patched_mod_function,
"tests.model_patcher_fixtures.module1",
force_target_module_check=False,
)
assert module1.mod_1_function() == patched_mod_function()
with manipulate_test_module_fixures():
patch_target_module(
"tests.model_patcher_fixtures.module1.module1_1.mod_1_function",
patched_mod_function,
"tests.model_patcher_fixtures.module1",
)
assert module1.mod_1_function() == "patched_mod_function"

# test patches are reset outside the context manager
assert module1.mod_1_function() == "unpatched_mod_function"

# S2 - module1_1.Module1Class has an attribute module2.Module2Class
# 1. Replace Module2Class with new class and reload module1_1
# 2. Ensure patch_target_module replaces the attribute with a new attr class
patch_target_module(
"tests.model_patcher_fixtures.module2.Module2Class",
PatchedModuleClass,
"tests.model_patcher_fixtures.module1.module1_1"
)
assert isinstance(module1.Module1Class().attribute, PatchedModuleClass)

# tests.model_patcher_fixtures.module1 won't work as child module1_1 has been cached
with manipulate_test_module_fixures():
patch_target_module(
"tests.model_patcher_fixtures.module2.Module2Class",
PatchedModuleClass,
"tests.model_patcher_fixtures.module1.module1_1"
)
assert isinstance(module1.Module1Class().attribute, PatchedModuleClass)

# S3 - module1.module3.module3_1 is a submodule of module1
# 1. Replace module1.module3.module3_1.Module3Class with a new class
# 2. Show that reloading module1 does not replace the class
patch_target_module(
"tests.model_patcher_fixtures.module1.module3.module3_1.Module3Class",
PatchedModuleClass,
"tests.model_patcher_fixtures.module1",
force_target_module_check=False,
)
assert not isinstance(module1.module3.Module3Class(), PatchedModuleClass)
with manipulate_test_module_fixures():
patch_target_module(
"tests.model_patcher_fixtures.module1.module3.module3_1.Module3Class",
PatchedModuleClass,
"tests.model_patcher_fixtures.module1",
)
assert not isinstance(module1.module3.Module3Class(), PatchedModuleClass)

# S4 - module1.module3 submodule has a dependency on parent module1.mod_1_function
# 1. Replace the module1.module1_1.mod_1_function with a new function
# 2. Ensure the function is replaced after reloading module1.module3 submodule
patch_target_module(
"tests.model_patcher_fixtures.module1.module1_1.mod_1_function",
patched_mod_function,
"tests.model_patcher_fixtures.module1.module3.module3_1",
)
assert module1.module3.Module3Class().attribute() == patched_mod_function()
with manipulate_test_module_fixures():
patch_target_module(
"tests.model_patcher_fixtures.module1.module1_1.mod_1_function",
patched_mod_function,
"tests.model_patcher_fixtures.module1.module3.module3_1",
)
assert module1.module3.Module3Class().attribute() == patched_mod_function()


0 comments on commit 2683d9e

Please sign in to comment.