Skip to content

Commit

Permalink
additional modifications to foak patch rules
Browse files Browse the repository at this point in the history
  • Loading branch information
achew010 committed Jul 29, 2024
1 parent a4f8800 commit aff75fd
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from transformers import TrainingArguments
import torch
import torch.distributed as dist
from fms_acceleration_foak.models import register_foak_model_patch_rules

# consider moving this somewhere else later
def lora_adapters_switch_ddp_from_fsdp(modules, fsdp_plugin):
Expand Down Expand Up @@ -59,6 +58,16 @@ def _all_reduce_hook(grad):
if not B.weight.is_cuda:
set_module_tensor_to_device(B, "weight", "cuda")

def register_foak_model_patch_rules(base_type):
from fms_acceleration.model_patcher import ModelPatcher
from .models import llama, mistral, mixtral
rules = [
*llama.get_mp_rules(base_type),
*mistral.get_mp_rules(base_type),
*mixtral.get_mp_rules(base_type),
]
for _rule in rules:
ModelPatcher.register(_rule)

class FastQuantizedPeftAccelerationPlugin(AccelerationPlugin):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Local
from fms_acceleration.model_patcher import ModelPatcher
import importlib

PATCHES = [".models.llama", ".models.mistral", ".models.mixtral"]
PLUGIN_PREFIX = "fms_acceleration_foak"

# TODO: remove the need for the prefix
def register_foak_model_patch_rules(base_type):
for postfix in PATCHES:
# define the patch module path to import
# if it exist, import the module
patch_path = f"{PLUGIN_PREFIX}{postfix}"
if importlib.util.find_spec(patch_path):
m = importlib.import_module(patch_path)
# get all model patcher rules from the module
# register every rule in the module
rules = m.get_mp_rules(base_type)
for _rule in rules:
ModelPatcher.register(_rule)
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops

def get_mp_rules(base_type):
def get_mp_rules(base_type: str):
"""
Function to access all patch rules in this module.
If it is a forward_builder rule with `base_type` in
its forward builder argument, wrap the forward_builder
function as a partial function with the base_type argument
"""
LLAMA_MP_RULES = [
return [
# TODO: have a generic version of this rule
# - do regex on RMSNorm class name
# - check on the tensors required for fast_rms_layernorm
Expand Down Expand Up @@ -77,11 +77,13 @@ def get_mp_rules(base_type):
build_lora_fused_ops,
submodule_names=["q_proj", "k_proj", "v_proj"],
fused_op=KEY_QKV,
base_type=base_type,
),
partial(
build_lora_fused_ops,
submodule_names=["o_proj"],
fused_op=KEY_O,
base_type=base_type,
),
logic="APPEND",
),
Expand All @@ -99,6 +101,7 @@ def get_mp_rules(base_type):
build_lora_fused_ops,
submodule_names=["up_proj", "down_proj", "gate_proj"],
fused_op=KEY_MLP,
base_type=base_type,
),
),
# TODO: have a generic version of this rule
Expand All @@ -124,11 +127,3 @@ def get_mp_rules(base_type):
),
)
]

for rule in LLAMA_MP_RULES:
if rule.forward_builder is not None:
rule.forward_builder = partial(
rule.forward_builder,
base_type=base_type,
)
return LLAMA_MP_RULES
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_mp_rules(base_type):
function as a partial function with the base_type argument
"""

MISTRAL_MP_RULES = [
return [
# - do regex on RMSNorm class name
# - check on the tensors required for fast_rms_layernorm
ModelPatcherRule(
Expand Down Expand Up @@ -75,11 +75,13 @@ def get_mp_rules(base_type):
build_lora_fused_ops,
submodule_names=["q_proj", "k_proj", "v_proj"],
fused_op=KEY_QKV,
base_type=base_type,
),
partial(
build_lora_fused_ops,
submodule_names=["o_proj"],
fused_op=KEY_O,
base_type=base_type,
),
logic="APPEND",
),
Expand All @@ -97,6 +99,7 @@ def get_mp_rules(base_type):
build_lora_fused_ops,
submodule_names=["up_proj", "down_proj", "gate_proj"],
fused_op=KEY_MLP,
base_type=base_type,
),
),
ModelPatcherRule(
Expand All @@ -116,11 +119,3 @@ def get_mp_rules(base_type):
),
)
]

for rule in MISTRAL_MP_RULES:
if rule.forward_builder is not None:
rule.forward_builder = partial(
rule.forward_builder,
base_type=base_type,
)
return MISTRAL_MP_RULES
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_mp_rules(base_type):

# - do regex on RMSNorm class name
# - check on the tensors required for fast_rms_layernorm
MIXTRAL_MP_RULES = [
return [
ModelPatcherRule(
rule_id="mixtral-rms",
trigger=ModelPatcherTrigger(check=MixtralRMSNorm),
Expand Down Expand Up @@ -74,11 +74,13 @@ def get_mp_rules(base_type):
build_lora_fused_ops,
submodule_names=["q_proj", "k_proj", "v_proj"],
fused_op=KEY_QKV,
base_type=base_type,
),
partial(
build_lora_fused_ops,
submodule_names=["o_proj"],
fused_op=KEY_O,
base_type=base_type,
),
logic="APPEND",
),
Expand All @@ -100,11 +102,3 @@ def get_mp_rules(base_type):
),
)
]

for rule in MIXTRAL_MP_RULES:
if rule.forward_builder is not None:
rule.forward_builder = partial(
rule.forward_builder,
base_type=base_type,
)
return MIXTRAL_MP_RULES

0 comments on commit aff75fd

Please sign in to comment.