Skip to content

Commit

Permalink
shifted patch trigger to main framework class
Browse files Browse the repository at this point in the history
  • Loading branch information
achew010 committed Jul 29, 2024
1 parent 1d498e0 commit a4f8800
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 295 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def register_tensors_as_parameters_patch_rule(target_module, torch_dtype):
forward_builder = partial(
build_patch_to_view_tensor_to_parameter_for_fsdp_gptq, torch_dtype=torch_dtype
),
forward_builder_args=["torch_dtype"],
)
)

Expand Down
25 changes: 25 additions & 0 deletions plugins/framework/src/fms_acceleration/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,22 @@
logger.setLevel(logging._get_default_logging_level())
logger.addHandler(logging._default_handler)

def log_patch_summary(
logging_func: Callable = None,
):
if logging_func is None:
logging_func = print

# this is a guarded import, because the model rule registration
# does not need to be loaded unless patch_model is required
# Local
from .model_patcher import ( # pylint: disable=import-outside-toplevel
patch_model_summary,
)

for line in patch_model_summary().split("\n"):
logging_func(line)


def check_plugin_packages(plugin: AccelerationPlugin):
if plugin.require_packages is None:
Expand Down Expand Up @@ -215,6 +231,15 @@ def get_callbacks_and_ready_for_train(
logging_func=logger.info,
)

from .model_patcher import ModelPatcher # pylint: disable=import-outside-toplevel
if model is not None:
# Finally apply all registered patches to the model
ModelPatcher.patch(model)

# if patching is done, print patch summary to logger
if len(ModelPatcher.history) > 0:
log_patch_summary(logging_func=logger.info)

cbks = []
for _, plugin in self.active_plugins:
cbks.extend(plugin.get_callbacks_and_ready_for_train(model, accelerator))
Expand Down
32 changes: 0 additions & 32 deletions plugins/framework/src/fms_acceleration/framework_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,32 +21,9 @@
# Third Party
from accelerate import Accelerator
from peft import LoraConfig
from transformers.utils import logging
from transformers import TrainingArguments
import torch

# want to use the transformers logger, but a bit of pain
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
logger.setLevel(logging._get_default_logging_level())
logger.addHandler(logging._default_handler)

def log_patch_summary(
logging_func: Callable = None,
):
if logging_func is None:
logging_func = print

# this is a guarded import, because the model rule registration
# does not need to be loaded unless patch_model is required
# Local
from fms_acceleration.model_patcher import ( # pylint: disable=import-outside-toplevel
patch_model_summary,
)

for line in patch_model_summary().split("\n"):
logging_func(line)


@dataclass
class PluginRegistration:
plugin: "AccelerationPlugin"
Expand Down Expand Up @@ -168,15 +145,6 @@ def augmentation(
def get_callbacks_and_ready_for_train(
self, model: torch.nn.Module = None, accelerator: Accelerator = None
):
from .model_patcher import ModelPatcher # pylint: disable=import-outside-toplevel
if model is not None:
# Finally apply all registered patches to the model
ModelPatcher.patch(model)

# if patching is done, print patch summary to logger
if len(ModelPatcher.history) > 0:
log_patch_summary(logging_func=logger.info)

return []

def _check_config_and_maybe_check_values(
Expand Down
185 changes: 85 additions & 100 deletions plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,114 +34,99 @@
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops

# TODO: have a generic version of this rule
# - do regex on RMSNorm class name
# - check on the tensors required for fast_rms_layernorm
RULE_LLAMA_RMS = ModelPatcherRule(
rule_id="llama-rms",
trigger=ModelPatcherTrigger(check=LlamaRMSNorm),
forward=fast_rms_layernorm,
)

# TODO: have a generic version of this rule
# - do regex on Attention class name
# - have a set of qkv / o module names and check on that
RULE_LLAMA_QKVO = ModelPatcherRule(
rule_id="llama-qkvo",
trigger=combine_triggers(
ModelPatcherTrigger(
check=partial(
trigger_fused_ops,
attn_cls=LlamaAttention,
submodule_names=["q_proj", "k_proj", "v_proj"],
)
),
ModelPatcherTrigger(
check=partial(
trigger_fused_ops,
attn_cls=LlamaAttention,
submodule_names=["o_proj"],
)
),
logic="OR",
),
forward_builder=combine_functions(
partial(
build_lora_fused_ops,
submodule_names=["q_proj", "k_proj", "v_proj"],
fused_op=KEY_QKV,
),
partial(
build_lora_fused_ops,
submodule_names=["o_proj"],
fused_op=KEY_O,
),
logic="APPEND",
),
forward_builder_args=["base_type"],
)

RULE_LLAMA_MLP = ModelPatcherRule(
rule_id="llama-mlp",
trigger=ModelPatcherTrigger(
check=partial(
trigger_fused_ops,
attn_cls=LlamaMLP,
submodule_names=["up_proj", "down_proj", "gate_proj"],
)
),
forward_builder=partial(
build_lora_fused_ops,
submodule_names=["up_proj", "down_proj", "gate_proj"],
fused_op=KEY_MLP,
),
forward_builder_args=["base_type"],
)

# TODO: have a generic version of this rule
# - get the module_name and reload on that
RULE_LLAMA_CE = ModelPatcherRule(
rule_id="llama-cross-ent",
import_and_maybe_reload=(
"torch.nn.CrossEntropyLoss",
FastCrossEntropyLoss,
"transformers.models.llama.modeling_llama",
),
)

# TODO: have a generic version of this rule
# - get the module name
# - check if "apply_rotary_pos_emb" exists
# - patch
RULE_LLAMA_ROPE = ModelPatcherRule(
rule_id="llama-rope",
import_and_maybe_reload=(
"transformers.models.llama.modeling_llama.apply_rotary_pos_emb",
fast_rope_embedding,
None,
),
)

LLAMA_MP_RULES = [
RULE_LLAMA_RMS,
RULE_LLAMA_QKVO,
RULE_LLAMA_MLP,
RULE_LLAMA_CE,
RULE_LLAMA_ROPE,
]

def get_mp_rules(base_type):
"""
Function to access all patch rules in this module.
If it is a forward_builder rule with `base_type` in
its forward builder argument, wrap the forward_builder
function as a partial function with the base_type argument
"""
LLAMA_MP_RULES = [
# TODO: have a generic version of this rule
# - do regex on RMSNorm class name
# - check on the tensors required for fast_rms_layernorm
ModelPatcherRule(
rule_id="llama-rms",
trigger=ModelPatcherTrigger(check=LlamaRMSNorm),
forward=fast_rms_layernorm,
),
# TODO: have a generic version of this rule
# - do regex on Attention class name
# - have a set of qkv / o module names and check on that
ModelPatcherRule(
rule_id="llama-qkvo",
trigger=combine_triggers(
ModelPatcherTrigger(
check=partial(
trigger_fused_ops,
attn_cls=LlamaAttention,
submodule_names=["q_proj", "k_proj", "v_proj"],
)
),
ModelPatcherTrigger(
check=partial(
trigger_fused_ops,
attn_cls=LlamaAttention,
submodule_names=["o_proj"],
)
),
logic="OR",
),
forward_builder=combine_functions(
partial(
build_lora_fused_ops,
submodule_names=["q_proj", "k_proj", "v_proj"],
fused_op=KEY_QKV,
),
partial(
build_lora_fused_ops,
submodule_names=["o_proj"],
fused_op=KEY_O,
),
logic="APPEND",
),
),
ModelPatcherRule(
rule_id="llama-mlp",
trigger=ModelPatcherTrigger(
check=partial(
trigger_fused_ops,
attn_cls=LlamaMLP,
submodule_names=["up_proj", "down_proj", "gate_proj"],
)
),
forward_builder=partial(
build_lora_fused_ops,
submodule_names=["up_proj", "down_proj", "gate_proj"],
fused_op=KEY_MLP,
),
),
# TODO: have a generic version of this rule
# - get the module_name and reload on that
ModelPatcherRule(
rule_id="llama-cross-ent",
import_and_maybe_reload=(
"torch.nn.CrossEntropyLoss",
FastCrossEntropyLoss,
"transformers.models.llama.modeling_llama",
),
),
# TODO: have a generic version of this rule
# - get the module name
# - check if "apply_rotary_pos_emb" exists
# - patch
ModelPatcherRule(
rule_id="llama-rope",
import_and_maybe_reload=(
"transformers.models.llama.modeling_llama.apply_rotary_pos_emb",
fast_rope_embedding,
None,
),
)
]

for rule in LLAMA_MP_RULES:
if (
rule.forward_builder is not None
and "base_type" in rule.forward_builder_args
):
if rule.forward_builder is not None:
rule.forward_builder = partial(
rule.forward_builder,
base_type=base_type,
Expand Down
Loading

0 comments on commit a4f8800

Please sign in to comment.