Skip to content

Commit

Permalink
modelpatcher unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
achew010 committed Jul 25, 2024
1 parent 9438aba commit 8c825d9
Show file tree
Hide file tree
Showing 8 changed files with 309 additions and 13 deletions.
40 changes: 36 additions & 4 deletions plugins/framework/src/fms_acceleration/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
import importlib
import inspect
import warnings

# Third Party
import pandas as pd
Expand Down Expand Up @@ -268,16 +269,27 @@ def register(rule: ModelPatcherRule):

@staticmethod
def did_rule_trigger(module: torch.nn.Module, module_name: str):

active_rule_name, active_rule = None, None
for name, rule in ModelPatcher.rules.items():

# if there is no trigger
if rule.trigger is None:
continue

if rule.trigger.is_triggered(module, module_name):
return name, rule

return None, None
# if active rule, assign the the current rule as active
if active_rule is None:
active_rule_name = name
active_rule = rule
# otherwise, if there is already an active rule, raise warning
# that subsequent compatible forward rules will be ignored for simple forward patches
# forwardbuilders are handled when they are decomposed into new simple forward rules
elif rule.forward is not None:
warnings.warn(f"rule {rule.rule_id} is ignored on {module_name} as an earlier rule {active_rule.rule_id} has been applied")
#raise Exception(f"rule {rule.rule_id} is ignored on {module_name} as an earlier rule has been applied")

return active_rule_name, active_rule

@staticmethod
def _import_and_reload(model: torch.nn.Module):
Expand Down Expand Up @@ -326,7 +338,18 @@ def _import_and_reload(model: torch.nn.Module):
elif _target.startswith(module_path):
_no_reload.append(rule)

assert len(_with_reload) <= 1, "can only have at most one rule with reload"
# If there are multiple reload targets,
# ensure that their paths do not conflict as reloading same module might reset patches
if len(_with_reload)>1:
# sort ascending target path length
_with_reload = sorted(_with_reload, key=lambda _rule: len(_rule.import_and_maybe_reload[2]), reverse=False)
for rule_s in _with_reload:
for rule_l in _with_reload[1:]:
# if target paths in rule s is a prefix of rule l, raise an error
_, _, _path_s = rule_s.import_and_maybe_reload
_, _, _path_l = rule_l.import_and_maybe_reload
assert not _path_l.startswith(_path_s), \
f"Attempting to reload same path `{_path_s}` multiple times in {rule_s.rule_id} and {rule_l.rule_id}"

# handle those with reload first
for rule in _with_reload + _no_reload:
Expand Down Expand Up @@ -444,6 +467,15 @@ def _patch_forwards(
def patch(model: torch.nn.Module, **kwargs):
# NOTE: for a set of rules, this patch function should be called
# only once. We do not have any checks for this at the moment

# 1. Iterate over all ModelPatcher rules
# 2. For import_and_maybe_reload rules, an assertion
# is currently thrown if there are multiple reloads
# 3. For _patch_forwards, ensure that the trigger check
# module or callable function is unique across all rules
# otherwise, an assertion is thrown as it could patch the
# forwards over previous patches

try:
ModelPatcher._import_and_reload(model.get_base_model())
except AttributeError:
Expand Down
6 changes: 2 additions & 4 deletions plugins/framework/src/fms_acceleration/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,5 @@ def instantiate_model_patcher():
from fms_acceleration.model_patcher import ModelPatcher
old_registrations = ModelPatcher.rules
ModelPatcher.rules = {}
try:
yield
finally:
ModelPatcher.rules = old_registrations
yield
ModelPatcher.rules = old_registrations
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .module4_1 import mod_4_function
from .module5 import Module5Class, mod_5_function
import torch

class Module4Class(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.attribute = Module5Class()
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def mod_4_function():
return "unpatched_mod_function"
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .module5_1 import Module5Class, mod_5_function
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import torch

class Module5Class(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def mod_5_function():
return "unpatched_mod_function"
12 changes: 7 additions & 5 deletions plugins/framework/tests/test_model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@
)

from .model_patcher_test_utils import create_module_class, isolate_test_module_fixtures
from .model_patcher_fixtures import module1
from .model_patcher_fixtures import module2
from fms_acceleration.utils.test_utils import instantiate_model_patcher
from .model_patcher_fixtures import module1, module2, module4

MOD_CLS_A = create_module_class("MOD_CLS_A")
MOD_SUBCLS_A = create_module_class("MOD_SUBCLS_A", parent_class=MOD_CLS_A)
Expand Down Expand Up @@ -279,6 +277,12 @@ def test_patch_target_module_replaces_module_or_function_correctly():
- attribute: mod_1_function
- module2:
- Module2Class:
- module4:
- Module4Class:
- attribute: mod_1_function
"""

PatchedModuleClass = create_module_class(
Expand Down Expand Up @@ -378,5 +382,3 @@ def patched_mod_function():
"tests.model_patcher_fixtures.module1.module3.module3_1",
)
assert module1.module3.Module3Class().attribute() == "patched_mod_function"


245 changes: 245 additions & 0 deletions plugins/framework/tests/test_model_patcher2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
# Third Party
import pytest # pylint: disable=(import-error
import torch

# First Party
from fms_acceleration.model_patcher import (
ModelPatcher,
ModelPatcherRule,
ModelPatcherTrigger,
patch_target_module,
ModelPatcherTriggerType,
ModelPatcherHistory,
combine_functions,
combine_triggers,
)

from .model_patcher_test_utils import create_module_class, isolate_test_module_fixtures
from .model_patcher_fixtures import module1, module2, module4
from fms_acceleration.utils.test_utils import instantiate_model_patcher

from .test_model_patcher import DUMMY_RULE_ID

#Test patching of model attribute
def test_simple_forward_rule_with_mp_replaces_old_forward(): # pylint: disable=redefined-outer-name
"""
model_patcher_fixtures:
- module1:
- module1_1:
- Module2Class:
- attribute: Module2Class
- mod_1_function
- module3:
- module3_1
- Module3Class:
- attribute: mod_1_function
- module2:
- Module2Class:
- module4:
- Module4Class(torch.nn.Module):
- attribute: mod_1_function
"""

def patched_forward_function(X):
return "patched_forward_function"

# 1. Create an instance of Module4Class as model
# 2. Add a submodule to Module4Class
# 3. Create and register rule to patch forward of submodule class
# 4. Patch model
# 5. Ensure that model's submodule forward is replaced
with isolate_test_module_fixtures():
with instantiate_model_patcher():
model = module4.Module4Class()
SubModule1 = create_module_class(
"SubModule1",
namespaces={"forward": lambda self: "unpatched_forward_function"}
)
model.add_module("submodule_1", SubModule1())
rule = ModelPatcherRule(
rule_id=DUMMY_RULE_ID,
trigger=ModelPatcherTrigger(check=SubModule1),
forward=patched_forward_function,
)
ModelPatcher.register(rule)
ModelPatcher.patch(model)

assert model.submodule_1.forward() == "patched_forward_function"

def test_import_and_maybe_reload_rule_with_mp_replaces_old_attribute():
# 1. Register rule replacing module5.module5_1.Module5Class with a patched_mod_function
# reload_target is test.model_patcher.fixtures.module4
# 2. Patch module4.Module4Class with ModelPatcher
# 3. check patched module exist in module4.Module4Class.attribute
PatchedModuleClass = create_module_class(
"PatchedModClass",
)


with isolate_test_module_fixtures():
with instantiate_model_patcher():
model = module4.Module4Class()
ModelPatcher.register(
ModelPatcherRule(
rule_id=DUMMY_RULE_ID,
import_and_maybe_reload=(
"tests.model_patcher_fixtures.module4.module5.Module5Class",
PatchedModuleClass,
"tests.model_patcher_fixtures.module4",
),
)
)
ModelPatcher.patch(model)
assert isinstance(module4.Module4Class().attribute, PatchedModuleClass)

# TODO forward builder test


def test_mp_throws_error_with_multiple_reloads_on_same_target():
"""
Simulate a case where two rules attempt to reload on the same target prefix
example:
- Rule 1 target path 1: x.y.z
- Rule 2 target path 2: x.y
this might reverse the patch on Rule 1 and needs to be caught
model_patcher_fixtures:
- module1:
- module1_1:
- Module2Class:
- attribute: Module2Class
- mod_1_function
- module3:
- module3_1
- Module3Class:
- attribute: mod_1_function
- module2:
- Module2Class:
- module4:
- Module4Class(torch.nn.Module):
- attribute: mod_1_function
- module4_1
- mod_4_function
- module5:
- module5_1
- Module5Class
- module_5_function
"""

PatchedModuleClass = create_module_class(
"PatchedModuleClass",
)

def patched_mod_function():
return "patched_mod_function"

# Demonstrate that the 2nd patch overwrites the 1st patch if the reload module paths are the same
with isolate_test_module_fixtures():
# 1st patch on a function
patch_target_module(
"tests.model_patcher_fixtures.module4.module5.module5_1.mod_5_function",
patched_mod_function,
"tests.model_patcher_fixtures.module4.module5",
)

assert module4.module5.mod_5_function() == "patched_mod_function"

# 2nd patch on a class that has a target path that reloads module5 as well
patch_target_module(
"tests.model_patcher_fixtures.module4.module5.module5_1.Module5Class",
PatchedModuleClass,
"tests.model_patcher_fixtures.module4.module5"
)

assert isinstance(module4.module5.Module5Class(), PatchedModuleClass)
assert module4.module5.mod_5_function() == "unpatched_mod_function"

# Ensure that an assertion is raised if target paths share the same root path
with pytest.raises(
AssertionError,
):
with isolate_test_module_fixtures():
with instantiate_model_patcher():
# 1. Initialize a model with module path tests.model_patcher_fixtures.module4
model = module4.Module4Class()

# 2. Simulate patching a function in module4.module5.module5_1
ModelPatcher.register(
ModelPatcherRule(
rule_id=f"{DUMMY_RULE_ID}.2",
import_and_maybe_reload=(
"tests.model_patcher_fixtures.module4.module5.module5_1.mod_5_function",
patched_mod_function,
"tests.model_patcher_fixtures.module4.module5.module5_1",
),
)
)

# 3. Simulate patching a class in module4.module5.module5_1
ModelPatcher.register(
ModelPatcherRule(
rule_id=f"{DUMMY_RULE_ID}.1",
import_and_maybe_reload=(
"tests.model_patcher_fixtures.module4.module5.module5_1.Module5Class",
PatchedModuleClass,
"tests.model_patcher_fixtures.module4",
),
)
)

# while ModelPatcher is patching different objects, repeated reloads on same path is risky
# since module4 is a parent of module5, reloading module4 again might affect the previous patch.
# To prevent this we throw an exception if the shorter target path is a prefix of the
# longer target path
ModelPatcher.patch(model)


def test_mp_throws_warning_with_multiple_patches():
"""
Ensure for each module, only one forward patch is implemented on it.
The patch implementation checks if there are multiple forward patch rules that are applied to the module,
only the 1st forward patch rule is applied, the others will be ignored and a warning will be raised
In the case of a list of new rules generated by `forwardbuilder`, it will be handled similarly since
it decomposes to multiple single forward patch rules downstream.
"""
with pytest.warns(
UserWarning,
):
with isolate_test_module_fixtures():
with instantiate_model_patcher():
# 1. Create a model
# 2. Create a submodule to patch on
# 3. Create 1st rule to patch submodule forward function
# 4. Create 2nd rule to patch submodule forward function again
# 5. Throws warning that any subsequent forward patches after the 1st patch is ignored

model = module4.Module4Class()
SubModule1 = create_module_class(
"SubModule1",
namespaces={"forward": lambda self: "unpatched_forward_function"}
)
model.add_module("submodule_1", SubModule1())

ModelPatcher.register(
ModelPatcherRule(
rule_id=DUMMY_RULE_ID+".1",
trigger=ModelPatcherTrigger(check=SubModule1),
forward=lambda self: "patched_forward_function",
)
)
ModelPatcher.register(
ModelPatcherRule(
rule_id=DUMMY_RULE_ID+".2",
trigger=ModelPatcherTrigger(check=SubModule1),
forward=lambda self: "patched_forward_function_2",
)
)
ModelPatcher.patch(model)

# TODO test on forward builder cases

0 comments on commit 8c825d9

Please sign in to comment.