Skip to content

Commit

Permalink
Allow Kernels for Full FT and Non-Quantized PEFT (#79)
Browse files Browse the repository at this point in the history
* add or logic for plugin registration

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: 1000850000 user <[email protected]>

* add fast kernels plugin

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: 1000850000 user <[email protected]>

* prepare full-foak benchmarks

Signed-off-by: 1000850000 user <[email protected]>

* update benchmark logic to have empty framework_config

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: 1000850000 user <[email protected]>

* minor fixes to foak full

Signed-off-by: 1000850000 user <[email protected]>

* addressed code review changes

Signed-off-by: 1000850000 user <[email protected]>

* additional fixes from code review

Signed-off-by: 1000850000 user <[email protected]>

* minor fixes to standard peft

Signed-off-by: 1000850000 user <[email protected]>

* Apply suggestions from code review

Co-authored-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: 1000850000 user <[email protected]>

* changes to filtering function and modifications to allow flexibilty of activating kernels

Signed-off-by: 1000850000 user <[email protected]>

* additional check in fastkernels and changes to FOAK README.md

Signed-off-by: 1000850000 user <[email protected]>

* fix syntax error

Signed-off-by: 1000850000 user <[email protected]>

* fix reloads on multiple patches

Signed-off-by: 1000850000 user <[email protected]>

* dtype changes to scenarios.yaml and README.md

Signed-off-by: 1000850000 user <[email protected]>

* changes to scenarios.yaml

Signed-off-by: 1000850000 user <[email protected]>

* additional comments

Signed-off-by: 1000850000 user <[email protected]>

* format and lint

Signed-off-by: 1000850000 user <[email protected]>

* fixes and updates to benchmark

Signed-off-by: 1000850000 user <[email protected]>

* fixes on reload

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

---------

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: 1000850000 user <[email protected]>
Co-authored-by: 1000850000 user <[email protected]>
Co-authored-by: achew010 <[email protected]>
  • Loading branch information
3 people authored Sep 16, 2024
1 parent a0ac97a commit 4e81c64
Show file tree
Hide file tree
Showing 25 changed files with 890 additions and 179 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ def model_loader(self, model_name: str, **kwargs):
# and there is a section of code that will be skipped if not set.
setattr(model, "is_loaded_in_4bit", True)
setattr(model, "quantization_method", "gptq")

return model

@property
Expand Down Expand Up @@ -275,6 +274,8 @@ def augmentation(

# some assertions
assert peft_config is not None, "need peft_config to install PEFT adapters"
# running this plugin in float16 is the most performant
# https://github.com/foundation-model-stack/fms-acceleration/issues/84
assert (
model.dtype == torch.float16 or train_args.fp16
), "need to run in fp16 mixed precision or load model in fp16"
Expand Down Expand Up @@ -324,6 +325,13 @@ def augmentation(
auto_find_all_linears=requires_installation_on_all_linears(peft_config),
train_mode=True, # install adapaters for training
)

# We do not set `is_loaded_in_4bit`` at this point because otherwise
# `accelerate.prepare_model` will think the device placement is finalized
# for the quantized model, and will raise
# Reassign `quantization_method` after PEFT installation replaces the top-level class
setattr(model, "quantization_method", "gptq")

modifiable_args = (None,) # return a None for peft_config

if self.use_external_lib:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,20 @@ def _is_backbone(module: torch.nn.Module):
# Local
from .flash_attn import _flash_attention_forward_with_posids

# - we need to reload on the correct module
try:
# if it is peft
_module_path = model.get_base_model().__module__
except AttributeError:
_module_path = model.__module__

ModelPatcher.register(
ModelPatcherRule(
rule_id="flash_attn_forward",
import_and_maybe_reload=(
"transformers.modeling_flash_attention_utils._flash_attention_forward",
partial(_flash_attention_forward_with_posids, id(model)),
model.__module__,
_module_path,
),
),
)
Expand Down
1 change: 0 additions & 1 deletion plugins/framework/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ classifiers=[
dependencies = [
"numpy<2.0", # numpy needs to be bounded due to incompatiblity with current torch<2.3
"torch>2.2",
"transformers",
"peft",
"accelerate",
"pandas",
Expand Down
69 changes: 55 additions & 14 deletions plugins/framework/src/fms_acceleration/framework_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
class PluginRegistration:
plugin: "AccelerationPlugin"
AND: List[str] = None
# OR: List[str] = None # not implemented yet
OR: List[str] = None

# package metadata
package_name: str = None
Expand All @@ -53,28 +53,61 @@ def _trace_key_path(configuration: Dict, key: str):
def get_relevant_configuration_sections(configuration: Dict) -> Dict:
results = []

# this function updates cfg with content
# - equivalent to taking a union
def _update_config_contents(_cfg: Dict, content: Dict, key: str):
path = key.split(".")
n = len(path)
_cfg = relevant_config
while n > 1:
p = path.pop(0)
if p not in _cfg:
_cfg[p] = {}
_cfg = _cfg[p]
n -= 1

_cfg[path[0]] = content

# assume the registrations are all done with at least some default key
for registration in PLUGIN_REGISTRATIONS:
relevant_config = {}
# OR is not implemented yet

_and_keys = registration.AND
_or_keys = registration.OR
if _and_keys is None:
_and_keys = []
if _or_keys is None:
_or_keys = []

# go through AND paths then OR paths
# - if all AND paths are speciied, then return their union of all content
# - if any OR path is specified, then return the union of specified content
reject = False
for key in registration.AND:
for key in _and_keys:
content = _trace_key_path(configuration, key)
if content is None:
# if AND key, then if at least one of them not
# specified, then reject and do not descend config tree
reject = True
break

path = key.split(".")
n = len(path)
_cfg = relevant_config
while n > 1:
p = path.pop(0)
if p not in _cfg:
_cfg[p] = {}
_cfg = _cfg[p]
n -= 1
# update
_update_config_contents(relevant_config, content, key)

# if all the any keys were not satisfied, then reset the config
if reject:
relevant_config = {}

for key in _or_keys:
content = _trace_key_path(configuration, key)
if content is not None:
if reject:
# it is an OR key, and if at least one of them specified
# then do not reject
reject = False

_cfg[path[0]] = content
# update all content that is not None
_update_config_contents(relevant_config, content, key)

if reject:
continue
Expand All @@ -91,7 +124,8 @@ class AccelerationPlugin:
@staticmethod
def register_plugin(
plugin: "AccelerationPlugin",
configuration_and_paths: List[str],
configuration_and_paths: List[str] = None,
configuration_or_paths: List[str] = None,
**kwargs,
):

Expand All @@ -101,6 +135,12 @@ def register_plugin(
# is done (global-variable-not-assigned)
# global PLUGIN_REGISTRATIONS

assert (
configuration_and_paths is not None and len(configuration_and_paths) > 0
) or (
configuration_or_paths is not None and len(configuration_or_paths) > 0
), "Specify at least one AND or OR path"

# get the package metadata
pkg_name = sys.modules[plugin.__module__].__package__
try:
Expand All @@ -112,6 +152,7 @@ def register_plugin(
PluginRegistration(
plugin=plugin,
AND=configuration_and_paths,
OR=configuration_or_paths,
package_name=pkg_name,
package_version=package_version,
)
Expand Down
18 changes: 14 additions & 4 deletions plugins/framework/src/fms_acceleration/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,14 +348,24 @@ def _import_and_reload(model: torch.nn.Module):
key=lambda _rule: len(_rule.import_and_maybe_reload[2]),
reverse=False,
)
for rule_s in _with_reload:
for rule_l in _with_reload[1:]:

for i_s, rule_s in enumerate(_with_reload[:-1]):
for rule_l in _with_reload[i_s + 1 :]:
# if target paths in rule s is a prefix of rule l, raise an error
_, _, _path_s = rule_s.import_and_maybe_reload
_name_s, _obj_s, _path_s = rule_s.import_and_maybe_reload
_, _, _path_l = rule_l.import_and_maybe_reload

if _path_s == _path_l:
# - in the even the target is exactly the same, we will
# only reload once
rule_s.import_and_maybe_reload = (_name_s, _obj_s, None)
continue

# - otherwise, we do not consider the cases where the target
# is a subpath since this results in unpredictablity.
assert not _path_l.startswith(
_path_s
), f"Attempting to reload same path `{_path_s}` multiple times in \
), f"Attempting to reload a subpath`{_path_s}` multiple times in \
{rule_s.rule_id} and {rule_l.rule_id}"

# handle those with reload first
Expand Down
22 changes: 18 additions & 4 deletions plugins/framework/src/fms_acceleration/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# Standard
from contextlib import contextmanager
from tempfile import NamedTemporaryFile
from typing import Any, Callable, Dict, List, Set, Tuple, Type
from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union

# Third Party
import torch
Expand Down Expand Up @@ -67,7 +67,14 @@ def configure_framework_from_json(
@contextmanager
def build_framework_and_maybe_instantiate(
plugins_to_be_registered: List[
Tuple[List[str], Type[AccelerationPlugin]] # and_paths, plugin_class
Union[
Tuple[List[str], Type[AccelerationPlugin]], # and_paths, plugin_class
Tuple[
List[str],
List[str], # and_or_paths
Type[AccelerationPlugin], # plugin_class
],
]
],
configuration_contents: Dict = None,
instantiate: bool = True,
Expand All @@ -89,10 +96,17 @@ def build_framework_and_maybe_instantiate(
AccelerationFramework.active_plugins = []
AccelerationFramework.plugins_require_custom_loading = []

for path, plugin in plugins_to_be_registered:
for paths_and_plugins in plugins_to_be_registered:
try:
and_paths, plugin = paths_and_plugins
or_paths = None
except ValueError:
and_paths, or_paths, plugin = paths_and_plugins

AccelerationPlugin.register_plugin(
plugin,
configuration_and_paths=path,
configuration_and_paths=and_paths,
configuration_or_paths=or_paths,
)

if instantiate:
Expand Down
73 changes: 73 additions & 0 deletions plugins/framework/tests/test_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,76 @@ def _hook(
framework.augmentation(model, None, None)
for c, (n, _) in zip(plugin_activation_order, plugins_to_be_installed):
assert n in c


def test_plugin_registration_combination_logic():

plugin = create_plugin_cls(
restricted_models={"CausalLM"},
requires_agumentation=True,
agumentation=dummy_augmentation,
)

configuration_contents = {"existing1": {"key1": 1}, "existing2": {"key1": 1}}

# empty conditions
with pytest.raises(AssertionError, match="Specify at least one AND or OR path"):
with build_framework_and_instantiate(
plugins_to_be_registered=[
([], [], plugin),
],
configuration_contents=configuration_contents,
) as framework:
pass

# AND logic - happy
with build_framework_and_instantiate(
plugins_to_be_registered=[
(["existing1", "existing2"], plugin),
],
configuration_contents=configuration_contents,
) as framework:
# check 1.
assert len(PLUGIN_REGISTRATIONS) == 1

# check 2.
assert len(framework.active_plugins) == 1

# AND - sad path
with pytest.raises(
ValueError,
match="No plugins could be configured. Please check the acceleration",
):
with build_framework_and_instantiate(
plugins_to_be_registered=[
(["existing1", "non-existant"], plugin),
],
configuration_contents=configuration_contents,
) as framework:
pass

# OR logic
with build_framework_and_instantiate(
plugins_to_be_registered=[
([], ["existing1", "non-existant"], plugin),
],
configuration_contents=configuration_contents,
) as framework:
# check 1.
assert len(PLUGIN_REGISTRATIONS) == 1

# check 2.
assert len(framework.active_plugins) == 1

# OR - sad path
with pytest.raises(
ValueError,
match="No plugins could be configured. Please check the acceleration",
):
with build_framework_and_instantiate(
plugins_to_be_registered=[
(["non-existant", "non-existant2"], plugin),
],
configuration_contents=configuration_contents,
) as framework:
pass
Loading

0 comments on commit 4e81c64

Please sign in to comment.