Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored Model Patcher Class #55

Merged
merged 28 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
7941ed7
set main to track current plugin versions
achew010 Jul 17, 2024
4b871d0
move model_patcher to framework
achew010 Jul 17, 2024
3bf9a55
replace local patching with model_patcher
achew010 Jul 18, 2024
815b0c8
add additional unit tests
achew010 Jul 18, 2024
7efbfed
remove redundant patch function
achew010 Jul 18, 2024
33258ba
shifted patch summary logging to framework plugin and patch id renames
achew010 Jul 18, 2024
af7009c
modified unit tests from PR comments
achew010 Jul 20, 2024
6b6fca9
incremental refactor of unit tests
achew010 Jul 22, 2024
252a73c
changes to mp trigger unit tests
achew010 Jul 23, 2024
94e217e
additional changes to trigger unit tests
achew010 Jul 23, 2024
a31bf6e
adding MP Rule unit tests
achew010 Jul 23, 2024
2683d9e
add context manager to isolate patching unit tests
achew010 Jul 24, 2024
748595c
some fixes
fabianlim Jul 24, 2024
9438aba
clarified comments
fabianlim Jul 25, 2024
8c825d9
modelpatcher unit tests
achew010 Jul 24, 2024
df95ece
added forward_builder fn unit test
achew010 Jul 25, 2024
e653b80
lint changes
achew010 Jul 25, 2024
e6f2284
more lint changes
achew010 Jul 25, 2024
736e706
file renaming and added license headers on new files
achew010 Jul 26, 2024
7c302ba
added guard to patch model only if model exist in framework plugin ca…
achew010 Jul 26, 2024
cd253b3
replaced buggy partial wrapping on ModelPatcher.patch and set tox env…
achew010 Jul 27, 2024
1d498e0
additional linting
achew010 Jul 28, 2024
a4f8800
shifted patch trigger to main framework class
achew010 Jul 29, 2024
ac31192
additional modifications to foak patch rules
achew010 Jul 29, 2024
8895cad
linting
achew010 Jul 29, 2024
f6848a7
additional changes from comments
achew010 Jul 29, 2024
5e535b2
fixes to mp unit test
achew010 Jul 29, 2024
c204c86
updated with new benchmark results
achew010 Jul 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion plugins/accelerated-peft/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "fms-acceleration-peft"
version = '0.0.1'
version = '0.1.0.1.dev'
description = "FMS Acceleration for PeFT"
authors = [
{name = "Fabian Lim", email = "[email protected]"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,35 +24,45 @@
from peft.tuners.lora.gptq import QuantLinear as LoraLinearGPTQ
import torch

from fms_acceleration.model_patcher import ModelPatcher, ModelPatcherRule, ModelPatcherTrigger
from functools import partial

# these parameters are to be patched for triton v2
# consider making a map if patching more kernels
PATCH_FOR_FSDP_TRITON_V2 = ["qweight", "qzeros"]


# This function may be moved after merging
# https://github.com/foundation-model-stack/fms-acceleration/pull/25
def _patch_target_module(
to_patch: str,
replace_with: Any,
target_module: str = None,
def build_patch_to_view_tensor_to_parameter_for_fsdp_gptq(
module,
torch_dtype,
):
to_patch = to_patch.split(".")
assert len(to_patch) > 1, "must have an object to patch"

to_patch, obj_name_to_patch = to_patch[:-1], to_patch[-1]
to_patch = ".".join(to_patch)
source = importlib.import_module(to_patch)
original_obj = getattr(source, obj_name_to_patch)
setattr(source, obj_name_to_patch, replace_with)

if target_module is not None:
# reload and this should get the patched object
target_module = importlib.import_module(target_module)
importlib.reload(target_module)

# replace it
setattr(source, obj_name_to_patch, original_obj)
# convert all patched attributes to Parameters of torch_dtype
# so FSDP can shard them
for attr_name in PATCH_FOR_FSDP_TRITON_V2:
attr = getattr(module, attr_name)
attr = torch.nn.Parameter(
attr.view(torch_dtype), requires_grad=False
)
setattr(module, attr_name, attr)

# this patches the forward to convert them back to original
# type (i.e. int32) before the function call into the kernels
return patch_forward_to_view_attributes_before_call(
module.forward,
attribute_names=PATCH_FOR_FSDP_TRITON_V2,
torch_dtype=torch.int32, # patch it back to
)

def register_tensors_as_parameters_patch_rule(target_module, torch_dtype):
# Register patch
fabianlim marked this conversation as resolved.
Show resolved Hide resolved
ModelPatcher.register(
ModelPatcherRule(
rule_id="autogptq_patch_tensors_as_float_parameters",
trigger=ModelPatcherTrigger(check=target_module),
forward_builder = build_patch_to_view_tensor_to_parameter_for_fsdp_gptq,
forward_builder_args=["torch_dtype"],
)
)
ModelPatcher.patch = partial(ModelPatcher.patch, torch_dtype=torch_dtype)

def make_sure_no_tensor_in_meta_device(
model,
Expand Down Expand Up @@ -124,7 +134,6 @@ def create_new_module_peft(
# if module cannot be found, return None which results in a raise in the call-stack
return new_module


# consider to move this somewhere more general
def patch_forward_to_view_attributes_before_call(
old_forward: Callable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

# Third Party
from fms_acceleration import AccelerationPlugin
from fms_acceleration.model_patcher import patch_target_module
from peft import LoraConfig, prepare_model_for_kbit_training
from peft.tuners.lora.model import LoraModel
from transformers import AutoModelForCausalLM, TrainingArguments
Expand Down Expand Up @@ -81,11 +82,6 @@ def model_loader(self, model_name: str, **kwargs):
from .gptqmodel.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
QuantLinear,
)
# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
PATCH_FOR_FSDP_TRITON_V2,
patch_forward_to_view_attributes_before_call,
)

# Currently we allow only a quantized checkpoint to be loaded, we do not
# implement the quantization process here.
Expand Down Expand Up @@ -143,14 +139,11 @@ def model_loader(self, model_name: str, **kwargs):
kwargs["low_cpu_mem_usage"] = True
if self.use_external_lib:
# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
_patch_target_module,
make_sure_no_tensor_in_meta_device,
)
from .autogptq_utils import make_sure_no_tensor_in_meta_device # pylint: disable=import-outside-toplevel

# We patch `make_sure_no_tensor_in_meta_device`
# from autogptq to avoid errors on models without bias
_patch_target_module(
patch_target_module(
to_patch="auto_gptq.modeling._utils.make_sure_no_tensor_in_meta_device",
replace_with=make_sure_no_tensor_in_meta_device,
target_module="auto_gptq.modeling._base",
Expand Down Expand Up @@ -201,31 +194,15 @@ def model_loader(self, model_name: str, **kwargs):
world_size > 1
and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
):
# register FSDP patch
from .autogptq_utils import register_tensors_as_parameters_patch_rule
register_tensors_as_parameters_patch_rule(
target_module=QuantLinear,
torch_dtype=torch_dtype,
)

# patch all the QuantLinear base layers
for mod in model.modules():
if isinstance(mod, QuantLinear):

# convert all patched attributes to Parameters of torch_dtype
# so FSDP can shard them
for attr_name in PATCH_FOR_FSDP_TRITON_V2:
attr = getattr(mod, attr_name)
attr = torch.nn.Parameter(
attr.view(torch_dtype), requires_grad=False
)
setattr(mod, attr_name, attr)

# this patches the forward to convert them back to original
# type (i.e. int32) before the function call into the kernels
_forward = patch_forward_to_view_attributes_before_call(
mod.forward,
attribute_names=PATCH_FOR_FSDP_TRITON_V2,
torch_dtype=torch.int32, # patch it back to
)
mod.forward = MethodType(_forward, mod)

# replace
AutoModelForCausalLM.from_config = _old_from_config
# replace
AutoModelForCausalLM.from_config = _old_from_config

# AutoGPTQ does not set the torch_dtype of the model carefully
model.config.torch_dtype = torch_dtype
Expand Down
2 changes: 1 addition & 1 deletion plugins/framework/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "fms-acceleration"
version = '0.1.1.dev'
version = '0.1.2.dev'
description = "FMS Acceleration Plugin Framework"
authors = [
{name = "Fabian Lim", email = "[email protected]"},
Expand Down
32 changes: 31 additions & 1 deletion plugins/framework/src/fms_acceleration/framework_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,38 @@

# Standard
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple, Callable
import importlib
import sys

# Third Party
from accelerate import Accelerator
from peft import LoraConfig
from transformers.utils import logging
from transformers import TrainingArguments
import torch

# want to use the transformers logger, but a bit of pain
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
logger.setLevel(logging._get_default_logging_level())
logger.addHandler(logging._default_handler)

def log_patch_summary(
logging_func: Callable = None,
):
if logging_func is None:
logging_func = print

# this is a guarded import, because the model rule registration
# does not need to be loaded unless patch_model is required
# Local
from fms_acceleration.model_patcher import ( # pylint: disable=import-outside-toplevel
patch_model_summary,
)

for line in patch_model_summary().split("\n"):
logging_func(line)


@dataclass
class PluginRegistration:
Expand Down Expand Up @@ -146,6 +168,14 @@ def augmentation(
def get_callbacks_and_ready_for_train(
self, model: torch.nn.Module = None, accelerator: Accelerator = None
):
# Finally apply all registered patches to the model
from .model_patcher import ModelPatcher # pylint: disable=import-outside-toplevel
ModelPatcher.patch(model)

# if patching is done, print patch summary to logger
if len(ModelPatcher.history)>0:
log_patch_summary(logging_func=logger.info)

return []

def _check_config_and_maybe_check_values(
Expand Down
Loading
Loading