Skip to content

Commit

Permalink
Fix Low CPU Memory Mode Issues for Quantized Peft (#90)
Browse files Browse the repository at this point in the history
* address issue 2 in #83

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* properly handle broadcast of adapters

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* handle param_init_fn_tied_param

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* trl version error

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* tied weights fix and meta fix for autogptq

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* update readme

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* fmt + lint

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* upgrade granite benches

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

---------

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim authored Oct 17, 2024
1 parent 97fc3c1 commit fc78b55
Show file tree
Hide file tree
Showing 8 changed files with 354 additions and 151 deletions.
7 changes: 4 additions & 3 deletions plugins/accelerated-peft/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ Currently only supports LoRA-related techniques, but more are in the pipeline to

Plugin | Description | Depends | Loading | Augmentation | Callbacks
--|--|--|--|--|--
[autogptq](./src/fms_acceleration_peft/framework_plugin_autogptq.py) | Loads 4bit GPTQ-LoRA with quantized GPTQ as base | AutoGPTQ | ✅ | ✅
[bnb](./src/fms_acceleration_peft/framework_plugin_bnb.py) | Loads 4bit QLoRA with quantized bitsandbytes Linear4 | Huggingface<br>bitsandbytes | ✅ | ✅
[autogptq](./src/fms_acceleration_peft/framework_plugin_autogptq.py) | Loads 4bit GPTQ-LoRA with quantized GPTQ as base | AutoGPTQ | ✅ | ✅ | ✅
[bnb](./src/fms_acceleration_peft/framework_plugin_bnb.py) | Loads 4bit QLoRA with quantized bitsandbytes Linear4 | Huggingface<br>bitsandbytes | ✅ | ✅ | ✅


### Key Points
Expand Down Expand Up @@ -43,6 +43,7 @@ GPTQ-LORA depends on an AutoGPTQ backend to run. There are 2 backend options
## Known Issues
<!--
- Models with sliding windows (e.g., Mistral, Mixtral) will have [memory and throughout issues](https://github.com/huggingface/transformers/issues/30461).
-->
- GPTQ-LORA sometimes observed to have `nan` grad norms in the begining of training, but training proceeds well otherwise.
- `low_cpu_mem_usage` temporarily disabled for AutoGPTQ until bug with `make_sure_no_tensor_in_meta_device` is resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from peft.tuners.lora.model import LoraModel
from transformers import AutoModelForCausalLM, TrainingArguments
from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.import_utils import _is_package_available
import torch
import torch.distributed

Expand Down Expand Up @@ -61,10 +62,6 @@ def __init__(self, configurations: Dict[str, Dict]):
)

if self.use_external_lib:
# Third Party
from transformers.utils.import_utils import ( # pylint: disable=import-outside-toplevel
_is_package_available,
)

assert _is_package_available("auto_gptq") is True, (
"Unable to use external library, auto_gptq module not found. "
Expand Down Expand Up @@ -351,6 +348,48 @@ def augmentation(

return model, modifiable_args

def get_callbacks_and_ready_for_train(
self, model: torch.nn.Module = None, accelerator=None
):
callbacks = []
if (
accelerator is not None
and getattr(accelerator.state, "fsdp_plugin", None) is not None
):
_, _transformers_version = _is_package_available(
"transformers", return_version=True
)
_trl_installed, _trl_version = _is_package_available(
"trl", return_version=True
)

# the meta device fix for quantized models is since this transformers version
# or if trl is installed then its only for this version
if _transformers_version >= "4.45" and (
not _trl_installed or (_trl_installed and _trl_version >= "0.12")
):
# guarded
# NOTE: replace this later with a more specific accelerate version check
try:
# Third Party
# pylint: disable=import-outside-toplevel
from torch.distributed.utils import ensure_weights_retied

# then its handled internally and there is nothing to do
except ImportError:
# need to use our internal version
# Local
from .fsdp_utils import ( # pylint: disable=import-outside-toplevel
ensure_weights_retied,
)

accelerator.state.fsdp_plugin.param_init_fn = ensure_weights_retied(
accelerator.state.fsdp_plugin.param_init_fn,
model.get_base_model(),
accelerator.device,
)
return callbacks


# register
AccelerationPlugin.register_plugin(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,64 +116,12 @@ def model_loader(self, model_name: str, **kwargs):
except ValueError:
world_size = 1 # pg not init

patched_is_local_dist_rank_0 = None
if (
world_size > 1
and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
):
config_kwargs["bnb_4bit_quant_storage"] = torch_dtype

# - of course assume that this package must exist, simply need the version
_, _transformers_version = _is_package_available(
"transformers", return_version=True
)

# this is a workaround that disables low_cpu_mem_mode for quant QLORA
# - this issue was introduced in https://github.com/huggingface/transformers/pull/33154
# whereby the low_cpu_mem_mode was actually fixed.
# - However fixing it causes some problems with the current impl.
# 1. For lora fused ops, the adapters cannot be managed by FSDP, as
# forwards are not called. This causes issue 2) in
# https://github.com/foundation-model-stack/fms-acceleration/issues/83
# where the adapters are still sharded when passed in the fused-ops.
# However, if low_cpu_mem_mode=True, then we NEED FSDP to intialize
# their state, which contradicts the above point.
#
# 2. We have observed,
# see https://github.com/foundation-model-stack/fms-acceleration/pull/86
# that low_cpu_mem_mode=True can cause torch distributed primitives
# to hang.

if _transformers_version >= "4.45":

# pylint: disable=import-outside-toplevel
# Third Party
from fms_acceleration.model_patcher import patch_target_module
import transformers.modeling_utils

def _truthy():
return (
True # use this to always return True to is_local_dist_rank_0
)

# - we cannot use the model patcher and this needs to be called immediately below
# at the model_loader
# - but we immediately revert the patch after loading
patched_is_local_dist_rank_0 = (
transformers.modeling_utils.is_local_dist_rank_0
)
patch_target_module(
"transformers.modeling_utils.is_local_dist_rank_0",
_truthy,
)

warnings.warn(
"Disabling low_cpu_mem_mode in the BNBAccelerationPlugin as this may "
"potentiall cause problems with: "
"1. the fused-ops-and-kernels package, and, "
"2. the syncing of FSDP modules across devices."
)

elif world_size > 1:
warnings.warn(
"Running in distributed mode but bnb_4bit_quant_storage is not set. "
Expand Down Expand Up @@ -206,13 +154,6 @@ def _truthy():
attn_implementation=attn_implementation,
)

if patched_is_local_dist_rank_0 is not None:
# replace it
patch_target_module(
"transformers.modeling_utils.is_local_dist_rank_0",
patched_is_local_dist_rank_0,
)

return model

@property
Expand Down Expand Up @@ -252,6 +193,49 @@ def augmentation(
modifiable_args = (None,) # return a None
return model, modifiable_args

def get_callbacks_and_ready_for_train(
self, model: torch.nn.Module = None, accelerator=None
):
callbacks = []
if (
accelerator is not None
and getattr(accelerator.state, "fsdp_plugin", None) is not None
):
_, _transformers_version = _is_package_available(
"transformers", return_version=True
)
_trl_installed, _trl_version = _is_package_available(
"trl", return_version=True
)

# the meta device fix for quantized models is since this transformers version
# or if trl is installed then its only for this version
if _transformers_version >= "4.45" and (
not _trl_installed or (_trl_installed and _trl_version >= "0.12")
):
# guarded
# NOTE: replace this later with a more specific accelerate version check
try:
# Third Party
# pylint: disable=import-outside-toplevel
from torch.distributed.utils import ensure_weights_retied

# then its handled internally and there is nothing to do
except ImportError:
# need to use our internal version
# Local
from .fsdp_utils import ( # pylint: disable=import-outside-toplevel
ensure_weights_retied,
)

accelerator.state.fsdp_plugin.param_init_fn = ensure_weights_retied(
accelerator.state.fsdp_plugin.param_init_fn,
model if self._no_peft_model else model.get_base_model(),
accelerator.device,
)

return callbacks


# register
AccelerationPlugin.register_plugin(
Expand Down
72 changes: 72 additions & 0 deletions plugins/accelerated-peft/src/fms_acceleration_peft/fsdp_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Standard
from collections import defaultdict

# Third Party
import torch

# Copyright The IBM Tuning Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/


def ensure_weights_retied(
param_init_fn, model: torch.nn.Module, device: torch.cuda.device
):

_tied_names = model._tied_weights_keys
if not _tied_names:
# if no tied names just passthrough
return param_init_fn

# get map of parameter instances to params.
# - needed for replacement later
_tied_params = {}
for name in _tied_names:
name = name.split(".")
name, param_name = ".".join(name[:-1]), name[-1]
mod = model.get_submodule(name)
param = getattr(mod, param_name)

_tied_params[id(param)] = None # placeholder for the param first

# build param_init_fn for the case with tied params
def param_init_fn_tied_param(module: torch.nn.Module):

# track which params to tie
# - usually only 1, but for completeness consider > 1
params_to_tie = defaultdict(list)
for n, param in module.named_parameters(recurse=False):
if id(param) in _tied_params:
params_to_tie[id(param)].append(n)

# call the param init fn, which potentially re-allocates the
# parameters
module = param_init_fn(module)

# search the parameters again and tie them up again
for id_key, _param_names in params_to_tie.items():
for param_name in _param_names:
param = _tied_params[id_key]
if param is None:
# everything will be tied to the first time the
# param is observed
_tied_params[id_key] = getattr(module, param_name)
else:
setattr(module, param_name, param) # tie

return module

return param_init_fn_tied_param
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
PretrainedConfig,
PreTrainedModel,
)
from transformers.modeling_utils import no_init_weights, shard_checkpoint
from transformers.modeling_utils import (
is_local_dist_rank_0,
no_init_weights,
shard_checkpoint,
)
from transformers.utils.generic import ContextManagers
import accelerate
import torch
Expand Down Expand Up @@ -1105,45 +1109,50 @@ def skip(*args, **kwargs):
# prepares the model on gpu in `trainer.train` to avoid unnecessary gpu usage
device_map = {"": "cpu"}

load_checkpoint_in_model = False
# compat: runtime convert checkpoint gptq(v1) to gptq_v2 format
if quantize_config.format == FORMAT.GPTQ:
accelerate.load_checkpoint_in_model(
model,
dtype=torch_dtype,
# This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292
checkpoint=model_save_name,
device_map=device_map,
)
# validate sym=False v1 loading needs to be protected for models produced with new v2 format codebase
if (
not quantize_config.sym
and not quantize_config.is_quantized_or_packed_by_v2()
):
raise ValueError(
f"Loading of a sym=False model with format={FORMAT.GPTQ} is only supported if produced by gptqmodel version >= {MIN_VERSION_WITH_V2}"
# low_cpu_mem_usage fix by [email protected]
# - load the checkpoint only if not low_cpu_mem_usage
# - or if low_cpu_mem_usage then only in the rank_0
if not low_cpu_mem_usage or is_local_dist_rank_0():
load_checkpoint_in_model = False
# compat: runtime convert checkpoint gptq(v1) to gptq_v2 format
if quantize_config.format == FORMAT.GPTQ:
accelerate.load_checkpoint_in_model(
model,
dtype=torch_dtype,
# This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292
checkpoint=model_save_name,
device_map=device_map,
)
# validate sym=False v1 loading needs to be protected for models produced with new v2 format codebase
if (
not quantize_config.sym
and not quantize_config.is_quantized_or_packed_by_v2()
):
raise ValueError(
f"Loading of a sym=False model with format={FORMAT.GPTQ} is only supported if produced by gptqmodel version >= {MIN_VERSION_WITH_V2}"
)

logger.info(
f"Compatibility: converting `{FORMAT_FIELD_JSON}` from `{FORMAT.GPTQ}` to `{FORMAT.GPTQ_V2}`."
)
model = convert_gptq_v1_to_v2_format(
model,
quantize_config=quantize_config,
qlinear_kernel=preload_qlinear_kernel,
)
load_checkpoint_in_model = True
quantize_config.format = FORMAT.GPTQ_V2
logger.info(
f"Compatibility: converting `{FORMAT_FIELD_JSON}` from `{FORMAT.GPTQ}` to `{FORMAT.GPTQ_V2}`."
)
model = convert_gptq_v1_to_v2_format(
model,
quantize_config=quantize_config,
qlinear_kernel=preload_qlinear_kernel,
)
load_checkpoint_in_model = True
quantize_config.format = FORMAT.GPTQ_V2

if not load_checkpoint_in_model and backend == Backend.TRITON:
accelerate.load_checkpoint_in_model(
model,
dtype=torch_dtype, # This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292
checkpoint=model_save_name,
device_map=device_map,
)

if not load_checkpoint_in_model and backend == Backend.TRITON:
accelerate.load_checkpoint_in_model(
model,
dtype=torch_dtype, # This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292
checkpoint=model_save_name,
device_map=device_map,
)
# TODO: Why are we using this custom function and not dispatch_model?
model = simple_dispatch_model(model, device_map)
# TODO: Why are we using this custom function and not dispatch_model?
model = simple_dispatch_model(model, device_map)

qlinear_kernel = select_quant_linear(
bits=quantize_config.bits,
Expand Down
Loading

0 comments on commit fc78b55

Please sign in to comment.