From 12e2cb63f4e894d54c0d02d8030507a618ecadab Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Fri, 25 Oct 2024 04:08:49 +0000 Subject: [PATCH 1/2] retie fix should be applied regardless of version for autogptq Signed-off-by: Yu Chin Fabian Lim --- .../framework_plugin_autogptq.py | 51 ++++++++----------- 1 file changed, 21 insertions(+), 30 deletions(-) diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py index e1fd277..6e4e042 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py @@ -356,38 +356,29 @@ def get_callbacks_and_ready_for_train( accelerator is not None and getattr(accelerator.state, "fsdp_plugin", None) is not None ): - _, _transformers_version = _is_package_available( - "transformers", return_version=True - ) - _trl_installed, _trl_version = _is_package_available( - "trl", return_version=True - ) - # the meta device fix for quantized models is since this transformers version - # or if trl is installed then its only for this version - if _transformers_version >= "4.45" and ( - not _trl_installed or (_trl_installed and _trl_version >= "0.12") - ): - # guarded - # NOTE: replace this later with a more specific accelerate version check - try: - # Third Party - # pylint: disable=import-outside-toplevel - from torch.distributed.utils import ensure_weights_retied - - # then its handled internally and there is nothing to do - except ImportError: - # need to use our internal version - # Local - from .fsdp_utils import ( # pylint: disable=import-outside-toplevel - ensure_weights_retied, - ) + # for autogptq we will install the fix regardless of transformers or + # trl version, because those fixes were only for BNB. Here we control + # our own model loading + # NOTE: guard this later with a more specific accelerate version check + try: + # Third Party + # pylint: disable=import-outside-toplevel + from torch.distributed.utils import ensure_weights_retied + + # then its handled internally and there is nothing to do + except ImportError: + # need to use our internal version + # Local + from .fsdp_utils import ( # pylint: disable=import-outside-toplevel + ensure_weights_retied, + ) - accelerator.state.fsdp_plugin.param_init_fn = ensure_weights_retied( - accelerator.state.fsdp_plugin.param_init_fn, - model.get_base_model(), - accelerator.device, - ) + accelerator.state.fsdp_plugin.param_init_fn = ensure_weights_retied( + accelerator.state.fsdp_plugin.param_init_fn, + model.get_base_model(), + accelerator.device, + ) return callbacks From ba38790fc27af74799946bc1ed18abc3e1fd5fce Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Fri, 25 Oct 2024 04:31:45 +0000 Subject: [PATCH 2/2] fmt + lint Signed-off-by: Yu Chin Fabian Lim --- .../src/fms_acceleration_peft/framework_plugin_autogptq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py index 6e4e042..8ce7c02 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py @@ -357,7 +357,7 @@ def get_callbacks_and_ready_for_train( and getattr(accelerator.state, "fsdp_plugin", None) is not None ): - # for autogptq we will install the fix regardless of transformers or + # for autogptq we will install the fix regardless of transformers or # trl version, because those fixes were only for BNB. Here we control # our own model loading # NOTE: guard this later with a more specific accelerate version check