-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix Low CPU Memory Mode Issues for Quantized Peft (#90)
* address issue 2 in #83 Signed-off-by: Yu Chin Fabian Lim <[email protected]> * properly handle broadcast of adapters Signed-off-by: Yu Chin Fabian Lim <[email protected]> * handle param_init_fn_tied_param Signed-off-by: Yu Chin Fabian Lim <[email protected]> * trl version error Signed-off-by: Yu Chin Fabian Lim <[email protected]> * tied weights fix and meta fix for autogptq Signed-off-by: Yu Chin Fabian Lim <[email protected]> * update readme Signed-off-by: Yu Chin Fabian Lim <[email protected]> * fmt + lint Signed-off-by: Yu Chin Fabian Lim <[email protected]> * upgrade granite benches Signed-off-by: Yu Chin Fabian Lim <[email protected]> --------- Signed-off-by: Yu Chin Fabian Lim <[email protected]>
- Loading branch information
Showing
8 changed files
with
354 additions
and
151 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
72 changes: 72 additions & 0 deletions
72
plugins/accelerated-peft/src/fms_acceleration_peft/fsdp_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# Standard | ||
from collections import defaultdict | ||
|
||
# Third Party | ||
import torch | ||
|
||
# Copyright The IBM Tuning Team | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
# https://spdx.dev/learn/handling-license-info/ | ||
|
||
|
||
def ensure_weights_retied( | ||
param_init_fn, model: torch.nn.Module, device: torch.cuda.device | ||
): | ||
|
||
_tied_names = model._tied_weights_keys | ||
if not _tied_names: | ||
# if no tied names just passthrough | ||
return param_init_fn | ||
|
||
# get map of parameter instances to params. | ||
# - needed for replacement later | ||
_tied_params = {} | ||
for name in _tied_names: | ||
name = name.split(".") | ||
name, param_name = ".".join(name[:-1]), name[-1] | ||
mod = model.get_submodule(name) | ||
param = getattr(mod, param_name) | ||
|
||
_tied_params[id(param)] = None # placeholder for the param first | ||
|
||
# build param_init_fn for the case with tied params | ||
def param_init_fn_tied_param(module: torch.nn.Module): | ||
|
||
# track which params to tie | ||
# - usually only 1, but for completeness consider > 1 | ||
params_to_tie = defaultdict(list) | ||
for n, param in module.named_parameters(recurse=False): | ||
if id(param) in _tied_params: | ||
params_to_tie[id(param)].append(n) | ||
|
||
# call the param init fn, which potentially re-allocates the | ||
# parameters | ||
module = param_init_fn(module) | ||
|
||
# search the parameters again and tie them up again | ||
for id_key, _param_names in params_to_tie.items(): | ||
for param_name in _param_names: | ||
param = _tied_params[id_key] | ||
if param is None: | ||
# everything will be tied to the first time the | ||
# param is observed | ||
_tied_params[id_key] = getattr(module, param_name) | ||
else: | ||
setattr(module, param_name, param) # tie | ||
|
||
return module | ||
|
||
return param_init_fn_tied_param |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,7 +32,11 @@ | |
PretrainedConfig, | ||
PreTrainedModel, | ||
) | ||
from transformers.modeling_utils import no_init_weights, shard_checkpoint | ||
from transformers.modeling_utils import ( | ||
is_local_dist_rank_0, | ||
no_init_weights, | ||
shard_checkpoint, | ||
) | ||
from transformers.utils.generic import ContextManagers | ||
import accelerate | ||
import torch | ||
|
@@ -1105,45 +1109,50 @@ def skip(*args, **kwargs): | |
# prepares the model on gpu in `trainer.train` to avoid unnecessary gpu usage | ||
device_map = {"": "cpu"} | ||
|
||
load_checkpoint_in_model = False | ||
# compat: runtime convert checkpoint gptq(v1) to gptq_v2 format | ||
if quantize_config.format == FORMAT.GPTQ: | ||
accelerate.load_checkpoint_in_model( | ||
model, | ||
dtype=torch_dtype, | ||
# This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292 | ||
checkpoint=model_save_name, | ||
device_map=device_map, | ||
) | ||
# validate sym=False v1 loading needs to be protected for models produced with new v2 format codebase | ||
if ( | ||
not quantize_config.sym | ||
and not quantize_config.is_quantized_or_packed_by_v2() | ||
): | ||
raise ValueError( | ||
f"Loading of a sym=False model with format={FORMAT.GPTQ} is only supported if produced by gptqmodel version >= {MIN_VERSION_WITH_V2}" | ||
# low_cpu_mem_usage fix by [email protected] | ||
# - load the checkpoint only if not low_cpu_mem_usage | ||
# - or if low_cpu_mem_usage then only in the rank_0 | ||
if not low_cpu_mem_usage or is_local_dist_rank_0(): | ||
load_checkpoint_in_model = False | ||
# compat: runtime convert checkpoint gptq(v1) to gptq_v2 format | ||
if quantize_config.format == FORMAT.GPTQ: | ||
accelerate.load_checkpoint_in_model( | ||
model, | ||
dtype=torch_dtype, | ||
# This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292 | ||
checkpoint=model_save_name, | ||
device_map=device_map, | ||
) | ||
# validate sym=False v1 loading needs to be protected for models produced with new v2 format codebase | ||
if ( | ||
not quantize_config.sym | ||
and not quantize_config.is_quantized_or_packed_by_v2() | ||
): | ||
raise ValueError( | ||
f"Loading of a sym=False model with format={FORMAT.GPTQ} is only supported if produced by gptqmodel version >= {MIN_VERSION_WITH_V2}" | ||
) | ||
|
||
logger.info( | ||
f"Compatibility: converting `{FORMAT_FIELD_JSON}` from `{FORMAT.GPTQ}` to `{FORMAT.GPTQ_V2}`." | ||
) | ||
model = convert_gptq_v1_to_v2_format( | ||
model, | ||
quantize_config=quantize_config, | ||
qlinear_kernel=preload_qlinear_kernel, | ||
) | ||
load_checkpoint_in_model = True | ||
quantize_config.format = FORMAT.GPTQ_V2 | ||
logger.info( | ||
f"Compatibility: converting `{FORMAT_FIELD_JSON}` from `{FORMAT.GPTQ}` to `{FORMAT.GPTQ_V2}`." | ||
) | ||
model = convert_gptq_v1_to_v2_format( | ||
model, | ||
quantize_config=quantize_config, | ||
qlinear_kernel=preload_qlinear_kernel, | ||
) | ||
load_checkpoint_in_model = True | ||
quantize_config.format = FORMAT.GPTQ_V2 | ||
|
||
if not load_checkpoint_in_model and backend == Backend.TRITON: | ||
accelerate.load_checkpoint_in_model( | ||
model, | ||
dtype=torch_dtype, # This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292 | ||
checkpoint=model_save_name, | ||
device_map=device_map, | ||
) | ||
|
||
if not load_checkpoint_in_model and backend == Backend.TRITON: | ||
accelerate.load_checkpoint_in_model( | ||
model, | ||
dtype=torch_dtype, # This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292 | ||
checkpoint=model_save_name, | ||
device_map=device_map, | ||
) | ||
# TODO: Why are we using this custom function and not dispatch_model? | ||
model = simple_dispatch_model(model, device_map) | ||
# TODO: Why are we using this custom function and not dispatch_model? | ||
model = simple_dispatch_model(model, device_map) | ||
|
||
qlinear_kernel = select_quant_linear( | ||
bits=quantize_config.bits, | ||
|
Oops, something went wrong.