Skip to content

Commit

Permalink
fmt + lint
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Oct 14, 2024
1 parent dce3c00 commit 343c852
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,6 @@ def __init__(self, configurations: Dict[str, Dict]):
)

if self.use_external_lib:
# Third Party
from transformers.utils.import_utils import ( # pylint: disable=import-outside-toplevel
_is_package_available,
)

assert _is_package_available("auto_gptq") is True, (
"Unable to use external library, auto_gptq module not found. "
Expand Down Expand Up @@ -360,28 +356,37 @@ def get_callbacks_and_ready_for_train(
accelerator is not None
and getattr(accelerator.state, "fsdp_plugin", None) is not None
):
_, _transformers_version = _is_package_available("transformers", return_version=True)
_trl_installed, _trl_version = _is_package_available("trl", return_version=True)
_, _transformers_version = _is_package_available(
"transformers", return_version=True
)
_trl_installed, _trl_version = _is_package_available(
"trl", return_version=True
)

# the meta device fix for quantized models is since this transformers version
# or if trl is installed then its only for this version
if (
_transformers_version >= "4.45" and (
not _trl_installed or (_trl_installed and _trl_version >= "0.12")
)
if _transformers_version >= "4.45" and (
not _trl_installed or (_trl_installed and _trl_version >= "0.12")
):
# guarded
# NOTE: replace this later with a more specific accelerate version check
try:
# Third Party
# pylint: disable=import-outside-toplevel
from torch.distributed.utils import ensure_weights_retied

# then its handled internally and there is nothing to do
except ImportError:
# need to use our internal version
from .fsdp_utils import ensure_weights_retied
# Local
from .fsdp_utils import ( # pylint: disable=import-outside-toplevel
ensure_weights_retied,
)

accelerator.state.fsdp_plugin.param_init_fn = ensure_weights_retied(
accelerator.state.fsdp_plugin.param_init_fn,
accelerator.state.fsdp_plugin.param_init_fn,
model.get_base_model(),
accelerator.device
accelerator.device,
)
return callbacks

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,28 +201,37 @@ def get_callbacks_and_ready_for_train(
accelerator is not None
and getattr(accelerator.state, "fsdp_plugin", None) is not None
):
_, _transformers_version = _is_package_available("transformers", return_version=True)
_trl_installed, _trl_version = _is_package_available("trl", return_version=True)
_, _transformers_version = _is_package_available(
"transformers", return_version=True
)
_trl_installed, _trl_version = _is_package_available(
"trl", return_version=True
)

# the meta device fix for quantized models is since this transformers version
# or if trl is installed then its only for this version
if (
_transformers_version >= "4.45" and (
not _trl_installed or (_trl_installed and _trl_version >= "0.12")
)
if _transformers_version >= "4.45" and (
not _trl_installed or (_trl_installed and _trl_version >= "0.12")
):
# guarded
# NOTE: replace this later with a more specific accelerate version check
try:
# Third Party
# pylint: disable=import-outside-toplevel
from torch.distributed.utils import ensure_weights_retied

# then its handled internally and there is nothing to do
except ImportError:
# need to use our internal version
from .fsdp_utils import ensure_weights_retied
# Local
from .fsdp_utils import ( # pylint: disable=import-outside-toplevel
ensure_weights_retied,
)

accelerator.state.fsdp_plugin.param_init_fn = ensure_weights_retied(
accelerator.state.fsdp_plugin.param_init_fn,
accelerator.state.fsdp_plugin.param_init_fn,
model if self._no_peft_model else model.get_base_model(),
accelerator.device
accelerator.device,
)

return callbacks
Expand Down
25 changes: 14 additions & 11 deletions plugins/accelerated-peft/src/fms_acceleration_peft/fsdp_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Standard
from collections import defaultdict
import torch

# Third Party
import torch

# Copyright The IBM Tuning Team
#
Expand All @@ -19,6 +21,7 @@
# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/


def ensure_weights_retied(
param_init_fn, model: torch.nn.Module, device: torch.cuda.device
):
Expand All @@ -28,28 +31,28 @@ def ensure_weights_retied(
# if no tied names just passthrough
return param_init_fn

# get map of parameter instances to params.
# get map of parameter instances to params.
# - needed for replacement later
_tied_params = {}
for name in _tied_names:
name = name.split('.')
name, param_name = '.'.join(name[:-1]), name[-1]
name = name.split(".")
name, param_name = ".".join(name[:-1]), name[-1]
mod = model.get_submodule(name)
param = getattr(mod, param_name)

_tied_params[id(param)] = None # placeholder for the param first
_tied_params[id(param)] = None # placeholder for the param first

# build param_init_fn for the case with tied params
def param_init_fn_tied_param(module: torch.nn.Module):

# track which params to tie
# track which params to tie
# - usually only 1, but for completeness consider > 1
params_to_tie = defaultdict(list)
for n, param in module.named_parameters(recurse=False):
if id(param) in _tied_params:
params_to_tie[id(param)].append(n)

# call the param init fn, which potentially re-allocates the
# call the param init fn, which potentially re-allocates the
# parameters
module = param_init_fn(module)

Expand All @@ -62,8 +65,8 @@ def param_init_fn_tied_param(module: torch.nn.Module):
# param is observed
_tied_params[id_key] = getattr(module, param_name)
else:
setattr(module, param_name, param) # tie
setattr(module, param_name, param) # tie

return module

return param_init_fn_tied_param
return param_init_fn_tied_param
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
PretrainedConfig,
PreTrainedModel,
)
from transformers.modeling_utils import no_init_weights, shard_checkpoint, is_local_dist_rank_0
from transformers.modeling_utils import (
is_local_dist_rank_0,
no_init_weights,
shard_checkpoint,
)
from transformers.utils.generic import ContextManagers
import accelerate
import torch
Expand Down

0 comments on commit 343c852

Please sign in to comment.