From a4f8800e6a78762a854aa6164541ecad8875d1f2 Mon Sep 17 00:00:00 2001 From: 1000850000 user Date: Mon, 29 Jul 2024 02:51:23 +0000 Subject: [PATCH] shifted patch trigger to main framework class --- .../fms_acceleration_peft/autogptq_utils.py | 1 - .../src/fms_acceleration/framework.py | 25 +++ .../src/fms_acceleration/framework_plugin.py | 32 --- .../src/fms_acceleration_foak/models/llama.py | 185 ++++++++---------- .../fms_acceleration_foak/models/mistral.py | 166 +++++++--------- .../fms_acceleration_foak/models/mixtral.py | 133 ++++++------- 6 files changed, 247 insertions(+), 295 deletions(-) diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py b/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py index c11640ac..a62d0543 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py @@ -60,7 +60,6 @@ def register_tensors_as_parameters_patch_rule(target_module, torch_dtype): forward_builder = partial( build_patch_to_view_tensor_to_parameter_for_fsdp_gptq, torch_dtype=torch_dtype ), - forward_builder_args=["torch_dtype"], ) ) diff --git a/plugins/framework/src/fms_acceleration/framework.py b/plugins/framework/src/fms_acceleration/framework.py index 62735296..999b0415 100644 --- a/plugins/framework/src/fms_acceleration/framework.py +++ b/plugins/framework/src/fms_acceleration/framework.py @@ -38,6 +38,22 @@ logger.setLevel(logging._get_default_logging_level()) logger.addHandler(logging._default_handler) +def log_patch_summary( + logging_func: Callable = None, +): + if logging_func is None: + logging_func = print + + # this is a guarded import, because the model rule registration + # does not need to be loaded unless patch_model is required + # Local + from .model_patcher import ( # pylint: disable=import-outside-toplevel + patch_model_summary, + ) + + for line in patch_model_summary().split("\n"): + logging_func(line) + def check_plugin_packages(plugin: AccelerationPlugin): if plugin.require_packages is None: @@ -215,6 +231,15 @@ def get_callbacks_and_ready_for_train( logging_func=logger.info, ) + from .model_patcher import ModelPatcher # pylint: disable=import-outside-toplevel + if model is not None: + # Finally apply all registered patches to the model + ModelPatcher.patch(model) + + # if patching is done, print patch summary to logger + if len(ModelPatcher.history) > 0: + log_patch_summary(logging_func=logger.info) + cbks = [] for _, plugin in self.active_plugins: cbks.extend(plugin.get_callbacks_and_ready_for_train(model, accelerator)) diff --git a/plugins/framework/src/fms_acceleration/framework_plugin.py b/plugins/framework/src/fms_acceleration/framework_plugin.py index 49b522ef..7bab757c 100644 --- a/plugins/framework/src/fms_acceleration/framework_plugin.py +++ b/plugins/framework/src/fms_acceleration/framework_plugin.py @@ -21,32 +21,9 @@ # Third Party from accelerate import Accelerator from peft import LoraConfig -from transformers.utils import logging from transformers import TrainingArguments import torch -# want to use the transformers logger, but a bit of pain -logger = logging.get_logger(__name__) # pylint: disable=invalid-name -logger.setLevel(logging._get_default_logging_level()) -logger.addHandler(logging._default_handler) - -def log_patch_summary( - logging_func: Callable = None, -): - if logging_func is None: - logging_func = print - - # this is a guarded import, because the model rule registration - # does not need to be loaded unless patch_model is required - # Local - from fms_acceleration.model_patcher import ( # pylint: disable=import-outside-toplevel - patch_model_summary, - ) - - for line in patch_model_summary().split("\n"): - logging_func(line) - - @dataclass class PluginRegistration: plugin: "AccelerationPlugin" @@ -168,15 +145,6 @@ def augmentation( def get_callbacks_and_ready_for_train( self, model: torch.nn.Module = None, accelerator: Accelerator = None ): - from .model_patcher import ModelPatcher # pylint: disable=import-outside-toplevel - if model is not None: - # Finally apply all registered patches to the model - ModelPatcher.patch(model) - - # if patching is done, print patch summary to logger - if len(ModelPatcher.history) > 0: - log_patch_summary(logging_func=logger.info) - return [] def _check_config_and_maybe_check_values( diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py index 233dd4de..165f81e5 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py @@ -34,102 +34,6 @@ from ..kernels.unsloth.rope_embedding import fast_rope_embedding from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops -# TODO: have a generic version of this rule -# - do regex on RMSNorm class name -# - check on the tensors required for fast_rms_layernorm -RULE_LLAMA_RMS = ModelPatcherRule( - rule_id="llama-rms", - trigger=ModelPatcherTrigger(check=LlamaRMSNorm), - forward=fast_rms_layernorm, -) - -# TODO: have a generic version of this rule -# - do regex on Attention class name -# - have a set of qkv / o module names and check on that -RULE_LLAMA_QKVO = ModelPatcherRule( - rule_id="llama-qkvo", - trigger=combine_triggers( - ModelPatcherTrigger( - check=partial( - trigger_fused_ops, - attn_cls=LlamaAttention, - submodule_names=["q_proj", "k_proj", "v_proj"], - ) - ), - ModelPatcherTrigger( - check=partial( - trigger_fused_ops, - attn_cls=LlamaAttention, - submodule_names=["o_proj"], - ) - ), - logic="OR", - ), - forward_builder=combine_functions( - partial( - build_lora_fused_ops, - submodule_names=["q_proj", "k_proj", "v_proj"], - fused_op=KEY_QKV, - ), - partial( - build_lora_fused_ops, - submodule_names=["o_proj"], - fused_op=KEY_O, - ), - logic="APPEND", - ), - forward_builder_args=["base_type"], -) - -RULE_LLAMA_MLP = ModelPatcherRule( - rule_id="llama-mlp", - trigger=ModelPatcherTrigger( - check=partial( - trigger_fused_ops, - attn_cls=LlamaMLP, - submodule_names=["up_proj", "down_proj", "gate_proj"], - ) - ), - forward_builder=partial( - build_lora_fused_ops, - submodule_names=["up_proj", "down_proj", "gate_proj"], - fused_op=KEY_MLP, - ), - forward_builder_args=["base_type"], -) - -# TODO: have a generic version of this rule -# - get the module_name and reload on that -RULE_LLAMA_CE = ModelPatcherRule( - rule_id="llama-cross-ent", - import_and_maybe_reload=( - "torch.nn.CrossEntropyLoss", - FastCrossEntropyLoss, - "transformers.models.llama.modeling_llama", - ), -) - -# TODO: have a generic version of this rule -# - get the module name -# - check if "apply_rotary_pos_emb" exists -# - patch -RULE_LLAMA_ROPE = ModelPatcherRule( - rule_id="llama-rope", - import_and_maybe_reload=( - "transformers.models.llama.modeling_llama.apply_rotary_pos_emb", - fast_rope_embedding, - None, - ), -) - -LLAMA_MP_RULES = [ - RULE_LLAMA_RMS, - RULE_LLAMA_QKVO, - RULE_LLAMA_MLP, - RULE_LLAMA_CE, - RULE_LLAMA_ROPE, -] - def get_mp_rules(base_type): """ Function to access all patch rules in this module. @@ -137,11 +41,92 @@ def get_mp_rules(base_type): its forward builder argument, wrap the forward_builder function as a partial function with the base_type argument """ + LLAMA_MP_RULES = [ + # TODO: have a generic version of this rule + # - do regex on RMSNorm class name + # - check on the tensors required for fast_rms_layernorm + ModelPatcherRule( + rule_id="llama-rms", + trigger=ModelPatcherTrigger(check=LlamaRMSNorm), + forward=fast_rms_layernorm, + ), + # TODO: have a generic version of this rule + # - do regex on Attention class name + # - have a set of qkv / o module names and check on that + ModelPatcherRule( + rule_id="llama-qkvo", + trigger=combine_triggers( + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=LlamaAttention, + submodule_names=["q_proj", "k_proj", "v_proj"], + ) + ), + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=LlamaAttention, + submodule_names=["o_proj"], + ) + ), + logic="OR", + ), + forward_builder=combine_functions( + partial( + build_lora_fused_ops, + submodule_names=["q_proj", "k_proj", "v_proj"], + fused_op=KEY_QKV, + ), + partial( + build_lora_fused_ops, + submodule_names=["o_proj"], + fused_op=KEY_O, + ), + logic="APPEND", + ), + ), + ModelPatcherRule( + rule_id="llama-mlp", + trigger=ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=LlamaMLP, + submodule_names=["up_proj", "down_proj", "gate_proj"], + ) + ), + forward_builder=partial( + build_lora_fused_ops, + submodule_names=["up_proj", "down_proj", "gate_proj"], + fused_op=KEY_MLP, + ), + ), + # TODO: have a generic version of this rule + # - get the module_name and reload on that + ModelPatcherRule( + rule_id="llama-cross-ent", + import_and_maybe_reload=( + "torch.nn.CrossEntropyLoss", + FastCrossEntropyLoss, + "transformers.models.llama.modeling_llama", + ), + ), + # TODO: have a generic version of this rule + # - get the module name + # - check if "apply_rotary_pos_emb" exists + # - patch + ModelPatcherRule( + rule_id="llama-rope", + import_and_maybe_reload=( + "transformers.models.llama.modeling_llama.apply_rotary_pos_emb", + fast_rope_embedding, + None, + ), + ) + ] + for rule in LLAMA_MP_RULES: - if ( - rule.forward_builder is not None - and "base_type" in rule.forward_builder_args - ): + if rule.forward_builder is not None: rule.forward_builder = partial( rule.forward_builder, base_type=base_type, diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py index 300eaffd..a1c9ead9 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py @@ -35,92 +35,6 @@ from ..kernels.unsloth.rope_embedding import fast_rope_embedding from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops -# - do regex on RMSNorm class name -# - check on the tensors required for fast_rms_layernorm -RULE_MISTRAL_RMS = ModelPatcherRule( - rule_id="mistral-rms", - trigger=ModelPatcherTrigger(check=MistralRMSNorm), - forward=fast_rms_layernorm, -) - -RULE_MISTRAL_QKVO = ModelPatcherRule( - rule_id="mistral-qkvo", - trigger=combine_triggers( - ModelPatcherTrigger( - check=partial( - trigger_fused_ops, - attn_cls=MistralAttention, - submodule_names=["q_proj", "k_proj", "v_proj"], - ) - ), - ModelPatcherTrigger( - check=partial( - trigger_fused_ops, - attn_cls=MistralAttention, - submodule_names=["o_proj"], - ) - ), - logic="OR", - ), - forward_builder=combine_functions( - partial( - build_lora_fused_ops, - submodule_names=["q_proj", "k_proj", "v_proj"], - fused_op=KEY_QKV, - ), - partial( - build_lora_fused_ops, - submodule_names=["o_proj"], - fused_op=KEY_O, - ), - logic="APPEND", - ), - forward_builder_args=["base_type"], -) - -RULE_MISTRAL_MLP = ModelPatcherRule( - rule_id="mistral-mlp", - trigger=ModelPatcherTrigger( - check=partial( - trigger_fused_ops, - attn_cls=MistralMLP, - submodule_names=["up_proj", "down_proj", "gate_proj"], - ) - ), - forward_builder=partial( - build_lora_fused_ops, - submodule_names=["up_proj", "down_proj", "gate_proj"], - fused_op=KEY_MLP, - ), - forward_builder_args=["base_type"], -) - -RULE_MISTRAL_CE = ModelPatcherRule( - rule_id="mistral-cross-ent", - import_and_maybe_reload=( - "torch.nn.CrossEntropyLoss", - FastCrossEntropyLoss, - "transformers.models.mistral.modeling_mistral", - ), -) - -RULE_MISTRAL_ROPE = ModelPatcherRule( - rule_id="mistral-rope", - import_and_maybe_reload=( - "transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb", - fast_rope_embedding, - None, - ), -) - -MISTRAL_MP_RULES = [ - RULE_MISTRAL_RMS, - RULE_MISTRAL_QKVO, - RULE_MISTRAL_MLP, - RULE_MISTRAL_CE, - RULE_MISTRAL_ROPE, -] - def get_mp_rules(base_type): """ Function to access all patch rules in this module. @@ -128,11 +42,83 @@ def get_mp_rules(base_type): its forward builder argument, wrap the forward_builder function as a partial function with the base_type argument """ + + MISTRAL_MP_RULES = [ + # - do regex on RMSNorm class name + # - check on the tensors required for fast_rms_layernorm + ModelPatcherRule( + rule_id="mistral-rms", + trigger=ModelPatcherTrigger(check=MistralRMSNorm), + forward=fast_rms_layernorm, + ), + ModelPatcherRule( + rule_id="mistral-qkvo", + trigger=combine_triggers( + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=MistralAttention, + submodule_names=["q_proj", "k_proj", "v_proj"], + ) + ), + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=MistralAttention, + submodule_names=["o_proj"], + ) + ), + logic="OR", + ), + forward_builder=combine_functions( + partial( + build_lora_fused_ops, + submodule_names=["q_proj", "k_proj", "v_proj"], + fused_op=KEY_QKV, + ), + partial( + build_lora_fused_ops, + submodule_names=["o_proj"], + fused_op=KEY_O, + ), + logic="APPEND", + ), + ), + ModelPatcherRule( + rule_id="mistral-mlp", + trigger=ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=MistralMLP, + submodule_names=["up_proj", "down_proj", "gate_proj"], + ) + ), + forward_builder=partial( + build_lora_fused_ops, + submodule_names=["up_proj", "down_proj", "gate_proj"], + fused_op=KEY_MLP, + ), + ), + ModelPatcherRule( + rule_id="mistral-cross-ent", + import_and_maybe_reload=( + "torch.nn.CrossEntropyLoss", + FastCrossEntropyLoss, + "transformers.models.mistral.modeling_mistral", + ), + ), + ModelPatcherRule( + rule_id="mistral-rope", + import_and_maybe_reload=( + "transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb", + fast_rope_embedding, + None, + ), + ) + ] + for rule in MISTRAL_MP_RULES: - if ( - rule.forward_builder is not None - and "base_type" in rule.forward_builder_args - ): + if rule.forward_builder is not None: rule.forward_builder = partial( rule.forward_builder, base_type=base_type, diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py index 7e2feca9..a6bcdcbf 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py @@ -34,74 +34,6 @@ from .utils import KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops -# - do regex on RMSNorm class name -# - check on the tensors required for fast_rms_layernorm -RULE_MIXTRAL_RMS = ModelPatcherRule( - rule_id="mixtral-rms", - trigger=ModelPatcherTrigger(check=MixtralRMSNorm), - forward=fast_rms_layernorm, -) - -RULE_MIXTRAL_QKVO = ModelPatcherRule( - rule_id="mixtral-qkvo", - trigger=combine_triggers( - ModelPatcherTrigger( - check=partial( - trigger_fused_ops, - attn_cls=MixtralAttention, - submodule_names=["q_proj", "k_proj", "v_proj"], - ) - ), - ModelPatcherTrigger( - check=partial( - trigger_fused_ops, - attn_cls=MixtralAttention, - submodule_names=["o_proj"], - ) - ), - logic="OR", - ), - forward_builder=combine_functions( - partial( - build_lora_fused_ops, - submodule_names=["q_proj", "k_proj", "v_proj"], - fused_op=KEY_QKV, - ), - partial( - build_lora_fused_ops, - submodule_names=["o_proj"], - fused_op=KEY_O, - ), - logic="APPEND", - ), - forward_builder_args=["base_type"], -) - -RULE_MIXTRAL_CE = ModelPatcherRule( - rule_id="mixtral-cross-ent", - import_and_maybe_reload=( - "torch.nn.CrossEntropyLoss", - FastCrossEntropyLoss, - "transformers.models.mixtral.modeling_mixtral", - ), -) - -RULE_MIXTRAL_ROPE = ModelPatcherRule( - rule_id="mixtral-rope", - import_and_maybe_reload=( - "transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb", - fast_rope_embedding, - None, - ), -) - -MIXTRAL_MP_RULES = [ - RULE_MIXTRAL_RMS, - RULE_MIXTRAL_QKVO, - RULE_MIXTRAL_CE, - RULE_MIXTRAL_ROPE, -] - def get_mp_rules(base_type): """ Function to access all patch rules in this module. @@ -109,11 +41,68 @@ def get_mp_rules(base_type): its forward builder argument, wrap the forward_builder function as a partial function with the base_type argument """ + + # - do regex on RMSNorm class name + # - check on the tensors required for fast_rms_layernorm + MIXTRAL_MP_RULES = [ + ModelPatcherRule( + rule_id="mixtral-rms", + trigger=ModelPatcherTrigger(check=MixtralRMSNorm), + forward=fast_rms_layernorm, + ), + ModelPatcherRule( + rule_id="mixtral-qkvo", + trigger=combine_triggers( + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=MixtralAttention, + submodule_names=["q_proj", "k_proj", "v_proj"], + ) + ), + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=MixtralAttention, + submodule_names=["o_proj"], + ) + ), + logic="OR", + ), + forward_builder=combine_functions( + partial( + build_lora_fused_ops, + submodule_names=["q_proj", "k_proj", "v_proj"], + fused_op=KEY_QKV, + ), + partial( + build_lora_fused_ops, + submodule_names=["o_proj"], + fused_op=KEY_O, + ), + logic="APPEND", + ), + ), + ModelPatcherRule( + rule_id="mixtral-cross-ent", + import_and_maybe_reload=( + "torch.nn.CrossEntropyLoss", + FastCrossEntropyLoss, + "transformers.models.mixtral.modeling_mixtral", + ), + ), + ModelPatcherRule( + rule_id="mixtral-rope", + import_and_maybe_reload=( + "transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb", + fast_rope_embedding, + None, + ), + ) + ] + for rule in MIXTRAL_MP_RULES: - if ( - rule.forward_builder is not None - and "base_type" in rule.forward_builder_args - ): + if rule.forward_builder is not None: rule.forward_builder = partial( rule.forward_builder, base_type=base_type,