diff --git a/plugins/accelerated-peft/pyproject.toml b/plugins/accelerated-peft/pyproject.toml index 35789df0..9e418287 100644 --- a/plugins/accelerated-peft/pyproject.toml +++ b/plugins/accelerated-peft/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "fms-acceleration-peft" -version = '0.0.1' +version = '0.1.0.1.dev' description = "FMS Acceleration for PeFT" authors = [ {name = "Fabian Lim", email = "flim@sg.ibm.com"}, diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py b/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py index 913a6b7e..a62d0543 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py @@ -16,43 +16,52 @@ # https://spdx.dev/learn/handling-license-info/ # Standard -from typing import Any, Callable, List -import importlib +from typing import Callable, List # Third Party from peft import LoraConfig from peft.tuners.lora.gptq import QuantLinear as LoraLinearGPTQ import torch +from fms_acceleration.model_patcher import ModelPatcher, ModelPatcherRule, ModelPatcherTrigger +from functools import partial + # these parameters are to be patched for triton v2 # consider making a map if patching more kernels PATCH_FOR_FSDP_TRITON_V2 = ["qweight", "qzeros"] - -# This function may be moved after merging -# https://github.com/foundation-model-stack/fms-acceleration/pull/25 -def _patch_target_module( - to_patch: str, - replace_with: Any, - target_module: str = None, +def build_patch_to_view_tensor_to_parameter_for_fsdp_gptq( + module, + torch_dtype, ): - to_patch = to_patch.split(".") - assert len(to_patch) > 1, "must have an object to patch" - - to_patch, obj_name_to_patch = to_patch[:-1], to_patch[-1] - to_patch = ".".join(to_patch) - source = importlib.import_module(to_patch) - original_obj = getattr(source, obj_name_to_patch) - setattr(source, obj_name_to_patch, replace_with) - - if target_module is not None: - # reload and this should get the patched object - target_module = importlib.import_module(target_module) - importlib.reload(target_module) - - # replace it - setattr(source, obj_name_to_patch, original_obj) + # convert all patched attributes to Parameters of torch_dtype + # so FSDP can shard them + for attr_name in PATCH_FOR_FSDP_TRITON_V2: + attr = getattr(module, attr_name) + attr = torch.nn.Parameter( + attr.view(torch_dtype), requires_grad=False + ) + setattr(module, attr_name, attr) + + # this patches the forward to convert them back to original + # type (i.e. int32) before the function call into the kernels + return patch_forward_to_view_attributes_before_call( + module.forward, + attribute_names=PATCH_FOR_FSDP_TRITON_V2, + torch_dtype=torch.int32, # patch it back to + ) +def register_tensors_as_parameters_patch_rule(target_module, torch_dtype): + # Register patch + ModelPatcher.register( + ModelPatcherRule( + rule_id="autogptq_patch_tensors_as_float_parameters", + trigger=ModelPatcherTrigger(check=target_module), + forward_builder = partial( + build_patch_to_view_tensor_to_parameter_for_fsdp_gptq, torch_dtype=torch_dtype + ), + ) + ) def make_sure_no_tensor_in_meta_device( model, @@ -124,7 +133,6 @@ def create_new_module_peft( # if module cannot be found, return None which results in a raise in the call-stack return new_module - # consider to move this somewhere more general def patch_forward_to_view_attributes_before_call( old_forward: Callable, diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py index f6fc8474..8e95751d 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py @@ -24,6 +24,7 @@ # Third Party from fms_acceleration import AccelerationPlugin +from fms_acceleration.model_patcher import patch_target_module from peft import LoraConfig, prepare_model_for_kbit_training from peft.tuners.lora.model import LoraModel from transformers import AutoModelForCausalLM, TrainingArguments @@ -31,6 +32,8 @@ import torch import torch.distributed +# Local +from .autogptq_utils import register_tensors_as_parameters_patch_rule class AutoGPTQAccelerationPlugin(AccelerationPlugin): @@ -81,11 +84,6 @@ def model_loader(self, model_name: str, **kwargs): from .gptqmodel.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error QuantLinear, ) - # Local - from .autogptq_utils import ( # pylint: disable=import-outside-toplevel - PATCH_FOR_FSDP_TRITON_V2, - patch_forward_to_view_attributes_before_call, - ) # Currently we allow only a quantized checkpoint to be loaded, we do not # implement the quantization process here. @@ -143,14 +141,11 @@ def model_loader(self, model_name: str, **kwargs): kwargs["low_cpu_mem_usage"] = True if self.use_external_lib: # Local - from .autogptq_utils import ( # pylint: disable=import-outside-toplevel - _patch_target_module, - make_sure_no_tensor_in_meta_device, - ) + from .autogptq_utils import make_sure_no_tensor_in_meta_device # pylint: disable=import-outside-toplevel # We patch `make_sure_no_tensor_in_meta_device` # from autogptq to avoid errors on models without bias - _patch_target_module( + patch_target_module( to_patch="auto_gptq.modeling._utils.make_sure_no_tensor_in_meta_device", replace_with=make_sure_no_tensor_in_meta_device, target_module="auto_gptq.modeling._base", @@ -201,31 +196,14 @@ def model_loader(self, model_name: str, **kwargs): world_size > 1 and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true" ): + # register FSDP patch + register_tensors_as_parameters_patch_rule( + target_module=QuantLinear, + torch_dtype=torch_dtype, + ) - # patch all the QuantLinear base layers - for mod in model.modules(): - if isinstance(mod, QuantLinear): - - # convert all patched attributes to Parameters of torch_dtype - # so FSDP can shard them - for attr_name in PATCH_FOR_FSDP_TRITON_V2: - attr = getattr(mod, attr_name) - attr = torch.nn.Parameter( - attr.view(torch_dtype), requires_grad=False - ) - setattr(mod, attr_name, attr) - - # this patches the forward to convert them back to original - # type (i.e. int32) before the function call into the kernels - _forward = patch_forward_to_view_attributes_before_call( - mod.forward, - attribute_names=PATCH_FOR_FSDP_TRITON_V2, - torch_dtype=torch.int32, # patch it back to - ) - mod.forward = MethodType(_forward, mod) - - # replace - AutoModelForCausalLM.from_config = _old_from_config + # replace + AutoModelForCausalLM.from_config = _old_from_config # AutoGPTQ does not set the torch_dtype of the model carefully model.config.torch_dtype = torch_dtype diff --git a/plugins/framework/pyproject.toml b/plugins/framework/pyproject.toml index 457cc729..7baa7f1c 100644 --- a/plugins/framework/pyproject.toml +++ b/plugins/framework/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "fms-acceleration" -version = '0.1.1.dev' +version = '0.1.2.dev' description = "FMS Acceleration Plugin Framework" authors = [ {name = "Fabian Lim", email = "flim@sg.ibm.com"}, @@ -27,6 +27,7 @@ dependencies = [ "transformers<4.40", "peft", "accelerate", + "pandas", ] [tool.hatch.build.targets.wheel] diff --git a/plugins/framework/src/fms_acceleration/framework.py b/plugins/framework/src/fms_acceleration/framework.py index 62735296..c8de939c 100644 --- a/plugins/framework/src/fms_acceleration/framework.py +++ b/plugins/framework/src/fms_acceleration/framework.py @@ -38,6 +38,22 @@ logger.setLevel(logging._get_default_logging_level()) logger.addHandler(logging._default_handler) +def log_patch_summary( + logging_func: Callable = None, +): + if logging_func is None: + logging_func = print + + # this is a guarded import, because the model rule registration + # does not need to be loaded unless patch_model is required + # Local + from .model_patcher import ( # pylint: disable=import-outside-toplevel + patch_model_summary, + ) + + for line in patch_model_summary().split("\n"): + logging_func(line) + def check_plugin_packages(plugin: AccelerationPlugin): if plugin.require_packages is None: @@ -207,6 +223,12 @@ def requires_agumentation(self): def get_callbacks_and_ready_for_train( self, model: torch.nn.Module = None, accelerator: Accelerator = None ): + + from .model_patcher import ModelPatcher # pylint: disable=import-outside-toplevel + if model is not None: + # Finally apply all registered patches to the model + ModelPatcher.patch(model) + # show the initialized message if accelerator is not None and accelerator.is_main_process: log_initialization_message( @@ -215,6 +237,9 @@ def get_callbacks_and_ready_for_train( logging_func=logger.info, ) + # if patching is done, print patch summary to logger + log_patch_summary(logging_func=logger.info) + cbks = [] for _, plugin in self.active_plugins: cbks.extend(plugin.get_callbacks_and_ready_for_train(model, accelerator)) diff --git a/plugins/framework/src/fms_acceleration/framework_plugin.py b/plugins/framework/src/fms_acceleration/framework_plugin.py index fc4fcf9c..0db569c6 100644 --- a/plugins/framework/src/fms_acceleration/framework_plugin.py +++ b/plugins/framework/src/fms_acceleration/framework_plugin.py @@ -24,7 +24,6 @@ from transformers import TrainingArguments import torch - @dataclass class PluginRegistration: plugin: "AccelerationPlugin" diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/model_patcher.py b/plugins/framework/src/fms_acceleration/model_patcher.py similarity index 81% rename from plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/model_patcher.py rename to plugins/framework/src/fms_acceleration/model_patcher.py index 7f803330..5f0655fb 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/model_patcher.py +++ b/plugins/framework/src/fms_acceleration/model_patcher.py @@ -19,6 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union import importlib import inspect +import warnings # Third Party import pandas as pd @@ -27,7 +28,7 @@ # ------------------------ helpers ----------------------- -def _patch_target_module( +def patch_target_module( to_patch: str, replace_with: Any, target_module: str = None, @@ -74,7 +75,7 @@ class ModelPatcherTrigger: # the trigger operation check: Union[ - torch.nn.Module, # trigger on isinstance + Type[torch.nn.Module], # trigger on isinstance Callable[[torch.nn.Module], bool], # trigger on callable ] @@ -88,7 +89,7 @@ class ModelPatcherTrigger: def is_triggered( self, module: torch.nn.Module, - module_name: str, + module_name: str = None, ): "Check if trigger returns truthful." @@ -111,12 +112,25 @@ def is_triggered( return False def __post_init__(self): - - if self.type is None: - if inspect.isclass(self.check) and issubclass(self.check, torch.nn.Module): + # if check is a module + if inspect.isclass(self.check) and issubclass(self.check, torch.nn.Module): + if self.type is None: self.type = ModelPatcherTriggerType.module else: + # ensure check conforms with self.type + assert self.type == ModelPatcherTriggerType.module, \ + "type argument passed but `check` argument does not match type specified" + # if check is a callable + elif callable(self.check): + if self.type is None: self.type = ModelPatcherTriggerType.callable + else: + # ensure check conforms with self.type + assert self.type == ModelPatcherTriggerType.callable, \ + "type argument passed but `check` argument does not match type specified" + else: + raise TypeError("check argument needs to be torch.nn.Module or Callable") + # type for model forward @@ -161,11 +175,11 @@ class ModelPatcherRule: ] = None def __post_init__(self): - if ( - self.forward is not None - and self.forward_builder is not None - and self.import_and_maybe_reload is not None - ): + if sum([ + self.forward is not None, + self.forward_builder is not None, + self.import_and_maybe_reload is not None, + ]) != 1: raise ValueError( f"Rule '{self.rule_id}' must only have only one of forward, " "foward builder, or import_and_maybe_reload, specified." @@ -183,7 +197,6 @@ def __post_init__(self): "forward_builder." ) - # helpful to keep a history of all patching that has been done @dataclass class ModelPatcherHistory: @@ -247,6 +260,8 @@ def register(rule: ModelPatcherRule): @staticmethod def did_rule_trigger(module: torch.nn.Module, module_name: str): + + active_rule_name, active_rule = None, None for name, rule in ModelPatcher.rules.items(): # if there is no trigger @@ -254,9 +269,19 @@ def did_rule_trigger(module: torch.nn.Module, module_name: str): continue if rule.trigger.is_triggered(module, module_name): - return name, rule - - return None, None + # if active rule, assign the the current rule as active + if active_rule is None: + active_rule_name = name + active_rule = rule + # otherwise, if there is already an active rule, raise warning + # that subsequent compatible forward rules will be ignored + # for simple forward patches. forward_builder args are handled + # when they are decomposed into new simple forward rules + elif rule.forward is not None: + warnings.warn(f"rule {rule.rule_id} is ignored on {module_name} as an \ + earlier rule {active_rule.rule_id} has been applied") + + return active_rule_name, active_rule @staticmethod def _import_and_reload(model: torch.nn.Module): @@ -305,12 +330,28 @@ def _import_and_reload(model: torch.nn.Module): elif _target.startswith(module_path): _no_reload.append(rule) - assert len(_with_reload) <= 1, "cannot have have at most one rule with reload" + # If there are multiple reload targets, + # ensure that their paths do not conflict as reloading same module might reset patches + if len(_with_reload) > 1: + # sort ascending target path length + _with_reload = sorted( + _with_reload, + key=lambda _rule: len(_rule.import_and_maybe_reload[2]), + reverse=False + ) + for rule_s in _with_reload: + for rule_l in _with_reload[1:]: + # if target paths in rule s is a prefix of rule l, raise an error + _, _, _path_s = rule_s.import_and_maybe_reload + _, _, _path_l = rule_l.import_and_maybe_reload + assert not _path_l.startswith(_path_s), \ + f"Attempting to reload same path `{_path_s}` multiple times in \ + {rule_s.rule_id} and {rule_l.rule_id}" # handle those with reload first for rule in _with_reload + _no_reload: _target, _object, _reload = rule.import_and_maybe_reload - _patch_target_module(_target, _object, _reload) + patch_target_module(_target, _object, _reload) ModelPatcher.history.append( ModelPatcherHistory( instance=id(model), @@ -423,6 +464,7 @@ def _patch_forwards( def patch(model: torch.nn.Module, **kwargs): # NOTE: for a set of rules, this patch function should be called # only once. We do not have any checks for this at the moment + try: ModelPatcher._import_and_reload(model.get_base_model()) except AttributeError: @@ -460,7 +502,6 @@ def summary(raw: bool = False): # ------------------------ function ----------------------- - def patch_model(model: torch.nn.Module, **kwargs): ModelPatcher.patch(model, **kwargs) return model @@ -471,17 +512,26 @@ def patch_model_summary(): def combine_triggers(*triggers: ModelPatcherTrigger, logic: str = "OR"): - assert logic == "OR", "only OR logic implemented for combining triggers" + assert logic in ["AND", "OR"], "Only `AND`, `OR` logic implemented for combining triggers" # NOTE: this can be probably simplified def _or_logic(*args, **kwargs): for trig in triggers: - if trig.check(*args, **kwargs): + if trig.is_triggered(*args, **kwargs): return True return False - return ModelPatcherTrigger(check=_or_logic) + def _and_logic(*args, **kwargs): + for trig in triggers: + if not trig.is_triggered(*args, **kwargs): + return False + return True + + _logic = _or_logic + if logic == "AND": + _logic = _and_logic + return ModelPatcherTrigger(check=_logic) def combine_functions(*funcs: Callable, logic: str = "APPEND"): assert logic == "APPEND", "only APPEND logic implemented for combining functions" diff --git a/plugins/framework/src/fms_acceleration/utils/test_utils.py b/plugins/framework/src/fms_acceleration/utils/test_utils.py index 3cc4004f..929c61e3 100644 --- a/plugins/framework/src/fms_acceleration/utils/test_utils.py +++ b/plugins/framework/src/fms_acceleration/utils/test_utils.py @@ -180,3 +180,11 @@ def dummy_augmentation(self, model, train_args, modifiable_args): def dummy_custom_loader(self, model_name, **kwargs): "dummy custom loader returning dummy model" return create_noop_model_with_archs(archs=["DummyModel"]) # + +@contextmanager +def instantiate_model_patcher(): + from fms_acceleration.model_patcher import ModelPatcher # pylint: disable=import-outside-toplevel + old_registrations = ModelPatcher.rules + ModelPatcher.rules = {} + yield + ModelPatcher.rules = old_registrations diff --git a/plugins/framework/tests/model_patcher_fixtures/__init__.py b/plugins/framework/tests/model_patcher_fixtures/__init__.py new file mode 100644 index 00000000..a211ad5c --- /dev/null +++ b/plugins/framework/tests/model_patcher_fixtures/__init__.py @@ -0,0 +1,13 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/plugins/framework/tests/model_patcher_fixtures/module1/__init__.py b/plugins/framework/tests/model_patcher_fixtures/module1/__init__.py new file mode 100644 index 00000000..46e6ebb6 --- /dev/null +++ b/plugins/framework/tests/model_patcher_fixtures/module1/__init__.py @@ -0,0 +1,16 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .module3 import Module3Class +from .module1_1 import Module1Class, mod_1_function diff --git a/plugins/framework/tests/model_patcher_fixtures/module1/module1_1.py b/plugins/framework/tests/model_patcher_fixtures/module1/module1_1.py new file mode 100644 index 00000000..07a5b86a --- /dev/null +++ b/plugins/framework/tests/model_patcher_fixtures/module1/module1_1.py @@ -0,0 +1,23 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..module2 import Module2Class + +class Module1Class: + def __init__(self) -> None: + self.attribute = Module2Class() + +def mod_1_function(): + return "unpatched_mod_function" + \ No newline at end of file diff --git a/plugins/framework/tests/model_patcher_fixtures/module1/module3/__init__.py b/plugins/framework/tests/model_patcher_fixtures/module1/module3/__init__.py new file mode 100644 index 00000000..eb882843 --- /dev/null +++ b/plugins/framework/tests/model_patcher_fixtures/module1/module3/__init__.py @@ -0,0 +1,15 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .module3_1 import Module3Class diff --git a/plugins/framework/tests/model_patcher_fixtures/module1/module3/module3_1.py b/plugins/framework/tests/model_patcher_fixtures/module1/module3/module3_1.py new file mode 100644 index 00000000..09108981 --- /dev/null +++ b/plugins/framework/tests/model_patcher_fixtures/module1/module3/module3_1.py @@ -0,0 +1,19 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..module1_1 import mod_1_function + +class Module3Class: + def __init__(self) -> None: + self.attribute = mod_1_function diff --git a/plugins/framework/tests/model_patcher_fixtures/module2.py b/plugins/framework/tests/model_patcher_fixtures/module2.py new file mode 100644 index 00000000..d866ac51 --- /dev/null +++ b/plugins/framework/tests/model_patcher_fixtures/module2.py @@ -0,0 +1,17 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class Module2Class: + def __init__(self) -> None: + pass diff --git a/plugins/framework/tests/model_patcher_fixtures/module4/__init__.py b/plugins/framework/tests/model_patcher_fixtures/module4/__init__.py new file mode 100644 index 00000000..d1a1e40c --- /dev/null +++ b/plugins/framework/tests/model_patcher_fixtures/module4/__init__.py @@ -0,0 +1,22 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .module4_1 import mod_4_function +from .module5 import Module5Class, mod_5_function +import torch + +class Module4Class(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.attribute = Module5Class() diff --git a/plugins/framework/tests/model_patcher_fixtures/module4/module4_1.py b/plugins/framework/tests/model_patcher_fixtures/module4/module4_1.py new file mode 100644 index 00000000..3b302c1f --- /dev/null +++ b/plugins/framework/tests/model_patcher_fixtures/module4/module4_1.py @@ -0,0 +1,17 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +def mod_4_function(): + return "unpatched_mod_function" + \ No newline at end of file diff --git a/plugins/framework/tests/model_patcher_fixtures/module4/module5/__init__.py b/plugins/framework/tests/model_patcher_fixtures/module4/module5/__init__.py new file mode 100644 index 00000000..cf0bb8e2 --- /dev/null +++ b/plugins/framework/tests/model_patcher_fixtures/module4/module5/__init__.py @@ -0,0 +1,15 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .module5_1 import Module5Class, mod_5_function diff --git a/plugins/framework/tests/model_patcher_fixtures/module4/module5/module5_1.py b/plugins/framework/tests/model_patcher_fixtures/module4/module5/module5_1.py new file mode 100644 index 00000000..fbfa408b --- /dev/null +++ b/plugins/framework/tests/model_patcher_fixtures/module4/module5/module5_1.py @@ -0,0 +1,23 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +class Module5Class(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + +def mod_5_function(): + return "unpatched_mod_function" + \ No newline at end of file diff --git a/plugins/framework/tests/model_patcher_test_utils.py b/plugins/framework/tests/model_patcher_test_utils.py new file mode 100644 index 00000000..30ccb6cb --- /dev/null +++ b/plugins/framework/tests/model_patcher_test_utils.py @@ -0,0 +1,65 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import os +import sys +import importlib +from contextlib import contextmanager +from typing import Dict, Any, Type + +ROOT = 'tests.model_patcher_fixtures' +MODULE_PATHS = [] +for root, dirs, files in os.walk(ROOT.replace('.', os.path.sep)): + for f in files: + filename, ext = os.path.splitext(f) + if ext != ".py": + continue + if filename != '__init__': + p = os.path.join(root, filename) + else: + p = root + + MODULE_PATHS.append(p.replace(os.path.sep, '.')) + +@contextmanager +def isolate_test_module_fixtures(): + old_mod = { + k: sys.modules[k] for k in MODULE_PATHS if k in sys.modules + } + yield + + # Reload only reloads the speicified module, but makes not attempt to reload + # the imports of that module. + # - i.e., This moeans that if and import had been changed + # then the reload will take the changed import. + # - i.e., This also means that the individuals must be reloaded seperatedly + # for a complete reset. + # + # Therefore, we need to reload ALL Modules in opposite tree order, meaning that + # the children must be reloaded before their parent + + for key in sorted(old_mod.keys(), key=len, reverse=True): + # Unclear why but needs a reload, to be investigated later + importlib.reload(old_mod[key]) + + +def create_module_class( + class_name: str, + namespaces: Dict[str, Any] = None, + parent_class: Type = torch.nn.Module +): + if namespaces is None: + namespaces = {} + return type(class_name, (parent_class,), namespaces) diff --git a/plugins/framework/tests/test_model_patcher.py b/plugins/framework/tests/test_model_patcher.py new file mode 100644 index 00000000..13d01cf3 --- /dev/null +++ b/plugins/framework/tests/test_model_patcher.py @@ -0,0 +1,339 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Third Party +import pytest # pylint: disable=(import-error + +# First Party +from fms_acceleration.model_patcher import ( + ModelPatcher, + ModelPatcherRule, + ModelPatcherTrigger, + patch_target_module, +) + +from .model_patcher_test_utils import create_module_class, isolate_test_module_fixtures +from .model_patcher_fixtures import module4 +from fms_acceleration.utils.test_utils import instantiate_model_patcher +from .test_model_patcher_helpers import DUMMY_RULE_ID + +#Test patching of model attribute +def test_simple_forward_rule_with_mp_replaces_old_forward(): + """ + Ensure that a child submodule forward function + is patched with a new forward function + + model_patcher_fixtures: + - module1: + - module1_1: + - Module2Class: + - attribute: Module2Class + - mod_1_function + - module3: + - module3_1 + - Module3Class: + - attribute: mod_1_function + - module2: + - Module2Class: + + - module4: + - Module4Class(torch.nn.Module): + - attribute: Module5Class + - module4_1 + - mod_4_function + - module5: + - module5_1 + - Module5Class + - module_5_function + + """ + + def patched_forward_function(X): + return "patched_forward_function" + + # 1. Create an instance of Module4Class as model + # 2. Add a submodule to Module4Class + # 3. Create and register rule to patch forward of submodule class + # 4. Patch model + # 5. Ensure that model's submodule forward is replaced + with isolate_test_module_fixtures(): + with instantiate_model_patcher(): + model = module4.Module4Class() + SubModule1 = create_module_class( + "SubModule1", + namespaces={"forward": lambda self: "unpatched_forward_function"} + ) + model.add_module("submodule_1", SubModule1()) + rule = ModelPatcherRule( + rule_id=DUMMY_RULE_ID, + trigger=ModelPatcherTrigger(check=SubModule1), + forward=patched_forward_function, + ) + ModelPatcher.register(rule) + ModelPatcher.patch(model) + + assert model.submodule_1.forward() == "patched_forward_function" + +def test_import_and_maybe_reload_rule_with_mp_replaces_old_attribute(): + """ + Module4Class has an attribute from Module5Class, + ensure that patching Module5Class with a PatchedModuleClass, + replaces the old attribute in Module4Class + + Module4Class(torch.nn.Module): + - attribute: Module5Class + + """ + # 1. Register rule replacing module5.module5_1.Module5Class with a patched_mod_function + # reload_target is test.model_patcher.fixtures.module4 + # 2. Patch module4.Module4Class with ModelPatcher + # 3. check patched module exist in module4.Module4Class.attribute + PatchedModuleClass = create_module_class( + "PatchedModClass", + ) + + with isolate_test_module_fixtures(): + with instantiate_model_patcher(): + model = module4.Module4Class() + ModelPatcher.register( + ModelPatcherRule( + rule_id=DUMMY_RULE_ID, + import_and_maybe_reload=( + "tests.model_patcher_fixtures.module4.module5.Module5Class", + PatchedModuleClass, + "tests.model_patcher_fixtures.module4", + ), + ) + ) + ModelPatcher.patch(model) + assert isinstance(module4.Module4Class().attribute, PatchedModuleClass) + +def test_mp_throws_error_with_multiple_reloads_on_same_target(): + """ + Simulate a case where two rules attempt to reload on the same target prefix + + example: + - Rule 1 target path 1: x.y.z + - Rule 2 target path 2: x.y + + this MIGHT reverse the patch on Rule 1 and needs to be prevented + + model_patcher_fixtures: + - module1: + - module1_1: + - Module2Class: + - attribute: Module2Class + - mod_1_function + - module3: + - module3_1 + - Module3Class: + - attribute: mod_1_function + - module2: + - Module2Class: + + - module4: + - Module4Class(torch.nn.Module): + - attribute: Module5Class + - module4_1 + - mod_4_function + - module5: + - module5_1 + - Module5Class + - module_5_function + + """ + + PatchedModuleClass = create_module_class( + "PatchedModuleClass", + ) + + def patched_mod_function(): + return "patched_mod_function" + + # Demonstrate how the 2nd patch overwrites the 1st patch if the reload module paths are the same + with isolate_test_module_fixtures(): + # 1st patch on a function + patch_target_module( + "tests.model_patcher_fixtures.module4.module5.module5_1.mod_5_function", + patched_mod_function, + "tests.model_patcher_fixtures.module4.module5", + ) + + assert module4.module5.mod_5_function() == "patched_mod_function" + + # 2nd patch on a class that has a target path that reloads module5 as well + patch_target_module( + "tests.model_patcher_fixtures.module4.module5.module5_1.Module5Class", + PatchedModuleClass, + "tests.model_patcher_fixtures.module4.module5" + ) + + assert isinstance(module4.module5.Module5Class(), PatchedModuleClass) + assert module4.module5.mod_5_function() == "unpatched_mod_function" + + # Ensure that an assertion is raised if target paths + # are a prefixes of another longer target path + with pytest.raises( + AssertionError, + ): + with isolate_test_module_fixtures(): + with instantiate_model_patcher(): + # 1. Initialize a model with module path tests.model_patcher_fixtures.module4 + model = module4.Module4Class() + + # 2. Simulate patching a function in module4.module5.module5_1 + ModelPatcher.register( + ModelPatcherRule( + rule_id=f"{DUMMY_RULE_ID}.2", + import_and_maybe_reload=( + "tests.model_patcher_fixtures.module4.module5.module5_1.mod_5_function", + patched_mod_function, + "tests.model_patcher_fixtures.module4.module5.module5_1", + ), + ) + ) + + # 3. Simulate patching a class in module4.module5.module5_1 + ModelPatcher.register( + ModelPatcherRule( + rule_id=f"{DUMMY_RULE_ID}.1", + import_and_maybe_reload=( + "tests.model_patcher_fixtures.module4.module5.module5_1.Module5Class", + PatchedModuleClass, + "tests.model_patcher_fixtures.module4", + ), + ) + ) + + # while there are occasions repeated reloads along the same target path prefix work, + # it is risky and not guaranteed to work for all cases. + # To prevent the risk of any of the patches conflicting, + # we throw an exception if a shorter target path is a prefix of another + # longer target path + ModelPatcher.patch(model) + +def test_mp_throws_warning_with_multiple_patches(): + """ + Ensure for each module, only one forward patch is implemented on it. + The patch implementation checks if there are multiple forward patch rules + that are applied to the module, only the 1st forward patch rule is applied, + the others will be ignored and a warning will be raised + + In the case of a list of new rules generated by `forward_builder`, it will be + handled similarly since it decomposes to multiple single forward patch rules downstream. + """ + with pytest.warns( + UserWarning, + ): + with isolate_test_module_fixtures(): + with instantiate_model_patcher(): + # 1. Create a model + # 2. Create a submodule to patch on + # 3. Create 1st rule to patch submodule forward function + # 4. Create 2nd rule to patch submodule forward function again + # 5. Throws warning that any subsequent forward patches after + # the 1st patch is ignored + + model = module4.Module4Class() + SubModule1 = create_module_class( + "SubModule1", + namespaces={"forward": lambda self: "unpatched_forward_function"} + ) + model.add_module("submodule_1", SubModule1()) + + ModelPatcher.register( + ModelPatcherRule( + rule_id=DUMMY_RULE_ID+".1", + trigger=ModelPatcherTrigger(check=SubModule1), + forward=lambda self: "patched_forward_function", + ) + ) + ModelPatcher.register( + ModelPatcherRule( + rule_id=DUMMY_RULE_ID+".2", + trigger=ModelPatcherTrigger(check=SubModule1), + forward=lambda self: "patched_forward_function_2", + ) + ) + ModelPatcher.patch(model) + + +def test_forward_builder_rule_with_mp_replaces_old_forward(): + """ + Ensure that patching a model with a rule using forward_builder argument will + replace the children module forwards + """ + def is_module_type_B(module): + if hasattr(module, "B"): + return True + return False + + def is_module_type_C(module): + if hasattr(module, "C"): + return True + return False + + def patched_forward_function(X): + return "patched_forward_function" + + with isolate_test_module_fixtures(): + with instantiate_model_patcher(): + # 1. Create Model and 3 different child submodules + # 2. Create the forward builder function to produce a list of + # (trigger obj, patched forwards) for each child module in model + # 3. Create rule on model class to patch the submodules using a forward_builder function + # 4. Ensure all submodule forwards are patched + + SubModule1 = create_module_class( + "SubModule1", namespaces={"forward": lambda X: "unpatched_forward_function"} + ) + SubModule1A = create_module_class( + "SubModule1A", parent_class=SubModule1, namespaces={"A": "attributeA"} + ) + SubModule1B = create_module_class( + "SubModule1B", parent_class=SubModule1, namespaces={"B": "attributeB"} + ) + SubModule2 = create_module_class( + "SubModule2", + namespaces={"C": "attributeC", "forward": lambda X: "unpatched_forward_function"} + ) + + model = module4.module5.Module5Class() + model.add_module("submodule_1A", SubModule1A()) + model.add_module("submodule_1B", SubModule1B()) + model.add_module("submodule_2", SubModule2()) + + # Function to create different triggers for different submodules + def build_list_of_triggers( + module, + ): + return [ + (ModelPatcherTrigger(check=SubModule1A), patched_forward_function), + (ModelPatcherTrigger(check=is_module_type_B), patched_forward_function), + (ModelPatcherTrigger(check=is_module_type_C), patched_forward_function), + ] + + ModelPatcher.register( + ModelPatcherRule( + rule_id=DUMMY_RULE_ID, + trigger=ModelPatcherTrigger(check=module4.module5.Module5Class), + forward_builder=build_list_of_triggers, + ) + ) + + ModelPatcher.patch(model) + + for _, mod in model.named_children(): + if hasattr(mod, "forward"): + assert mod.forward() == "patched_forward_function" diff --git a/plugins/framework/tests/test_model_patcher_helpers.py b/plugins/framework/tests/test_model_patcher_helpers.py new file mode 100644 index 00000000..6713aced --- /dev/null +++ b/plugins/framework/tests/test_model_patcher_helpers.py @@ -0,0 +1,390 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Third Party +import pytest # pylint: disable=(import-error +import torch + +# First Party +from fms_acceleration.model_patcher import ( + ModelPatcherRule, + ModelPatcherTrigger, + patch_target_module, + ModelPatcherTriggerType, + combine_triggers, +) + +from .model_patcher_test_utils import create_module_class, isolate_test_module_fixtures +from .model_patcher_fixtures import module1 + +MOD_CLS_A = create_module_class("MOD_CLS_A") +MOD_SUBCLS_A = create_module_class("MOD_SUBCLS_A", parent_class=MOD_CLS_A) +MOD_CLS_B = create_module_class("MOD_CLS_B") + +def returns_false(*args, **kwargs): + "falsy function" + return False + +def returns_true(*args, **kwargs): + "truthy function" + return True + +DUMMY_RULE_ID = "test_patch" + +# | ------------------ Test ModelPatcherTrigger ----------------------- | + +def test_mp_trigger_constructs_with_check_arg_only(): + "Test construction of trigger with check argument" + # Test that error is raised when check is not of accepted type + with pytest.raises( + TypeError, + match = "check argument needs to be torch.nn.Module or Callable" + ): + ModelPatcherTrigger(check=None) + + # Test module trigger type is correctly inferred from check + trigger = ModelPatcherTrigger(check=torch.nn.Module) + assert trigger.type == ModelPatcherTriggerType.module + + # Test callable trigger type is correctly inferred from check + trigger = ModelPatcherTrigger(check=returns_true) + assert trigger.type == ModelPatcherTriggerType.callable + +def test_mp_trigger_constructs_with_check_and_trigger_type_args(): + "Test construction of trigger with check and type arguments" + # check that trigger constructs successfully as check conforms to specified type + ModelPatcherTrigger( + check=torch.nn.Module, + type=ModelPatcherTriggerType.module, + ) + + ModelPatcherTrigger( + check=returns_true, + type=ModelPatcherTriggerType.callable, + ) + + # Ensure an error is raised when check is callable but type is module + with pytest.raises( + AssertionError, + match = "type argument passed but `check` argument does not match type specified", + ): + ModelPatcherTrigger( + check=returns_true, + type=ModelPatcherTriggerType.module, + ) + + # Ensure an error is raised when check is module but type is callable + with pytest.raises( + AssertionError, + match = "type argument passed but `check` argument does not match type specified", + ): + ModelPatcherTrigger( + check=torch.nn.Module, + type=ModelPatcherTriggerType.callable, + ) + +def test_mp_trigger_correctly_triggers(): + "Test for correctnness of trigger behaviour" + + ModClassA = create_module_class( + "ModClassA", + namespaces={"attr_1": None}, + ) + + ModClassB = create_module_class( + "ModClassB", + ) + + ModSubClassA = create_module_class( + "ModSubClassA", + parent_class=ModClassA, + ) + + # Scenario 1: + # if check is a Callable, is_triggered result must be equal to the boolean output of check + # 1. create function to check that returns true if module has attribute `attr_1`, + # otherwise return False + # 2. create trigger that checks the above function + # 3. create a subclass of module_A and ensure is_triggered returns True + # 4. create a module_B and ensure is_triggered returns False + def check_module(module): + if hasattr(module, "attr_1"): + return True + return False + + assert ModelPatcherTrigger(check=check_module).is_triggered( + ModClassA(), + ) is True + + assert ModelPatcherTrigger(check=check_module).is_triggered( + ModClassB(), + ) is False + + # Scenario 2: + # Ensure return True, if is not an instance of ModelPatcherTrigger.check + # 1. create trigger that checks for ModClassA + # 2. create an instance of ModClassA and check is_triggered returns True + # 3. create a subclass instance of ModClassA and check is_triggered returns True + # 4. create an instance of ModClassB and check is_triggered returns False + assert ModelPatcherTrigger(check=ModClassA).is_triggered( + ModClassA(), + ) is True + + assert ModelPatcherTrigger(check=ModClassA).is_triggered( + ModSubClassA(), + ) is True + + # Ensure returns False, if is not an instance of ModelPatcherTrigger.check + assert ModelPatcherTrigger(check=ModClassA).is_triggered( + ModClassB(), + ) is False + + # Scenario 3: + # Static check to ensure additional constraint is checked + # 1. create an instance of ModClassA as model + # 2. register 2 submodules instances of ModClassB, Submodule_1 and SubModule_2 + # 3. create a trigger that checks for an instance of module_B and `submodule_1` module name + # 4. for each module in model, ensure returns true if trigger detects module, + # otherwise it should return false + + # Create model + model = ModClassA() + # register submodules + model.add_module("submodule_1", ModClassB()) + model.add_module("submodule_2", ModClassB()) + # create trigger with search criteria + trigger = ModelPatcherTrigger(check=ModClassB, module_name="submodule_1") + # iterate through modules in model + for name, module in model.named_modules(): + if name == "submodule_1": + # assert that is_triggered returns true when module is found + assert trigger.is_triggered(module, name) is True + else: + # assert that is_triggered otherwise returns false + assert trigger.is_triggered(module, name) is False + +# Each test instance has +# - target_module, +# - tuple of trigger check arguments +# - a logic operator string +# - expected result as either a boolean or an error tuple +# 1. Instantiate list of triggers from tuple of trigger check arguments +# 2. construct a combined trigger given list of triggers and logic +# 3. if expected_result is a tuple, ensure an error is raised upon constructing the trigger +# 4. Otherwise, ensure that the combined_trigger returns the expected result on the target module +@pytest.mark.parametrize( + "target_module,trigger_checks,logic,expected_result", [ + [MOD_SUBCLS_A(), (returns_true, MOD_CLS_B), "OR", True], + [MOD_SUBCLS_A(), (MOD_CLS_B, returns_false), "OR", False], + [MOD_SUBCLS_A(), (MOD_CLS_A, returns_true), "OR", True], + [MOD_CLS_B(), (returns_false, MOD_CLS_A), "AND", False], + [MOD_CLS_B(), (MOD_CLS_B, returns_false), "AND", False], + [MOD_CLS_B(), (MOD_CLS_B, returns_true), "AND", True], + [ + MOD_SUBCLS_A(), (MOD_CLS_B, MOD_CLS_A), "NOR", + (AssertionError, "Only `AND`, `OR` logic implemented for combining triggers") + ], +]) +def test_combine_mp_triggers_produces_correct_output( + target_module, + trigger_checks, + logic, + expected_result +): + triggers = [ModelPatcherTrigger(check=check) for check in trigger_checks] + + # if expected_result is a tuple of (Exception, Exception_message) + if isinstance(expected_result, tuple): + with pytest.raises( + expected_result[0], + match=expected_result[1], + ): + combine_triggers( + *triggers, + logic=logic, + ) + else: # otherwise ensure is_triggered output returns the expected_result + assert combine_triggers( + *triggers, + logic=logic, + ).is_triggered(target_module) is expected_result + + +def test_mp_rule_raises_error_when_arguments_incorrectly_configured(): + "Ensure MP rule is throws appropriate error when wrong argument combinations are passed" + # Test mp rule construction raises with multiple arguments + with pytest.raises( + ValueError, + match="must only have only one of forward, " \ + "foward builder, or import_and_maybe_reload, specified." + ): + ModelPatcherRule( + rule_id=DUMMY_RULE_ID, + forward=lambda self, X: X, + import_and_maybe_reload=(), + forward_builder=lambda self, X: X, + ) + + # Test mp rule construction raises with trigger and import_and_reload + with pytest.raises( + ValueError, + match="has import_and_maybe_reload specified, " \ + "and trigger must be None." + ): + ModelPatcherRule( + rule_id=DUMMY_RULE_ID, + trigger=ModelPatcherTrigger(check=torch.nn.Module), + import_and_maybe_reload=(), + ) + + # Test that rule construction raises forward_builder_args are provided + # without a forward_builder, this can be the case when user passes in a + # forward instead of forward_builder + with pytest.raises( + ValueError, + match="has forward_builder_args but no " \ + "forward_builder." + ): + ModelPatcherRule( + rule_id=DUMMY_RULE_ID, + forward = lambda self, X: X, + forward_builder_args=[] + ) + +def test_patch_target_module_replaces_module_or_function_correctly(): + """ + Test patching of standalone file functions + + Fixtures Class Structure + + model_patcher_fixtures: + - module1: + - module1_1: + - Module2Class: + - attribute: Module2Class + - mod_1_function + - module3: + - module3_1 + - Module3Class: + - attribute: mod_1_function + - module2: + - Module2Class: + + - module4: + - Module4Class: + - attribute: mod_1_function + + + """ + + PatchedModuleClass = create_module_class( + "PatchedModClass", + ) + + def patched_mod_function(): + return "patched_mod_function" + + # S1 - module1_1 has function mod_1_function + # 1. Replace module1_1.mod_1_function with new function + # 2. Ensure patch_target_module replaces with a new function + with isolate_test_module_fixtures(): + patch_target_module( + "tests.model_patcher_fixtures.module1.module1_1.mod_1_function", + patched_mod_function, + "tests.model_patcher_fixtures.module1", + ) + assert module1.mod_1_function() == "patched_mod_function" + + # # test patches are reset outside the context manager + assert module1.mod_1_function() == "unpatched_mod_function" + + # S2 - module1_1.Module1Class has an attribute module2.Module2Class + # 1. Replace Module2Class with new class and reload module1_1 + # 2. Ensure patch_target_module replaces the attribute with a new attr class + with isolate_test_module_fixtures(): + patch_target_module( + "tests.model_patcher_fixtures.module2.Module2Class", + PatchedModuleClass, + "tests.model_patcher_fixtures.module1.module1_1" + ) + assert isinstance(module1.Module1Class().attribute, PatchedModuleClass) + + # check the the fixture isolation works + assert not isinstance(module1.Module1Class().attribute, PatchedModuleClass) + + # S3.1 - module1.module3.module3_1 is a submodule of module1 + # 1. Replace module1.module3.module3_1.Module3Class with a new class + # 2. No target reploading + # - this test shows that a replacement only affects the EXACT PATH that was patched + with isolate_test_module_fixtures(): + patch_target_module( + "tests.model_patcher_fixtures.module1.module3.module3_1.Module3Class", + PatchedModuleClass, + ) + + # - this is the exact module path that was patched, so it will reflect the patched class + assert isinstance(module1.module3.module3_1.Module3Class(), PatchedModuleClass) + + # - this is the top-level module path, and shows that upper level paths will be + # be affected + assert not isinstance(module1.module3.Module3Class(), PatchedModuleClass) + + # S3.2 - module1.module3.module3_1 is a submodule of module1 + # 1. Replace module1.module3.module3_1.Module3Class with a new class + # 2. reload the top-level module path module1 + # -> NOTE: in general, we should avoid targeting any parent paths + # for reload + with isolate_test_module_fixtures(): + patch_target_module( + "tests.model_patcher_fixtures.module1.module3.module3_1.Module3Class", + PatchedModuleClass, + "tests.model_patcher_fixtures.module1", + ) + + # - the reload of the top level module path module1, will NOT replace module1.module3 + # with the original version + # - reloading top-level paths is tricky due to caching of the modules + # - the reload of a top-level module does not cascade down to children modules. + assert not isinstance(module1.module3.module3_1.Module3Class(), PatchedModuleClass) + + # S3.3 - module1.module3 is a submodule of module1 + # 1. Replace module1.module3.module3_1.Module3Class with a new class + # 2. reload the top-level module path module1 + # -> NOTE: in general, we should avoid targeting any parent paths + # for reload + with isolate_test_module_fixtures(): + patch_target_module( + "tests.model_patcher_fixtures.module1.module3.Module3Class", + PatchedModuleClass, + "tests.model_patcher_fixtures.module1", + ) + + # - the reload of the top level module path module1, will replace module1.module3 + # with the original version + assert not isinstance(module1.module3.module3_1.Module3Class(), PatchedModuleClass) + + # S4 - module1.module3 submodule has a dependency on + # module1.module1_1.mod_1_function + # 1. Replace the module1.module1_1.mod_1_function with a new function + # 2. Ensure the target reloading of module1.module3 picks up the patched function + with isolate_test_module_fixtures(): + patch_target_module( + "tests.model_patcher_fixtures.module1.module1_1.mod_1_function", + patched_mod_function, + "tests.model_patcher_fixtures.module1.module3.module3_1", + ) + assert module1.module3.Module3Class().attribute() == "patched_mod_function" diff --git a/plugins/fused-ops-and-kernels/pyproject.toml b/plugins/fused-ops-and-kernels/pyproject.toml index 2b2aef78..1355634a 100644 --- a/plugins/fused-ops-and-kernels/pyproject.toml +++ b/plugins/fused-ops-and-kernels/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "fms-acceleration-foak" -version = '0.0.1' +version = '0.1.1.dev' description = "FMS Acceleration using Fused Operations and Kernels" authors = [ {name = "Fabian Lim", email = "flim@sg.ibm.com"}, @@ -22,7 +22,7 @@ classifiers=[ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", ] -dependencies = ['pandas'] +dependencies = [] [tool.hatch.build.targets.wheel] only-include = ["src/fms_acceleration_foak"] diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py index 01a5b4b7..d2abd5b1 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py @@ -13,7 +13,7 @@ # limitations under the License. # Standard -from typing import Callable, Dict, Tuple +from typing import Dict, Tuple # Third Party from accelerate.utils import set_module_tensor_to_device @@ -21,33 +21,9 @@ from peft import LoraConfig from peft.tuners.lora.layer import LoraLayer from transformers import TrainingArguments -from transformers.utils import logging import torch import torch.distributed as dist -# want to use the transformers logger, but a bit of pain -logger = logging.get_logger(__name__) # pylint: disable=invalid-name -logger.setLevel(logging._get_default_logging_level()) -logger.addHandler(logging._default_handler) - - -def log_patch_summary( - logging_func: Callable = None, -): - if logging_func is None: - logging_func = print - - # this is a guarded import, because the model rule registration - # does not need to be loaded unless patch_model is required - # Local - from .models.model_patcher import ( # pylint: disable=import-outside-toplevel - patch_model_summary, - ) - - for line in patch_model_summary().split("\n"): - logging_func(line) - - # consider moving this somewhere else later def lora_adapters_switch_ddp_from_fsdp(modules, fsdp_plugin): """ @@ -82,6 +58,16 @@ def _all_reduce_hook(grad): if not B.weight.is_cuda: set_module_tensor_to_device(B, "weight", "cuda") +def register_foak_model_patch_rules(base_type): + from fms_acceleration.model_patcher import ModelPatcher # pylint: disable=import-outside-toplevel + from .models import llama, mistral, mixtral # pylint: disable=import-outside-toplevel + rules = [ + *llama.get_mp_rules(base_type), + *mistral.get_mp_rules(base_type), + *mixtral.get_mp_rules(base_type), + ] + for _rule in rules: + ModelPatcher.register(_rule) class FastQuantizedPeftAccelerationPlugin(AccelerationPlugin): @@ -135,26 +121,14 @@ def augmentation( model.dtype == torch.float16 and train_args.fp16 ), "need to run in fp16 mixed precision or load model in fp16" - # this is a guarded import, because the model rule registration - # does not need to be loaded unless patch_model is required - # Local - from .models.model_patcher import ( # pylint: disable=import-outside-toplevel - patch_model, - ) - - model = patch_model(model, base_type=self._base_layer) + # wrapper function to register foak patches + register_foak_model_patch_rules(base_type = self._base_layer) return model, modifiable_args def get_callbacks_and_ready_for_train( self, model: torch.nn.Module = None, accelerator=None ): - # if this is moved to framework, it can be handled as the same way as - # log_initialization_message - # log the patch summary - if accelerator is not None and accelerator.is_main_process: - log_patch_summary(logging_func=logger.info) - callbacks = [] if ( accelerator is not None diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/__init__.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/__init__.py index ebd49924..38a9531e 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/__init__.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/__init__.py @@ -11,14 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Local -from .model_patcher import ModelPatcher - -PATCHES = [".models.llama", ".models.mistral", ".models.mixtral"] -PLUGIN_PREFIX = "fms_acceleration_foak" - -# TODO: remove the need for the prefix -ModelPatcher.load_patches( - [f"{PLUGIN_PREFIX}{postfix}" for postfix in PATCHES], -) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py index 290d1217..a934fc1e 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py @@ -21,114 +21,109 @@ LlamaMLP, LlamaRMSNorm, ) - -# Local -from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss -from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm -from ..kernels.unsloth.rope_embedding import fast_rope_embedding -from .model_patcher import ( - ModelPatcher, +from fms_acceleration.model_patcher import ( ModelPatcherRule, ModelPatcherTrigger, combine_functions, combine_triggers, ) -from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops -# TODO: have a generic version of this rule -# - do regex on RMSNorm class name -# - check on the tensors required for fast_rms_layernorm -ModelPatcher.register( - ModelPatcherRule( - rule_id="llama-rms", - trigger=ModelPatcherTrigger(check=LlamaRMSNorm), - forward=fast_rms_layernorm, - ), -) +# Local +from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss +from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm +from ..kernels.unsloth.rope_embedding import fast_rope_embedding +from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops -# TODO: have a generic version of this rule -# - do regex on Attention class name -# - have a set of qkv / o module names and check on that -ModelPatcher.register( - ModelPatcherRule( - rule_id="llama-qkvo", - trigger=combine_triggers( - ModelPatcherTrigger( - check=partial( - trigger_fused_ops, - attn_cls=LlamaAttention, +def get_mp_rules(base_type: str): + """ + Function to access all patch rules in this module. + If it is a forward_builder rule with `base_type` in + its forward builder argument, wrap the forward_builder + function as a partial function with the base_type argument + """ + return [ + # TODO: have a generic version of this rule + # - do regex on RMSNorm class name + # - check on the tensors required for fast_rms_layernorm + ModelPatcherRule( + rule_id="llama-rms", + trigger=ModelPatcherTrigger(check=LlamaRMSNorm), + forward=fast_rms_layernorm, + ), + # TODO: have a generic version of this rule + # - do regex on Attention class name + # - have a set of qkv / o module names and check on that + ModelPatcherRule( + rule_id="llama-qkvo", + trigger=combine_triggers( + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=LlamaAttention, + submodule_names=["q_proj", "k_proj", "v_proj"], + ) + ), + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=LlamaAttention, + submodule_names=["o_proj"], + ) + ), + logic="OR", + ), + forward_builder=combine_functions( + partial( + build_lora_fused_ops, submodule_names=["q_proj", "k_proj", "v_proj"], - ) + fused_op=KEY_QKV, + base_type=base_type, + ), + partial( + build_lora_fused_ops, + submodule_names=["o_proj"], + fused_op=KEY_O, + base_type=base_type, + ), + logic="APPEND", ), - ModelPatcherTrigger( + ), + ModelPatcherRule( + rule_id="llama-mlp", + trigger=ModelPatcherTrigger( check=partial( trigger_fused_ops, - attn_cls=LlamaAttention, - submodule_names=["o_proj"], + attn_cls=LlamaMLP, + submodule_names=["up_proj", "down_proj", "gate_proj"], ) ), - logic="OR", - ), - forward_builder=combine_functions( - partial( + forward_builder=partial( build_lora_fused_ops, - submodule_names=["q_proj", "k_proj", "v_proj"], - fused_op=KEY_QKV, - ), - partial( - build_lora_fused_ops, - submodule_names=["o_proj"], - fused_op=KEY_O, - ), - logic="APPEND", - ), - forward_builder_args=["base_type"], - ) -) - -ModelPatcher.register( - ModelPatcherRule( - rule_id="llama-mlp", - trigger=ModelPatcherTrigger( - check=partial( - trigger_fused_ops, - attn_cls=LlamaMLP, submodule_names=["up_proj", "down_proj", "gate_proj"], - ) - ), - forward_builder=partial( - build_lora_fused_ops, - submodule_names=["up_proj", "down_proj", "gate_proj"], - fused_op=KEY_MLP, - ), - forward_builder_args=["base_type"], - ) -) - -# TODO: have a generic version of this rule -# - get the module_name and reload on that -ModelPatcher.register( - ModelPatcherRule( - rule_id="llama-cross-ent", - import_and_maybe_reload=( - "torch.nn.CrossEntropyLoss", - FastCrossEntropyLoss, - "transformers.models.llama.modeling_llama", + fused_op=KEY_MLP, + base_type=base_type, + ), ), - ) -) - -# TODO: have a generic version of this rule -# - get the module name -# - check if "apply_rotary_pos_emb" exists -# - patch -ModelPatcher.register( - ModelPatcherRule( - rule_id="llama-rope", - import_and_maybe_reload=( - "transformers.models.llama.modeling_llama.apply_rotary_pos_emb", - fast_rope_embedding, - None, + # TODO: have a generic version of this rule + # - get the module_name and reload on that + ModelPatcherRule( + rule_id="llama-cross-ent", + import_and_maybe_reload=( + "torch.nn.CrossEntropyLoss", + FastCrossEntropyLoss, + "transformers.models.llama.modeling_llama", + ), ), - ) -) + # TODO: have a generic version of this rule + # - get the module name + # - check if "apply_rotary_pos_emb" exists + # - patch + ModelPatcherRule( + rule_id="llama-rope", + import_and_maybe_reload=( + "transformers.models.llama.modeling_llama.apply_rotary_pos_emb", + fast_rope_embedding, + None, + ), + ) + ] diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py index 37809fd1..d090da5f 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py @@ -21,104 +21,101 @@ MistralMLP, MistralRMSNorm, ) - -# Local -from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss -from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm -from ..kernels.unsloth.rope_embedding import fast_rope_embedding -from .model_patcher import ( - ModelPatcher, +from fms_acceleration.model_patcher import ( ModelPatcherRule, ModelPatcherTrigger, combine_functions, combine_triggers, ) + + +# Local +from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss +from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm +from ..kernels.unsloth.rope_embedding import fast_rope_embedding from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops -# - do regex on RMSNorm class name -# - check on the tensors required for fast_rms_layernorm -ModelPatcher.register( - ModelPatcherRule( - rule_id="mistral-rms", - trigger=ModelPatcherTrigger(check=MistralRMSNorm), - forward=fast_rms_layernorm, - ), -) +def get_mp_rules(base_type): + """ + Function to access all patch rules in this module. + If it is a forward_builder rule with `base_type` in + its forward builder argument, wrap the forward_builder + function as a partial function with the base_type argument + """ -ModelPatcher.register( - ModelPatcherRule( - rule_id="mistral-qkvo", - trigger=combine_triggers( - ModelPatcherTrigger( - check=partial( - trigger_fused_ops, - attn_cls=MistralAttention, + return [ + # - do regex on RMSNorm class name + # - check on the tensors required for fast_rms_layernorm + ModelPatcherRule( + rule_id="mistral-rms", + trigger=ModelPatcherTrigger(check=MistralRMSNorm), + forward=fast_rms_layernorm, + ), + ModelPatcherRule( + rule_id="mistral-qkvo", + trigger=combine_triggers( + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=MistralAttention, + submodule_names=["q_proj", "k_proj", "v_proj"], + ) + ), + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=MistralAttention, + submodule_names=["o_proj"], + ) + ), + logic="OR", + ), + forward_builder=combine_functions( + partial( + build_lora_fused_ops, submodule_names=["q_proj", "k_proj", "v_proj"], - ) + fused_op=KEY_QKV, + base_type=base_type, + ), + partial( + build_lora_fused_ops, + submodule_names=["o_proj"], + fused_op=KEY_O, + base_type=base_type, + ), + logic="APPEND", ), - ModelPatcherTrigger( + ), + ModelPatcherRule( + rule_id="mistral-mlp", + trigger=ModelPatcherTrigger( check=partial( trigger_fused_ops, - attn_cls=MistralAttention, - submodule_names=["o_proj"], + attn_cls=MistralMLP, + submodule_names=["up_proj", "down_proj", "gate_proj"], ) ), - logic="OR", - ), - forward_builder=combine_functions( - partial( + forward_builder=partial( build_lora_fused_ops, - submodule_names=["q_proj", "k_proj", "v_proj"], - fused_op=KEY_QKV, - ), - partial( - build_lora_fused_ops, - submodule_names=["o_proj"], - fused_op=KEY_O, - ), - logic="APPEND", - ), - forward_builder_args=["base_type"], - ) -) - -ModelPatcher.register( - ModelPatcherRule( - rule_id="mistral-mlp", - trigger=ModelPatcherTrigger( - check=partial( - trigger_fused_ops, - attn_cls=MistralMLP, submodule_names=["up_proj", "down_proj", "gate_proj"], - ) - ), - forward_builder=partial( - build_lora_fused_ops, - submodule_names=["up_proj", "down_proj", "gate_proj"], - fused_op=KEY_MLP, - ), - forward_builder_args=["base_type"], - ) -) - -ModelPatcher.register( - ModelPatcherRule( - rule_id="mistral-cross-ent", - import_and_maybe_reload=( - "torch.nn.CrossEntropyLoss", - FastCrossEntropyLoss, - "transformers.models.mistral.modeling_mistral", + fused_op=KEY_MLP, + base_type=base_type, + ), ), - ) -) - -ModelPatcher.register( - ModelPatcherRule( - rule_id="mistral-rope", - import_and_maybe_reload=( - "transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb", - fast_rope_embedding, - None, + ModelPatcherRule( + rule_id="mistral-cross-ent", + import_and_maybe_reload=( + "torch.nn.CrossEntropyLoss", + FastCrossEntropyLoss, + "transformers.models.mistral.modeling_mistral", + ), ), - ) -) + ModelPatcherRule( + rule_id="mistral-rope", + import_and_maybe_reload=( + "transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb", + fast_rope_embedding, + None, + ), + ) + ] diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py index 1522ef8d..7c0c58ab 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py @@ -20,85 +20,85 @@ MixtralAttention, MixtralRMSNorm, ) - -# Local -from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss -from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm -from ..kernels.unsloth.rope_embedding import fast_rope_embedding -from .model_patcher import ( - ModelPatcher, +from fms_acceleration.model_patcher import ( ModelPatcherRule, ModelPatcherTrigger, combine_functions, combine_triggers, ) + +# Local +from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss +from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm +from ..kernels.unsloth.rope_embedding import fast_rope_embedding + from .utils import KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops -# - do regex on RMSNorm class name -# - check on the tensors required for fast_rms_layernorm -ModelPatcher.register( - ModelPatcherRule( - rule_id="mixtral-rms", - trigger=ModelPatcherTrigger(check=MixtralRMSNorm), - forward=fast_rms_layernorm, - ), -) +def get_mp_rules(base_type): + """ + Function to access all patch rules in this module. + If it is a forward_builder rule with `base_type` in + its forward builder argument, wrap the forward_builder + function as a partial function with the base_type argument + """ -ModelPatcher.register( - ModelPatcherRule( - rule_id="mixtral-qkvo", - trigger=combine_triggers( - ModelPatcherTrigger( - check=partial( - trigger_fused_ops, - attn_cls=MixtralAttention, - submodule_names=["q_proj", "k_proj", "v_proj"], - ) + # - do regex on RMSNorm class name + # - check on the tensors required for fast_rms_layernorm + return [ + ModelPatcherRule( + rule_id="mixtral-rms", + trigger=ModelPatcherTrigger(check=MixtralRMSNorm), + forward=fast_rms_layernorm, + ), + ModelPatcherRule( + rule_id="mixtral-qkvo", + trigger=combine_triggers( + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=MixtralAttention, + submodule_names=["q_proj", "k_proj", "v_proj"], + ) + ), + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=MixtralAttention, + submodule_names=["o_proj"], + ) + ), + logic="OR", ), - ModelPatcherTrigger( - check=partial( - trigger_fused_ops, - attn_cls=MixtralAttention, + forward_builder=combine_functions( + partial( + build_lora_fused_ops, + submodule_names=["q_proj", "k_proj", "v_proj"], + fused_op=KEY_QKV, + base_type=base_type, + ), + partial( + build_lora_fused_ops, submodule_names=["o_proj"], - ) + fused_op=KEY_O, + base_type=base_type, + ), + logic="APPEND", ), - logic="OR", ), - forward_builder=combine_functions( - partial( - build_lora_fused_ops, - submodule_names=["q_proj", "k_proj", "v_proj"], - fused_op=KEY_QKV, + ModelPatcherRule( + rule_id="mixtral-cross-ent", + import_and_maybe_reload=( + "torch.nn.CrossEntropyLoss", + FastCrossEntropyLoss, + "transformers.models.mixtral.modeling_mixtral", ), - partial( - build_lora_fused_ops, - submodule_names=["o_proj"], - fused_op=KEY_O, - ), - logic="APPEND", - ), - forward_builder_args=["base_type"], - ) -) - -ModelPatcher.register( - ModelPatcherRule( - rule_id="mixtral-cross-ent", - import_and_maybe_reload=( - "torch.nn.CrossEntropyLoss", - FastCrossEntropyLoss, - "transformers.models.mixtral.modeling_mixtral", - ), - ) -) - -ModelPatcher.register( - ModelPatcherRule( - rule_id="mixtral-rope", - import_and_maybe_reload=( - "transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb", - fast_rope_embedding, - None, ), - ) -) + ModelPatcherRule( + rule_id="mixtral-rope", + import_and_maybe_reload=( + "transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb", + fast_rope_embedding, + None, + ), + ) + ] diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py index 10819fc0..9d624277 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py @@ -16,7 +16,7 @@ from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_mlp as fused_op_mlp_gptq from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_o_v2 as fused_op_o_gptq from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_qkv as fused_op_qkv_gptq -from .model_patcher import ModelPatcherTrigger +from fms_acceleration.model_patcher import ModelPatcherTrigger KEY_QKV = "qkv" KEY_O = "o" diff --git a/plugins/fused-ops-and-kernels/tests/test_fused_ops.py b/plugins/fused-ops-and-kernels/tests/test_fused_ops.py index ded61a2f..237a3a6f 100644 --- a/plugins/fused-ops-and-kernels/tests/test_fused_ops.py +++ b/plugins/fused-ops-and-kernels/tests/test_fused_ops.py @@ -11,7 +11,7 @@ import torch # First Party -from fms_acceleration_foak.models.model_patcher import patch_model +from fms_acceleration.model_patcher import patch_model BNB = "bitsandbytes" GPTQ = "auto_gptq" @@ -119,7 +119,7 @@ def model_inputs(seed: int = 42, device: torch.device = "cuda"): 0, 10000, # most models should have more than 10K (bs, seq_len), - dtype=torch.int, + dtype=torch.long, device=device, ), None, # dont pass in position ids for now diff --git a/scripts/benchmarks/compare_with_reference.py b/scripts/benchmarks/compare_with_reference.py index a580b8de..6a66cebd 100644 --- a/scripts/benchmarks/compare_with_reference.py +++ b/scripts/benchmarks/compare_with_reference.py @@ -111,9 +111,10 @@ def main( columns={"self": "new", "other": "ref"}, level=-1 ) diff = diff[diff.index.isin([outlier for outlier in outliers])] - outliers_df = outliers_df.set_index(indices).merge( - diff, left_index=True, right_index=True - ) + if not diff.empty: + outliers_df = outliers_df.set_index(indices).merge( + diff, left_index=True, right_index=True + ) outliers_df.to_csv(os.path.join(result_dir, OUTLIERS_FILENAME)) for chart, filename in charts: chart.figure.savefig(os.path.join(result_dir, filename)) diff --git a/scripts/benchmarks/refs/a100_80gb.csv b/scripts/benchmarks/refs/a100_80gb.csv index 6bb7714a..37cdcc6d 100644 --- a/scripts/benchmarks/refs/a100_80gb.csv +++ b/scripts/benchmarks/refs/a100_80gb.csv @@ -1,85 +1,85 @@ epoch,fp16,framework_config,learning_rate,lora_alpha,lora_dropout,mem_nvidia_mem_reserved,mem_peak_torch_mem_alloc_in_bytes,mem_torch_mem_alloc_in_bytes,model_name_or_path,num_gpus,peft_method,per_device_train_batch_size,r,target_modules,torch_dtype,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second -0.15,,none,2e-5,,,76671.0,72972297728.0,44005107200.0,mistralai/Mistral-7B-v0.1,1,,4,,,float16,0.9160769081115723,549.026,0.729,0.182,2984.194 -0.15,,none,2e-5,,,43744.0,36763146240.0,29521348608.0,mistralai/Mistral-7B-v0.1,2,,2,,,float16,0.8728336906433105,298.0786,1.342,0.335,2748.269 -0.29,,none,2e-5,,,79365.0,72972690944.0,44005500416.0,mistralai/Mistral-7B-v0.1,1,,8,,,float16,1.001595754623413,1066.0306,0.75,0.094,3073.833 -0.29,,none,2e-5,,,52883.0,36763342848.0,29521545216.0,mistralai/Mistral-7B-v0.1,2,,4,,,float16,0.9138528442382813,552.1771,1.449,0.181,2967.164 +0.15,,none,2e-5,,,76031.0,72435426816.0,43468236288.0,mistralai/Mistral-7B-v0.1,1,,4,,,float16,0.9239591407775879,539.6271,0.741,0.185,3036.17 +0.15,,none,2e-5,,,43610.0,36226242560.0,28984444928.0,mistralai/Mistral-7B-v0.1,2,,2,,,float16,0.878299913406372,297.9576,1.342,0.336,2749.384 +0.29,,none,2e-5,,,78727.0,72435820032.0,43468629504.0,mistralai/Mistral-7B-v0.1,1,,8,,,float16,1.008358039855957,1048.9483,0.763,0.095,3123.891 +0.29,,none,2e-5,,,52837.0,36226439168.0,28984641536.0,mistralai/Mistral-7B-v0.1,2,,4,,,float16,0.917950096130371,554.8154,1.442,0.18,2953.054 ,,none,2e-5,,,80969.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,4,,,float16,,,,, -,,none,2e-5,,,80925.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,2,,,float16,,,,, +,,none,2e-5,,,80921.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,2,,,float16,,,,, ,,none,2e-5,,,80969.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,8,,,float16,,,,, -,,none,2e-5,,,81003.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,4,,,float16,,,,, -,,none,2e-5,,,80987.0,,,NousResearch/Llama-2-70b-hf,1,,4,,,float16,,,,, -,,none,2e-5,,,80922.0,,,NousResearch/Llama-2-70b-hf,2,,2,,,float16,,,,, -,,none,2e-5,,,80987.0,,,NousResearch/Llama-2-70b-hf,1,,8,,,float16,,,,, -,,none,2e-5,,,80863.0,,,NousResearch/Llama-2-70b-hf,2,,4,,,float16,,,,, -0.15,,none,2e-4,16,0.1,28707.0,26109561344.0,15119705600.0,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8970945072174072,458.7158,0.872,0.218,3571.71 -0.15,,none,2e-4,16,0.1,17897.0,15458877440.0,7850448896.0,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8571704006195069,270.088,1.481,0.37,3033.086 -0.29,,none,2e-4,16,0.1,42171.0,37100825088.0,15120098816.0,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9901649284362793,913.5703,0.876,0.109,3586.807 -0.29,,none,2e-4,16,0.1,25659.0,22105014272.0,7850645504.0,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9025015163421631,482.2349,1.659,0.207,3397.514 -,,none,2e-4,16,0.1,80991.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, -0.15,,none,2e-4,16,0.1,61532.0,57898183168.0,47311509504.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8681951332092285,551.3062,0.726,0.181,1485.926 -,,none,2e-4,16,0.1,80991.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, -0.29,,none,2e-4,16,0.1,69436.0,65039245312.0,47311706112.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8880744457244873,924.9663,0.865,0.108,1771.308 -,,none,2e-4,16,0.1,80617.0,,,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, -,,none,2e-4,16,0.1,80756.0,,,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,,,,, -,,none,2e-4,16,0.1,80617.0,,,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, -,,none,2e-4,16,0.1,80851.0,,,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, -0.15,True,baseline-peft-bnb,2e-4,16,0.1,25999.0,23228815360.0,5368450560.0,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8647766017913818,593.395,0.674,0.169,2761.062 -0.15,True,baseline-peft-bnb,2e-4,16,0.1,12818.0,10431547904.0,2781601792.0,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8678814029693603,284.7643,1.405,0.351,2876.765 -0.29,True,baseline-peft-bnb,2e-4,16,0.1,46121.0,41084491776.0,5368843776.0,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.868037691116333,1158.2474,0.691,0.086,2829.102 -0.29,True,baseline-peft-bnb,2e-4,16,0.1,20421.0,17446783488.0,2781798400.0,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8695751667022705,502.2826,1.593,0.199,3261.909 -0.15,True,baseline-peft-bnb,2e-4,16,0.1,47567.0,46825980416.0,25726455296.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8924774932861328,1171.0504,0.342,0.085,1399.086 -0.15,True,baseline-peft-bnb,2e-4,16,0.1,25163.0,22356893696.0,13273817088.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8943204975128174,568.962,0.703,0.176,1439.815 -0.29,True,baseline-peft-bnb,2e-4,16,0.1,69237.0,67906358784.0,25726848512.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8907253837585449,2126.1835,0.376,0.047,1541.165 -0.29,True,baseline-peft-bnb,2e-4,16,0.1,32960.0,30165152256.0,13274013696.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.893255443572998,957.9628,0.835,0.104,1710.296 -,True,baseline-peft-bnb,2e-4,16,0.1,80123.0,,,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, -0.14,True,baseline-peft-bnb,2e-4,16,0.1,52469.0,47591447040.0,19434999808.0,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,1.0088242053985597,1955.5844,0.205,0.051,418.903 -,True,baseline-peft-bnb,2e-4,16,0.1,80581.0,,,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, -,True,baseline-peft-bnb,2e-4,16,0.1,80585.0,,,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, -0.15,True,accelerated-peft-bnb,2e-4,16,0.1,18907.0,15860617728.0,4843499008.0,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8644689750671387,482.8812,0.828,0.207,3392.967 -0.15,True,accelerated-peft-bnb,2e-4,16,0.1,12783.0,10431547904.0,2781601792.0,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8698421669006348,284.1914,1.408,0.352,2882.564 -0.29,True,accelerated-peft-bnb,2e-4,16,0.1,33331.0,26851881472.0,4843892224.0,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8686403369903565,948.8322,0.843,0.105,3453.508 -0.29,True,accelerated-peft-bnb,2e-4,16,0.1,20523.0,17446783488.0,2781798400.0,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8683876323699952,504.0477,1.587,0.198,3250.486 -0.15,True,accelerated-peft-bnb-foak,2e-4,16,0.1,17449.0,14173894656.0,4843499008.0,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8669318771362304,419.9549,0.952,0.238,3901.371 -0.15,True,accelerated-peft-bnb-foak,2e-4,16,0.1,12699.0,10065463808.0,2727075840.0,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8524643421173096,225.0245,1.778,0.444,3640.493 -0.29,True,accelerated-peft-bnb-foak,2e-4,16,0.1,28593.0,23504648192.0,4843892224.0,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8657933044433593,819.2575,0.976,0.122,3999.719 -0.29,True,accelerated-peft-bnb-foak,2e-4,16,0.1,19860.0,16744106496.0,2727272448.0,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8490522384643555,420.7803,1.901,0.238,3893.719 -0.15,True,accelerated-peft-bnb,2e-4,16,0.1,37399.0,36828377600.0,25201503744.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8931312561035156,925.5545,0.432,0.108,1770.182 -0.15,True,accelerated-peft-bnb,2e-4,16,0.1,25216.0,22359233024.0,13273817088.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.892439432144165,570.8031,0.701,0.175,1435.171 -0.29,True,accelerated-peft-bnb,2e-4,16,0.1,49913.0,48447599616.0,25201896960.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8924949169158936,1720.4669,0.465,0.058,1904.599 -0.29,True,accelerated-peft-bnb,2e-4,16,0.1,33214.0,30167236096.0,13274013696.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8910456848144531,961.4325,0.832,0.104,1704.124 -0.15,True,accelerated-peft-bnb-foak,2e-4,16,0.1,36039.0,36153218048.0,25201503744.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8932661628723144,855.0375,0.468,0.117,1916.173 -0.15,True,accelerated-peft-bnb-foak,2e-4,16,0.1,25513.0,22008699904.0,13219291136.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8599378490447998,511.3077,0.782,0.196,1602.166 -0.29,True,accelerated-peft-bnb-foak,2e-4,16,0.1,46959.0,47096648192.0,25201896960.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8929532051086426,1595.4842,0.501,0.063,2053.797 -0.29,True,accelerated-peft-bnb-foak,2e-4,16,0.1,33064.0,29497270272.0,13219487744.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8600027751922608,878.2625,0.911,0.114,1865.502 -0.14,True,accelerated-peft-bnb,2e-4,16,0.1,72701.0,69770819584.0,37347044864.0,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.0098001098632812,3656.7382,0.109,0.027,448.05 -0.14,True,accelerated-peft-bnb,2e-4,16,0.1,52469.0,47591447040.0,19434999808.0,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,1.0093148803710938,1952.1407,0.205,0.051,419.642 -,True,accelerated-peft-bnb,2e-4,16,0.1,79377.0,,,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, -,True,accelerated-peft-bnb,2e-4,16,0.1,80837.0,,,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, -0.14,True,accelerated-peft-bnb-foak,2e-4,16,0.1,71019.0,68424906752.0,37347044864.0,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.0100258159637452,3358.344,0.119,0.03,487.859 -0.14,True,accelerated-peft-bnb-foak,2e-4,16,0.1,51461.0,46787975680.0,19172855808.0,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9600833988189698,1747.5665,0.229,0.057,468.766 -,True,accelerated-peft-bnb-foak,2e-4,16,0.1,80945.0,,,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, -0.28,True,accelerated-peft-bnb-foak,2e-4,16,0.1,80967.0,72795019776.0,19173052416.0,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9324470138549805,3384.1355,0.236,0.03,484.141 -0.15,True,accelerated-peft-autogptq,2e-4,16,0.1,19429.0,15890927104.0,4873808384.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9843631744384765,481.226,0.831,0.208,3404.637 -0.15,True,accelerated-peft-autogptq,2e-4,16,0.1,12860.0,10079847936.0,2798148608.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9855545139312745,282.415,1.416,0.354,2900.695 -0.29,True,accelerated-peft-autogptq,2e-4,16,0.1,33223.0,26882190848.0,4874201600.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9518059539794922,944.0475,0.847,0.106,3471.012 -0.29,True,accelerated-peft-autogptq,2e-4,16,0.1,20472.0,16725984768.0,2798345216.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9537856483459473,497.4081,1.608,0.201,3293.875 -0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,17193.0,13632576512.0,4873808384.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9757871055603027,413.121,0.968,0.242,3965.908 -0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,12703.0,9780872704.0,2743622656.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9560029792785645,221.2793,1.808,0.452,3702.109 -0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,28977.0,22392753152.0,4874201600.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9514095497131347,805.8956,0.993,0.124,4066.035 -0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,19800.0,16157525504.0,2743819264.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9230777645111083,415.3379,1.926,0.241,3944.74 -0.15,True,accelerated-peft-autogptq,2e-4,16,0.1,36387.0,35528691200.0,24511572480.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8990980052947998,885.7851,0.452,0.113,1849.659 -0.15,True,accelerated-peft-autogptq,2e-4,16,0.1,23548.0,21067523584.0,12581313536.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8998581314086914,536.746,0.745,0.186,1526.234 -0.29,True,accelerated-peft-autogptq,2e-4,16,0.1,48905.0,46519954944.0,24511965696.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8961446380615234,1669.0298,0.479,0.06,1963.296 -0.29,True,accelerated-peft-autogptq,2e-4,16,0.1,30516.0,28187328512.0,12581510144.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8947424793243408,921.8778,0.868,0.108,1777.242 -0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,34731.0,34183981056.0,24511572480.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8992811870574952,814.994,0.491,0.123,2010.322 -0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,24177.0,20715718656.0,12526787584.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8655492782592773,475.158,0.842,0.21,1724.058 -0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,45901.0,43758690304.0,24511965696.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.895248155593872,1528.7913,0.523,0.065,2143.393 -0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,31452.0,27526991360.0,12526984192.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8628469562530517,835.9993,0.957,0.12,1959.81 -0.14,True,accelerated-peft-autogptq,2e-4,16,0.1,71181.0,67237753856.0,36290374144.0,TheBloke/Llama-2-70B-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9894431018829346,3599.2898,0.111,0.028,455.201 -0.14,True,accelerated-peft-autogptq,2e-4,16,0.1,51115.0,45806148096.0,18387856384.0,TheBloke/Llama-2-70B-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9900471115112305,1900.1037,0.211,0.053,431.134 -,True,accelerated-peft-autogptq,2e-4,16,0.1,79265.0,,,TheBloke/Llama-2-70B-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, -0.28,True,accelerated-peft-autogptq,2e-4,16,0.1,80813.0,71747131904.0,18388052992.0,TheBloke/Llama-2-70B-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9895571708679199,3672.2631,0.218,0.027,446.155 -0.14,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,69479.0,66160276480.0,36290374144.0,TheBloke/Llama-2-70B-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9900266265869141,3283.8655,0.122,0.03,498.924 -0.14,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,50518.0,45136894464.0,18125712384.0,TheBloke/Llama-2-70B-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9589622497558594,1684.1824,0.238,0.059,486.408 -,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,80301.0,,,TheBloke/Llama-2-70B-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, -0.28,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,79950.0,70539958784.0,18125908992.0,TheBloke/Llama-2-70B-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9595681858062745,3305.9445,0.242,0.03,495.592 +,,none,2e-5,,,79851.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,4,,,float16,,,,, +,,none,2e-5,,,81049.0,,,NousResearch/Llama-2-70b-hf,1,,4,,,float16,,,,, +,,none,2e-5,,,80535.0,,,NousResearch/Llama-2-70b-hf,2,,2,,,float16,,,,, +,,none,2e-5,,,81049.0,,,NousResearch/Llama-2-70b-hf,1,,8,,,float16,,,,, +,,none,2e-5,,,80778.0,,,NousResearch/Llama-2-70b-hf,2,,4,,,float16,,,,, +0.15,,none,2e-4,16,0.1,28069.0,25654479360.0,14664623616.0,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8902829456329345,492.8103,0.812,0.203,3324.606 +0.15,,none,2e-4,16,0.1,17745.0,15245721600.0,7368103936.0,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8672024631500244,282.6828,1.415,0.354,2897.948 +0.29,,none,2e-4,16,0.1,41405.0,36645743104.0,14665016832.0,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9471714115142822,981.2132,0.815,0.102,3339.539 +0.29,,none,2e-4,16,0.1,25347.0,22161342464.0,7368300544.0,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8963674259185791,519.4995,1.54,0.192,3153.805 +,,none,2e-4,16,0.1,81015.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.15,,none,2e-4,16,0.1,61651.0,58190445568.0,47366035456.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8985625457763672,521.5924,0.767,0.192,1570.575 +,,none,2e-4,16,0.1,81015.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.29,,none,2e-4,16,0.1,69774.0,65584154624.0,47366232064.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9053801918029785,899.4995,0.889,0.111,1821.457 +,,none,2e-4,16,0.1,81043.0,,,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, +,,none,2e-4,16,0.1,80885.0,,,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,,,,, +,,none,2e-4,16,0.1,81043.0,,,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +,,none,2e-4,16,0.1,80297.0,,,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.15,True,baseline-peft-bnb,2e-4,16,0.1,25359.0,21215549440.0,4831579648.0,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.868394603729248,582.2584,0.687,0.172,2813.871 +0.15,True,baseline-peft-bnb,2e-4,16,0.1,12012.0,9525447168.0,2244599808.0,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8697228622436524,293.0901,1.365,0.341,2795.045 +0.29,True,baseline-peft-bnb,2e-4,16,0.1,45481.0,37594830848.0,4831972864.0,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8708569717407226,1141.9093,0.701,0.088,2869.58 +0.29,True,baseline-peft-bnb,2e-4,16,0.1,19437.0,16171584000.0,2244796416.0,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8702435684204102,504.2852,1.586,0.198,3248.955 +0.15,True,baseline-peft-bnb,2e-4,16,0.1,44857.0,44196393984.0,25726455296.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8948281192779541,1108.8622,0.361,0.09,1477.551 +0.15,True,baseline-peft-bnb,2e-4,16,0.1,24006.0,21761152512.0,13273686016.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8967031955718994,533.5877,0.75,0.187,1535.268 +0.29,True,baseline-peft-bnb,2e-4,16,0.1,63891.0,62284255232.0,25726848512.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8938827133178711,2008.7727,0.398,0.05,1631.245 +0.29,True,baseline-peft-bnb,2e-4,16,0.1,31110.0,28873790464.0,13273882624.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8923345756530762,904.7134,0.884,0.111,1810.96 +,True,baseline-peft-bnb,2e-4,16,0.1,79963.0,,,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.14,True,baseline-peft-bnb,2e-4,16,0.1,51535.0,46685150208.0,19266900480.0,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,1.0005579376220703,1958.8245,0.204,0.051,418.21 +,True,baseline-peft-bnb,2e-4,16,0.1,80417.0,,,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.28,True,baseline-peft-bnb,2e-4,16,0.1,80307.0,72626134016.0,19267097088.0,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9999524974822998,3755.0656,0.213,0.027,436.317 +0.15,True,accelerated-peft-bnb,2e-4,16,0.1,18267.0,15323746816.0,4306628096.0,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8702622985839844,473.6415,0.845,0.211,3459.156 +0.15,True,accelerated-peft-bnb,2e-4,16,0.1,11974.0,9525447168.0,2244599808.0,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8716620254516602,290.781,1.376,0.344,2817.241 +0.29,True,accelerated-peft-bnb,2e-4,16,0.1,32691.0,26315010560.0,4307021312.0,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8700333976745606,930.8538,0.859,0.107,3520.209 +0.29,True,accelerated-peft-bnb,2e-4,16,0.1,19507.0,16171584000.0,2244796416.0,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8691346645355225,504.1747,1.587,0.198,3249.667 +0.15,True,accelerated-peft-bnb-foak,2e-4,16,0.1,16809.0,13065396224.0,4306628096.0,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8685474967956544,410.2967,0.975,0.244,3993.208 +0.15,True,accelerated-peft-bnb-foak,2e-4,16,0.1,11780.0,9309506048.0,2244599808.0,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8708373165130615,223.4526,1.79,0.448,3666.101 +0.29,True,accelerated-peft-bnb-foak,2e-4,16,0.1,27953.0,21825572864.0,4307021312.0,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.871671667098999,802.3406,0.997,0.125,4084.051 +0.29,True,accelerated-peft-bnb-foak,2e-4,16,0.1,18836.0,15686158848.0,2244796416.0,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8719861793518067,421.817,1.897,0.237,3884.148 +0.15,True,accelerated-peft-bnb,2e-4,16,0.1,37381.0,36218622464.0,25201503744.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8955930042266845,863.9677,0.463,0.116,1896.367 +0.15,True,accelerated-peft-bnb,2e-4,16,0.1,24002.5,21762409472.0,13273686016.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8963699913024903,534.0626,0.749,0.187,1533.903 +0.29,True,accelerated-peft-bnb,2e-4,16,0.1,49911.0,47209886208.0,25201896960.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.89283447265625,1612.732,0.496,0.062,2031.832 +0.29,True,accelerated-peft-bnb,2e-4,16,0.1,31178.0,28883566592.0,13273882624.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.891493616104126,905.1923,0.884,0.11,1810.002 +0.15,True,accelerated-peft-bnb-foak,2e-4,16,0.1,35493.0,34864005632.0,25201503744.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8977092170715332,797.873,0.501,0.125,2053.46 +0.15,True,accelerated-peft-bnb-foak,2e-4,16,0.1,24609.0,21479203840.0,13273686016.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.899329023361206,467.9373,0.855,0.214,1750.662 +0.29,True,accelerated-peft-bnb-foak,2e-4,16,0.1,46045.0,44399605760.0,25201896960.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8906442070007324,1482.6429,0.54,0.067,2210.107 +0.29,True,accelerated-peft-bnb-foak,2e-4,16,0.1,31782.0,28263236608.0,13273882624.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8938700771331787,819.8908,0.976,0.122,1998.315 +0.14,True,accelerated-peft-bnb,2e-4,16,0.1,71645.0,68126652928.0,37179273216.0,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.0004115581512452,3608.5496,0.111,0.028,454.033 +0.14,True,accelerated-peft-bnb,2e-4,16,0.1,51534.0,46685150208.0,19266900480.0,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,1.0003361415863037,1957.3459,0.204,0.051,418.526 +,True,accelerated-peft-bnb,2e-4,16,0.1,81013.0,,,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.28,True,accelerated-peft-bnb,2e-4,16,0.1,80576.0,72626134016.0,19267097088.0,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.0003034496307373,3754.632,0.213,0.027,436.368 +0.14,True,accelerated-peft-bnb-foak,2e-4,16,0.1,69963.0,67049175552.0,37179273216.0,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.0002764415740968,3310.3528,0.121,0.03,494.932 +0.14,True,accelerated-peft-bnb-foak,2e-4,16,0.1,51248.0,46407474176.0,19266900480.0,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,1.0008399486541748,1759.3679,0.227,0.057,465.622 +,True,accelerated-peft-bnb-foak,2e-4,16,0.1,80785.0,,,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.28,True,accelerated-peft-bnb-foak,2e-4,16,0.1,80907.0,71810538496.0,19267097088.0,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.0006698417663573,3397.5545,0.235,0.029,482.229 +0.15,True,accelerated-peft-autogptq,2e-4,16,0.1,18789.0,15354056192.0,4336937472.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9604459190368653,472.8841,0.846,0.211,3464.696 +0.15,True,accelerated-peft-autogptq,2e-4,16,0.1,12297.0,9542977024.0,2261277696.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9566273403167724,286.9325,1.394,0.349,2855.027 +0.29,True,accelerated-peft-autogptq,2e-4,16,0.1,32583.0,26345319936.0,4337330688.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9381388664245606,927.1603,0.863,0.108,3534.232 +0.29,True,accelerated-peft-autogptq,2e-4,16,0.1,19937.0,16189113856.0,2261474304.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9374161720275879,501.0475,1.597,0.2,3269.95 +0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,16553.0,13095705600.0,4336937472.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.971520586013794,403.7642,0.991,0.248,4057.814 +0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,12341.0,9327035904.0,2261277696.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9715091705322265,220.2987,1.816,0.454,3718.587 +0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,28337.0,21855882240.0,4337330688.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9389788341522217,791.0793,1.011,0.126,4142.189 +0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,19370.0,15703688704.0,2261474304.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9499363136291504,414.8507,1.928,0.241,3949.372 +0.15,True,accelerated-peft-autogptq,2e-4,16,0.1,36439.0,35528691200.0,24511572480.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8999805450439453,824.8819,0.485,0.121,1986.224 +0.15,True,accelerated-peft-autogptq,2e-4,16,0.1,23508.0,21071283712.0,12581313536.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8981617355346679,498.8269,0.802,0.2,1642.253 +0.29,True,accelerated-peft-autogptq,2e-4,16,0.1,48969.0,46519954944.0,24511965696.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8971894550323486,1569.1833,0.51,0.064,2088.22 +0.29,True,accelerated-peft-autogptq,2e-4,16,0.1,30660.0,28189791744.0,12581510144.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8969889640808105,869.4187,0.92,0.115,1884.478 +0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,34745.0,34214612480.0,24511572480.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8998163032531739,755.773,0.529,0.132,2167.847 +0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,24143.0,20788983296.0,12581313536.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9024192810058593,433.2446,0.923,0.231,1890.849 +0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,46713.0,43776172032.0,24511965696.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8984066200256348,1432.2052,0.559,0.07,2287.94 +0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,31123.0,27569485312.0,12581510144.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8986644458770752,780.5165,1.025,0.128,2099.123 +0.14,True,accelerated-peft-autogptq,2e-4,16,0.1,70529.0,67069982208.0,36122602496.0,TheBloke/Llama-2-70B-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9913517284393311,3559.5185,0.112,0.028,460.287 +0.14,True,accelerated-peft-autogptq,2e-4,16,0.1,50471.0,45638376448.0,18220084736.0,TheBloke/Llama-2-70B-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9913260459899902,1905.2123,0.21,0.052,429.978 +,True,accelerated-peft-autogptq,2e-4,16,0.1,79895.0,,,TheBloke/Llama-2-70B-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.28,True,accelerated-peft-autogptq,2e-4,16,0.1,80755.0,71579360256.0,18220281344.0,TheBloke/Llama-2-70B-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9910284423828125,3686.3588,0.217,0.027,444.449 +0.14,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,69339.0,65992504832.0,36122602496.0,TheBloke/Llama-2-70B-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.991469144821167,3234.048,0.124,0.031,506.61 +0.14,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,50733.0,45360700416.0,18220084736.0,TheBloke/Llama-2-70B-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9918032264709473,1691.5951,0.236,0.059,484.277 +,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,80161.0,,,TheBloke/Llama-2-70B-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.28,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,80316.0,70763764736.0,18220281344.0,TheBloke/Llama-2-70B-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9914980411529541,3325.3628,0.241,0.03,492.698 diff --git a/scripts/benchmarks/refs/requirements.txt b/scripts/benchmarks/refs/requirements.txt index 7bfa45a4..06abe58f 100644 --- a/scripts/benchmarks/refs/requirements.txt +++ b/scripts/benchmarks/refs/requirements.txt @@ -1,29 +1,34 @@ -accelerate==0.32.1 +accelerate==0.33.0 aiohttp==3.9.5 aiosignal==1.3.1 async-timeout==4.0.3 attrs==23.2.0 -bitsandbytes==0.43.1 +bitsandbytes==0.43.2 certifi==2024.7.4 charset-normalizer==3.3.2 +contourpy==1.2.1 +cycler==0.12.1 datasets==2.20.0 dill==0.3.8 docstring_parser==0.16 einops==0.8.0 filelock==3.15.4 fire==0.6.0 -flash-attn==2.5.9.post1 --e git+https://github.com/achew010/fms-acceleration.git@33bf943ed4e19db7941ca8f852666a51160fb2de#egg=fms_acceleration&subdirectory=plugins/framework --e git+https://github.com/achew010/fms-acceleration.git@33bf943ed4e19db7941ca8f852666a51160fb2de#egg=fms_acceleration_foak&subdirectory=plugins/fused-ops-and-kernels --e git+https://github.com/achew010/fms-acceleration.git@33bf943ed4e19db7941ca8f852666a51160fb2de#egg=fms_acceleration_peft&subdirectory=plugins/accelerated-peft -fms-hf-tuning @ git+https://github.com/foundation-model-stack/fms-hf-tuning.git@85f32cb15019217ccc22156233f15d280d3f4690 +flash-attn==2.6.3 +-e git+https://github.com/achew010/fms-acceleration.git@74319eb4f6ef5d946573be0e7e851d97ba16b823#egg=fms_acceleration&subdirectory=plugins/framework +-e git+https://github.com/achew010/fms-acceleration.git@74319eb4f6ef5d946573be0e7e851d97ba16b823#egg=fms_acceleration_foak&subdirectory=plugins/fused-ops-and-kernels +-e git+https://github.com/achew010/fms-acceleration.git@74319eb4f6ef5d946573be0e7e851d97ba16b823#egg=fms_acceleration_peft&subdirectory=plugins/accelerated-peft +fms-hf-tuning @ git+https://github.com/foundation-model-stack/fms-hf-tuning.git@7dfd4e71a0ded17ab65654925e18bf9a1d76b0fc +fonttools==4.53.1 frozenlist==1.4.1 fsspec==2024.5.0 -huggingface-hub==0.23.4 +huggingface-hub==0.24.2 idna==3.7 Jinja2==3.1.4 +kiwisolver==1.4.5 markdown-it-py==3.0.0 MarkupSafe==2.1.5 +matplotlib==3.9.1 mdurl==0.1.2 mpmath==1.3.0 multidict==6.0.5 @@ -34,26 +39,28 @@ nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 -nvidia-cudnn-cu12==8.9.2.26 +nvidia-cudnn-cu12==9.1.0.70 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 -nvidia-nccl-cu12==2.19.3 +nvidia-nccl-cu12==2.20.5 nvidia-nvjitlink-cu12==12.5.82 nvidia-nvtx-cu12==12.1.105 packaging==24.1 pandas==2.2.2 -peft==0.11.1 +peft==0.12.0 +pillow==10.4.0 protobuf==5.27.2 psutil==6.0.0 -pyarrow==16.1.0 +pyarrow==17.0.0 pyarrow-hotfix==0.6 Pygments==2.18.0 +pyparsing==3.1.2 python-dateutil==2.9.0.post0 pytz==2024.1 PyYAML==6.0.1 -regex==2024.5.15 +regex==2024.7.24 requests==2.32.3 rich==13.7.1 safetensors==0.4.3 @@ -61,14 +68,14 @@ sentencepiece==0.2.0 shtab==1.7.1 simpleeval==0.9.13 six==1.16.0 -sympy==1.13.0 +sympy==1.13.1 termcolor==2.4.0 threadpoolctl==3.5.0 -tokenizers==0.15.2 -torch==2.2.2 +tokenizers==0.19.1 +torch==2.4.0 tqdm==4.66.4 -transformers==4.39.3 -triton==2.2.0 +transformers==4.43.3 +triton==3.0.0 trl==0.9.6 typing_extensions==4.12.2 tyro==0.8.5 diff --git a/tox.ini b/tox.ini index a51edcb0..83bee5ea 100644 --- a/tox.ini +++ b/tox.ini @@ -21,9 +21,16 @@ allowlist_externals = bash description = run benchmarks skip_install = true deps = + matplotlib # this is for plotting benchmark comparisons in compare_with_reference.py packaging # this is required for flash-attn dep as fms_hf_tuning did not specify -e {toxinidir}/plugins/framework # install the framework here as the flash attention deps requires torch passenv = * # will pass the parent env, otherwise there are too many envs e.g. TRANSFORMERS that need to be set +setenv = + # Need to be set in new versions of triton that don't allow for access to global variable in the JIT compile + # Subsequently, consider changing triton kernels to access global variables that are annotated as constexpr + # source: https://github.com/triton-lang/triton/blob/7b617bcc35c4cf06f61dd267fc049fe33b2851f9/python/triton/compiler/code_generator.py#L280 + # Tracking this as an issue here # https://github.com/foundation-model-stack/fms-acceleration/issues/56 + TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1 commands = # need a version of fms-hf-tuning that has integrated the framework # NOTE: have to install this first coz havnt merged