Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored Model Patcher Class #55

Merged
merged 28 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
7941ed7
set main to track current plugin versions
achew010 Jul 17, 2024
4b871d0
move model_patcher to framework
achew010 Jul 17, 2024
3bf9a55
replace local patching with model_patcher
achew010 Jul 18, 2024
815b0c8
add additional unit tests
achew010 Jul 18, 2024
7efbfed
remove redundant patch function
achew010 Jul 18, 2024
33258ba
shifted patch summary logging to framework plugin and patch id renames
achew010 Jul 18, 2024
af7009c
modified unit tests from PR comments
achew010 Jul 20, 2024
6b6fca9
incremental refactor of unit tests
achew010 Jul 22, 2024
252a73c
changes to mp trigger unit tests
achew010 Jul 23, 2024
94e217e
additional changes to trigger unit tests
achew010 Jul 23, 2024
a31bf6e
adding MP Rule unit tests
achew010 Jul 23, 2024
2683d9e
add context manager to isolate patching unit tests
achew010 Jul 24, 2024
748595c
some fixes
fabianlim Jul 24, 2024
9438aba
clarified comments
fabianlim Jul 25, 2024
8c825d9
modelpatcher unit tests
achew010 Jul 24, 2024
df95ece
added forward_builder fn unit test
achew010 Jul 25, 2024
e653b80
lint changes
achew010 Jul 25, 2024
e6f2284
more lint changes
achew010 Jul 25, 2024
736e706
file renaming and added license headers on new files
achew010 Jul 26, 2024
7c302ba
added guard to patch model only if model exist in framework plugin ca…
achew010 Jul 26, 2024
cd253b3
replaced buggy partial wrapping on ModelPatcher.patch and set tox env…
achew010 Jul 27, 2024
1d498e0
additional linting
achew010 Jul 28, 2024
a4f8800
shifted patch trigger to main framework class
achew010 Jul 29, 2024
ac31192
additional modifications to foak patch rules
achew010 Jul 29, 2024
8895cad
linting
achew010 Jul 29, 2024
f6848a7
additional changes from comments
achew010 Jul 29, 2024
5e535b2
fixes to mp unit test
achew010 Jul 29, 2024
c204c86
updated with new benchmark results
achew010 Jul 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def build_patch_to_view_tensor_to_parameter_for_fsdp_gptq(
torch_dtype=torch.int32, # patch it back to
)

def load_fsdp_gptq_patch(target_module, torch_dtype):
def register_tensors_as_parameters_patch_rule(target_module, torch_dtype):
# Register patch
fabianlim marked this conversation as resolved.
Show resolved Hide resolved
ModelPatcher.register(
ModelPatcherRule(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,11 @@ def model_loader(self, model_name: str, **kwargs):
and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
):
# register FSDP patch
from .autogptq_utils import load_fsdp_gptq_patch
load_fsdp_gptq_patch(target_module = QuantLinear, torch_dtype = torch_dtype)
from .autogptq_utils import register_tensors_as_parameters_patch_rule
register_tensors_as_parameters_patch_rule(
target_module=QuantLinear,
torch_dtype=torch_dtype,
)

# replace
AutoModelForCausalLM.from_config = _old_from_config
Expand Down
12 changes: 6 additions & 6 deletions plugins/framework/src/fms_acceleration/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,11 @@ class ModelPatcherRule:
] = None

def __post_init__(self):
if (
self.forward is not None
and self.forward_builder is not None
and self.import_and_maybe_reload is not None
):
if sum([
self.forward is not None,
self.forward_builder is not None,
self.import_and_maybe_reload is not None,
])>1:
fabianlim marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
f"Rule '{self.rule_id}' must only have only one of forward, "
"foward builder, or import_and_maybe_reload, specified."
Expand Down Expand Up @@ -305,7 +305,7 @@ def _import_and_reload(model: torch.nn.Module):
elif _target.startswith(module_path):
_no_reload.append(rule)

assert len(_with_reload) <= 1, "cannot have have at most one rule with reload"
assert len(_with_reload) <= 1, "can only have at most one rule with reload"

# handle those with reload first
for rule in _with_reload + _no_reload:
Expand Down
19 changes: 9 additions & 10 deletions plugins/framework/src/fms_acceleration/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from typing import Any, Callable, Dict, List, Set, Tuple, Type

# Third Party
from torch.nn import CrossEntropyLoss
import torch
import yaml

Expand Down Expand Up @@ -182,12 +181,12 @@ def dummy_custom_loader(self, model_name, **kwargs):
"dummy custom loader returning dummy model"
return create_noop_model_with_archs(archs=["DummyModel"]) #


class DummyModule(torch.nn.Module):
def __init__(self, hidden_size, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.linear = torch.nn.Linear(hidden_size, hidden_size)
self.loss_fn = CrossEntropyLoss()

def forward(self, X):
return self.linear(X)
@contextmanager
def instantiate_model_patcher():
fabianlim marked this conversation as resolved.
Show resolved Hide resolved
from fms_acceleration.model_patcher import ModelPatcher
old_registrations = ModelPatcher.rules
ModelPatcher.rules = {}
try:
yield
finally:
ModelPatcher.rules = old_registrations
18 changes: 18 additions & 0 deletions plugins/framework/tests/model_patcher_test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch

UNPATCHED_RESPONSE = 0
PATCHED_RESPONSE = 1

class DummyAttribute(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, *args, **kwargs):
return UNPATCHED_RESPONSE

class PatchedAttribute(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, *args, **kwargs):
return PATCHED_RESPONSE
Loading
Loading