diff --git a/plugins/accelerated-peft/README.md b/plugins/accelerated-peft/README.md
index fc2cf62..8fb1afb 100644
--- a/plugins/accelerated-peft/README.md
+++ b/plugins/accelerated-peft/README.md
@@ -6,8 +6,8 @@ Currently only supports LoRA-related techniques, but more are in the pipeline to
Plugin | Description | Depends | Loading | Augmentation | Callbacks
--|--|--|--|--|--
-[autogptq](./src/fms_acceleration_peft/framework_plugin_autogptq.py) | Loads 4bit GPTQ-LoRA with quantized GPTQ as base | AutoGPTQ | ✅ | ✅
-[bnb](./src/fms_acceleration_peft/framework_plugin_bnb.py) | Loads 4bit QLoRA with quantized bitsandbytes Linear4 | Huggingface
bitsandbytes | ✅ | ✅
+[autogptq](./src/fms_acceleration_peft/framework_plugin_autogptq.py) | Loads 4bit GPTQ-LoRA with quantized GPTQ as base | AutoGPTQ | ✅ | ✅ | ✅
+[bnb](./src/fms_acceleration_peft/framework_plugin_bnb.py) | Loads 4bit QLoRA with quantized bitsandbytes Linear4 | Huggingface
bitsandbytes | ✅ | ✅ | ✅
### Key Points
@@ -43,6 +43,7 @@ GPTQ-LORA depends on an AutoGPTQ backend to run. There are 2 backend options
## Known Issues
+
- GPTQ-LORA sometimes observed to have `nan` grad norms in the begining of training, but training proceeds well otherwise.
-- `low_cpu_mem_usage` temporarily disabled for AutoGPTQ until bug with `make_sure_no_tensor_in_meta_device` is resolved.
diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py
index e03a40d..e1fd277 100644
--- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py
+++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py
@@ -29,6 +29,7 @@
from peft.tuners.lora.model import LoraModel
from transformers import AutoModelForCausalLM, TrainingArguments
from transformers.modeling_utils import is_fsdp_enabled
+from transformers.utils.import_utils import _is_package_available
import torch
import torch.distributed
@@ -61,10 +62,6 @@ def __init__(self, configurations: Dict[str, Dict]):
)
if self.use_external_lib:
- # Third Party
- from transformers.utils.import_utils import ( # pylint: disable=import-outside-toplevel
- _is_package_available,
- )
assert _is_package_available("auto_gptq") is True, (
"Unable to use external library, auto_gptq module not found. "
@@ -351,6 +348,48 @@ def augmentation(
return model, modifiable_args
+ def get_callbacks_and_ready_for_train(
+ self, model: torch.nn.Module = None, accelerator=None
+ ):
+ callbacks = []
+ if (
+ accelerator is not None
+ and getattr(accelerator.state, "fsdp_plugin", None) is not None
+ ):
+ _, _transformers_version = _is_package_available(
+ "transformers", return_version=True
+ )
+ _trl_installed, _trl_version = _is_package_available(
+ "trl", return_version=True
+ )
+
+ # the meta device fix for quantized models is since this transformers version
+ # or if trl is installed then its only for this version
+ if _transformers_version >= "4.45" and (
+ not _trl_installed or (_trl_installed and _trl_version >= "0.12")
+ ):
+ # guarded
+ # NOTE: replace this later with a more specific accelerate version check
+ try:
+ # Third Party
+ # pylint: disable=import-outside-toplevel
+ from torch.distributed.utils import ensure_weights_retied
+
+ # then its handled internally and there is nothing to do
+ except ImportError:
+ # need to use our internal version
+ # Local
+ from .fsdp_utils import ( # pylint: disable=import-outside-toplevel
+ ensure_weights_retied,
+ )
+
+ accelerator.state.fsdp_plugin.param_init_fn = ensure_weights_retied(
+ accelerator.state.fsdp_plugin.param_init_fn,
+ model.get_base_model(),
+ accelerator.device,
+ )
+ return callbacks
+
# register
AccelerationPlugin.register_plugin(
diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py
index 1aadc08..3a4c12a 100644
--- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py
+++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py
@@ -116,64 +116,12 @@ def model_loader(self, model_name: str, **kwargs):
except ValueError:
world_size = 1 # pg not init
- patched_is_local_dist_rank_0 = None
if (
world_size > 1
and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
):
config_kwargs["bnb_4bit_quant_storage"] = torch_dtype
- # - of course assume that this package must exist, simply need the version
- _, _transformers_version = _is_package_available(
- "transformers", return_version=True
- )
-
- # this is a workaround that disables low_cpu_mem_mode for quant QLORA
- # - this issue was introduced in https://github.com/huggingface/transformers/pull/33154
- # whereby the low_cpu_mem_mode was actually fixed.
- # - However fixing it causes some problems with the current impl.
- # 1. For lora fused ops, the adapters cannot be managed by FSDP, as
- # forwards are not called. This causes issue 2) in
- # https://github.com/foundation-model-stack/fms-acceleration/issues/83
- # where the adapters are still sharded when passed in the fused-ops.
- # However, if low_cpu_mem_mode=True, then we NEED FSDP to intialize
- # their state, which contradicts the above point.
- #
- # 2. We have observed,
- # see https://github.com/foundation-model-stack/fms-acceleration/pull/86
- # that low_cpu_mem_mode=True can cause torch distributed primitives
- # to hang.
-
- if _transformers_version >= "4.45":
-
- # pylint: disable=import-outside-toplevel
- # Third Party
- from fms_acceleration.model_patcher import patch_target_module
- import transformers.modeling_utils
-
- def _truthy():
- return (
- True # use this to always return True to is_local_dist_rank_0
- )
-
- # - we cannot use the model patcher and this needs to be called immediately below
- # at the model_loader
- # - but we immediately revert the patch after loading
- patched_is_local_dist_rank_0 = (
- transformers.modeling_utils.is_local_dist_rank_0
- )
- patch_target_module(
- "transformers.modeling_utils.is_local_dist_rank_0",
- _truthy,
- )
-
- warnings.warn(
- "Disabling low_cpu_mem_mode in the BNBAccelerationPlugin as this may "
- "potentiall cause problems with: "
- "1. the fused-ops-and-kernels package, and, "
- "2. the syncing of FSDP modules across devices."
- )
-
elif world_size > 1:
warnings.warn(
"Running in distributed mode but bnb_4bit_quant_storage is not set. "
@@ -206,13 +154,6 @@ def _truthy():
attn_implementation=attn_implementation,
)
- if patched_is_local_dist_rank_0 is not None:
- # replace it
- patch_target_module(
- "transformers.modeling_utils.is_local_dist_rank_0",
- patched_is_local_dist_rank_0,
- )
-
return model
@property
@@ -252,6 +193,49 @@ def augmentation(
modifiable_args = (None,) # return a None
return model, modifiable_args
+ def get_callbacks_and_ready_for_train(
+ self, model: torch.nn.Module = None, accelerator=None
+ ):
+ callbacks = []
+ if (
+ accelerator is not None
+ and getattr(accelerator.state, "fsdp_plugin", None) is not None
+ ):
+ _, _transformers_version = _is_package_available(
+ "transformers", return_version=True
+ )
+ _trl_installed, _trl_version = _is_package_available(
+ "trl", return_version=True
+ )
+
+ # the meta device fix for quantized models is since this transformers version
+ # or if trl is installed then its only for this version
+ if _transformers_version >= "4.45" and (
+ not _trl_installed or (_trl_installed and _trl_version >= "0.12")
+ ):
+ # guarded
+ # NOTE: replace this later with a more specific accelerate version check
+ try:
+ # Third Party
+ # pylint: disable=import-outside-toplevel
+ from torch.distributed.utils import ensure_weights_retied
+
+ # then its handled internally and there is nothing to do
+ except ImportError:
+ # need to use our internal version
+ # Local
+ from .fsdp_utils import ( # pylint: disable=import-outside-toplevel
+ ensure_weights_retied,
+ )
+
+ accelerator.state.fsdp_plugin.param_init_fn = ensure_weights_retied(
+ accelerator.state.fsdp_plugin.param_init_fn,
+ model if self._no_peft_model else model.get_base_model(),
+ accelerator.device,
+ )
+
+ return callbacks
+
# register
AccelerationPlugin.register_plugin(
diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/fsdp_utils.py b/plugins/accelerated-peft/src/fms_acceleration_peft/fsdp_utils.py
new file mode 100644
index 0000000..3086cf7
--- /dev/null
+++ b/plugins/accelerated-peft/src/fms_acceleration_peft/fsdp_utils.py
@@ -0,0 +1,72 @@
+# Standard
+from collections import defaultdict
+
+# Third Party
+import torch
+
+# Copyright The IBM Tuning Team
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# SPDX-License-Identifier: Apache-2.0
+# https://spdx.dev/learn/handling-license-info/
+
+
+def ensure_weights_retied(
+ param_init_fn, model: torch.nn.Module, device: torch.cuda.device
+):
+
+ _tied_names = model._tied_weights_keys
+ if not _tied_names:
+ # if no tied names just passthrough
+ return param_init_fn
+
+ # get map of parameter instances to params.
+ # - needed for replacement later
+ _tied_params = {}
+ for name in _tied_names:
+ name = name.split(".")
+ name, param_name = ".".join(name[:-1]), name[-1]
+ mod = model.get_submodule(name)
+ param = getattr(mod, param_name)
+
+ _tied_params[id(param)] = None # placeholder for the param first
+
+ # build param_init_fn for the case with tied params
+ def param_init_fn_tied_param(module: torch.nn.Module):
+
+ # track which params to tie
+ # - usually only 1, but for completeness consider > 1
+ params_to_tie = defaultdict(list)
+ for n, param in module.named_parameters(recurse=False):
+ if id(param) in _tied_params:
+ params_to_tie[id(param)].append(n)
+
+ # call the param init fn, which potentially re-allocates the
+ # parameters
+ module = param_init_fn(module)
+
+ # search the parameters again and tie them up again
+ for id_key, _param_names in params_to_tie.items():
+ for param_name in _param_names:
+ param = _tied_params[id_key]
+ if param is None:
+ # everything will be tied to the first time the
+ # param is observed
+ _tied_params[id_key] = getattr(module, param_name)
+ else:
+ setattr(module, param_name, param) # tie
+
+ return module
+
+ return param_init_fn_tied_param
diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/base.py b/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/base.py
index 40f80cb..85c17ee 100644
--- a/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/base.py
+++ b/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/models/base.py
@@ -32,7 +32,11 @@
PretrainedConfig,
PreTrainedModel,
)
-from transformers.modeling_utils import no_init_weights, shard_checkpoint
+from transformers.modeling_utils import (
+ is_local_dist_rank_0,
+ no_init_weights,
+ shard_checkpoint,
+)
from transformers.utils.generic import ContextManagers
import accelerate
import torch
@@ -1105,45 +1109,50 @@ def skip(*args, **kwargs):
# prepares the model on gpu in `trainer.train` to avoid unnecessary gpu usage
device_map = {"": "cpu"}
- load_checkpoint_in_model = False
- # compat: runtime convert checkpoint gptq(v1) to gptq_v2 format
- if quantize_config.format == FORMAT.GPTQ:
- accelerate.load_checkpoint_in_model(
- model,
- dtype=torch_dtype,
- # This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292
- checkpoint=model_save_name,
- device_map=device_map,
- )
- # validate sym=False v1 loading needs to be protected for models produced with new v2 format codebase
- if (
- not quantize_config.sym
- and not quantize_config.is_quantized_or_packed_by_v2()
- ):
- raise ValueError(
- f"Loading of a sym=False model with format={FORMAT.GPTQ} is only supported if produced by gptqmodel version >= {MIN_VERSION_WITH_V2}"
+ # low_cpu_mem_usage fix by flim@sg.ibm.com
+ # - load the checkpoint only if not low_cpu_mem_usage
+ # - or if low_cpu_mem_usage then only in the rank_0
+ if not low_cpu_mem_usage or is_local_dist_rank_0():
+ load_checkpoint_in_model = False
+ # compat: runtime convert checkpoint gptq(v1) to gptq_v2 format
+ if quantize_config.format == FORMAT.GPTQ:
+ accelerate.load_checkpoint_in_model(
+ model,
+ dtype=torch_dtype,
+ # This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292
+ checkpoint=model_save_name,
+ device_map=device_map,
)
+ # validate sym=False v1 loading needs to be protected for models produced with new v2 format codebase
+ if (
+ not quantize_config.sym
+ and not quantize_config.is_quantized_or_packed_by_v2()
+ ):
+ raise ValueError(
+ f"Loading of a sym=False model with format={FORMAT.GPTQ} is only supported if produced by gptqmodel version >= {MIN_VERSION_WITH_V2}"
+ )
- logger.info(
- f"Compatibility: converting `{FORMAT_FIELD_JSON}` from `{FORMAT.GPTQ}` to `{FORMAT.GPTQ_V2}`."
- )
- model = convert_gptq_v1_to_v2_format(
- model,
- quantize_config=quantize_config,
- qlinear_kernel=preload_qlinear_kernel,
- )
- load_checkpoint_in_model = True
- quantize_config.format = FORMAT.GPTQ_V2
+ logger.info(
+ f"Compatibility: converting `{FORMAT_FIELD_JSON}` from `{FORMAT.GPTQ}` to `{FORMAT.GPTQ_V2}`."
+ )
+ model = convert_gptq_v1_to_v2_format(
+ model,
+ quantize_config=quantize_config,
+ qlinear_kernel=preload_qlinear_kernel,
+ )
+ load_checkpoint_in_model = True
+ quantize_config.format = FORMAT.GPTQ_V2
+
+ if not load_checkpoint_in_model and backend == Backend.TRITON:
+ accelerate.load_checkpoint_in_model(
+ model,
+ dtype=torch_dtype, # This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292
+ checkpoint=model_save_name,
+ device_map=device_map,
+ )
- if not load_checkpoint_in_model and backend == Backend.TRITON:
- accelerate.load_checkpoint_in_model(
- model,
- dtype=torch_dtype, # This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292
- checkpoint=model_save_name,
- device_map=device_map,
- )
- # TODO: Why are we using this custom function and not dispatch_model?
- model = simple_dispatch_model(model, device_map)
+ # TODO: Why are we using this custom function and not dispatch_model?
+ model = simple_dispatch_model(model, device_map)
qlinear_kernel = select_quant_linear(
bits=quantize_config.bits,
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py
index 6ec4cd9..0558d6b 100644
--- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py
@@ -21,10 +21,10 @@
from peft import LoraConfig
from peft.tuners.lora.layer import LoraLayer
from transformers import TrainingArguments
+from transformers.modeling_utils import is_fsdp_enabled
import torch
import torch.distributed as dist
-
# consider moving this somewhere else later
def lora_adapters_switch_ddp_from_fsdp(modules, fsdp_plugin):
"""
@@ -54,11 +54,31 @@ def _all_reduce_hook(grad):
# because we will ignore these from FSDP, we need to manually
# move them to gpu if they are already not on them
+ # - if the adapters are on meta, we assume that this is for FSDP
+ # low_cpu_mem_mode purposes, and that the values will be synced over
+ # - So just initialize them to empty.
if not A.weight.is_cuda:
- set_module_tensor_to_device(A, "weight", "cuda")
+ value = A.weight
+
+ if is_fsdp_enabled() and value.device == torch.device('meta'):
+ # if low_cpu_mem_mode
+ value = torch.empty(*value.size(), dtype=value.dtype)
+
+ set_module_tensor_to_device(A, "weight", "cuda", value)
+
+ if is_fsdp_enabled():
+ dist.broadcast(A.weight, src=0)
+
if not B.weight.is_cuda:
- set_module_tensor_to_device(B, "weight", "cuda")
+ value = B.weight
+
+ if is_fsdp_enabled() and value.device == torch.device('meta'):
+ value = torch.empty(*value.size(), dtype=value.dtype)
+
+ set_module_tensor_to_device(B, "weight", "cuda", value)
+ if is_fsdp_enabled():
+ dist.broadcast(B.weight, src=0)
def register_foak_model_patch_rules(base_type):
# Third Party
diff --git a/scripts/benchmarks/refs/a100_80gb_granite.csv b/scripts/benchmarks/refs/a100_80gb_granite.csv
index 0637a4c..353efcd 100644
--- a/scripts/benchmarks/refs/a100_80gb_granite.csv
+++ b/scripts/benchmarks/refs/a100_80gb_granite.csv
@@ -1,29 +1,29 @@
-epoch,framework_config,learning_rate,lora_alpha,lora_dropout,mem_nvidia_mem_reserved,mem_peak_torch_mem_alloc_in_bytes,mem_torch_mem_alloc_in_bytes,model_name_or_path,num_gpus,output_dir,peft_method,per_device_train_batch_size,r,target_modules,torch_dtype,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second
-0.14,none,2e-5,,,47683.0,35134836736,21089537536,ibm/PowerLM-3b,1,benchmark_outputs3/exp_0/hf,,4,,,bfloat16,0.9698258590698242,322.7029,1.24,0.31,5077.116
-0.14,none,2e-5,,,43779.0,35198512128,28135299584,ibm/PowerLM-3b,2,benchmark_outputs3/exp_1/hf,,2,,,bfloat16,0.929816370010376,185.4019,2.157,0.539,4418.51
-0.29,none,2e-5,,,52927.0,47059390976,21089930752,ibm/PowerLM-3b,1,benchmark_outputs3/exp_2/hf,,8,,,bfloat16,0.9660552883148193,629.0534,1.272,0.159,5209.097
-0.29,none,2e-5,,,45698.0,35169178112,28129300992,ibm/PowerLM-3b,2,benchmark_outputs3/exp_3/hf,,4,,,bfloat16,0.9272240352630615,334.0146,2.395,0.299,4905.174
-0.14,foak-fast-kernels,2e-5,,,43333.0,35133716992,21088791040,ibm/PowerLM-3b,1,benchmark_outputs3/exp_4/hf,,4,,,bfloat16,0.9692397594451905,273.7817,1.461,0.365,5984.329
-0.14,foak-fast-kernels,2e-5,,,42204.0,35208045568,28144833024,ibm/PowerLM-3b,2,benchmark_outputs3/exp_5/hf,,2,,,bfloat16,0.9299186515808106,160.0673,2.499,0.625,5117.847
-0.29,foak-fast-kernels,2e-5,,,50077.0,40316365824,21089184256,ibm/PowerLM-3b,1,benchmark_outputs3/exp_6/hf,,8,,,bfloat16,0.9664452648162842,532.0588,1.504,0.188,6158.718
-0.29,foak-fast-kernels,2e-5,,,47155.0,35173639168,28114620928,ibm/PowerLM-3b,2,benchmark_outputs3/exp_7/hf,,4,,,bfloat16,0.9271490287780761,285.7146,2.8,0.35,5734.394
-0.14,none,2e-4,16,0.1,24383.0,20193463808,7183026688,ibm/PowerLM-3b,1,benchmark_outputs3/exp_8/hf,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,1.0136762142181397,310.9178,1.287,0.322,5269.56
-0.14,none,2e-4,16,0.1,13849.0,10322532864,3625665536,ibm/PowerLM-3b,2,benchmark_outputs3/exp_9/hf,lora,2,16,q_proj k_proj v_proj o_proj,bfloat16,1.014874858856201,215.9738,1.852,0.463,3793.053
-0.29,none,2e-4,16,0.1,41093.0,33176473088,7183419904,ibm/PowerLM-3b,1,benchmark_outputs3/exp_10/hf,lora,8,16,q_proj k_proj v_proj o_proj,bfloat16,1.013114538192749,611.955,1.307,0.163,5354.642
-0.29,none,2e-4,16,0.1,22076.0,16835885568,3646137344,ibm/PowerLM-3b,2,benchmark_outputs3/exp_11/hf,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,1.0142618751525878,332.7819,2.404,0.3,4923.344
-0.14,foak-fast-kernels,2e-4,16,0.1,21311.0,16897035264,7183026688,ibm/PowerLM-3b,1,benchmark_outputs3/exp_12/hf,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,1.0154188919067382,262.241,1.525,0.381,6247.687
-0.14,foak-fast-kernels,2e-4,16,0.1,12241.0,8673304576,3624651776,ibm/PowerLM-3b,2,benchmark_outputs3/exp_13/hf,lora,2,16,q_proj k_proj v_proj o_proj,bfloat16,1.0135305118560791,195.6867,2.044,0.511,4186.284
-0.29,foak-fast-kernels,2e-4,16,0.1,34661.0,26585189376,7183419904,ibm/PowerLM-3b,1,benchmark_outputs3/exp_14/hf,lora,8,16,q_proj k_proj v_proj o_proj,bfloat16,1.0131592655181885,516.5368,1.549,0.194,6343.788
-0.29,foak-fast-kernels,2e-4,16,0.1,19014.0,13533374464,3640054784,ibm/PowerLM-3b,2,benchmark_outputs3/exp_15/hf,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,1.011543436050415,290.371,2.755,0.344,5642.437
-0.14,baseline-peft-bnb,2e-4,16,0.1,23127.0,18174161920,2143825920,ibm/PowerLM-3b,1,benchmark_outputs3/exp_16/hf,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,1.0207978439331056,391.2937,1.022,0.256,4187.137
-0.14,baseline-peft-bnb,2e-4,16,0.1,11254.0,7826759168,1129891840,ibm/PowerLM-3b,2,benchmark_outputs3/exp_17/hf,lora,2,16,q_proj k_proj v_proj o_proj,bfloat16,1.0213402462005616,220.9544,1.81,0.453,3707.553
-0.29,baseline-peft-bnb,2e-4,16,0.1,43327.0,34177070080,2144219136,ibm/PowerLM-3b,1,benchmark_outputs3/exp_18/hf,lora,8,16,q_proj k_proj v_proj o_proj,bfloat16,1.018005132675171,760.3699,1.052,0.132,4309.481
-0.29,baseline-peft-bnb,2e-4,16,0.1,19415.0,14319836672,1130088448,ibm/PowerLM-3b,2,benchmark_outputs3/exp_19/hf,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,1.0211043930053711,336.4976,2.377,0.297,4868.98
-0.14,accelerated-peft-bnb,2e-4,16,0.1,19379.0,15151677952,2141240832,ibm/PowerLM-3b,1,benchmark_outputs3/exp_20/hf,lora,4,16,q_proj k_proj v_proj o_proj c_attn,bfloat16,1.0230310726165772,296.6505,1.348,0.337,5522.998
-0.14,accelerated-peft-bnb,2e-4,16,0.1,11269.0,7826759168,1129891840,ibm/PowerLM-3b,2,benchmark_outputs3/exp_21/hf,lora,2,16,q_proj k_proj v_proj o_proj c_attn,bfloat16,1.0242974948883057,227.579,1.758,0.439,3599.629
-0.29,accelerated-peft-bnb,2e-4,16,0.1,35837.0,28134687232,2141634048,ibm/PowerLM-3b,1,benchmark_outputs3/exp_22/hf,lora,8,16,q_proj k_proj v_proj o_proj c_attn,bfloat16,1.020608777999878,579.2578,1.381,0.173,5656.894
-0.29,accelerated-peft-bnb,2e-4,16,0.1,19406.0,14319836672,1130088448,ibm/PowerLM-3b,2,benchmark_outputs3/exp_23/hf,lora,4,16,q_proj k_proj v_proj o_proj c_attn,bfloat16,1.0224126720428466,334.7691,2.39,0.299,4894.12
-0.14,accelerated-peft-bnb-foak,2e-4,16,0.1,16347.0,11831656448,2141240832,ibm/PowerLM-3b,1,benchmark_outputs3/exp_24/hf,lora,4,16,q_proj k_proj v_proj o_proj c_attn,bfloat16,1.0216326427459717,247.0632,1.619,0.405,6631.502
-0.14,accelerated-peft-bnb-foak,2e-4,16,0.1,9705.0,6202137600,1129891840,ibm/PowerLM-3b,2,benchmark_outputs3/exp_25/hf,lora,2,16,q_proj k_proj v_proj o_proj c_attn,bfloat16,1.0239610195159912,137.0723,2.918,0.73,5976.408
-0.29,accelerated-peft-bnb-foak,2e-4,16,0.1,29777.0,21519810560,2141634048,ibm/PowerLM-3b,1,benchmark_outputs3/exp_26/hf,lora,8,16,q_proj k_proj v_proj o_proj c_attn,bfloat16,1.0218979930877685,480.7611,1.664,0.208,6815.859
-0.29,accelerated-peft-bnb-foak,2e-4,16,0.1,16454.0,11047001088,1130088448,ibm/PowerLM-3b,2,benchmark_outputs3/exp_27/hf,lora,4,16,q_proj k_proj v_proj o_proj c_attn,bfloat16,1.0228086185455323,252.2261,3.172,0.396,6495.76
+bf16,epoch,fp16,framework_config,learning_rate,lora_alpha,lora_dropout,mem_nvidia_mem_reserved,mem_peak_torch_mem_alloc_in_bytes,mem_torch_mem_alloc_in_bytes,model_name_or_path,num_gpus,peft_method,per_device_train_batch_size,r,target_modules,torch_dtype,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second
+,0.14,,none,2e-05,,,47683.0,35134836736,21089537536,ibm/PowerLM-3b,1,,4,,,bfloat16,0.9708438396453858,321.2872,1.245,0.311,5099.488
+,0.14,,none,2e-05,,,25563.0,17626672128,14116037120,ibm/PowerLM-3b,2,,2,,,bfloat16,0.9699858856201172,178.9566,2.235,0.559,4577.647
+,0.29,,none,2e-05,,,52927.0,47059390976,21089930752,ibm/PowerLM-3b,1,,8,,,bfloat16,0.9677363204956054,627.2531,1.275,0.159,5224.047
+,0.29,,none,2e-05,,,28962.0,23781051904,14120288768,ibm/PowerLM-3b,2,,4,,,bfloat16,0.967378387451172,328.2929,2.437,0.305,4990.665
+,0.14,,foak-fast-kernels,2e-05,,,43333.0,35133716992,21088791040,ibm/PowerLM-3b,1,,4,,,bfloat16,0.9708970737457276,272.6346,1.467,0.367,6009.509
+,0.14,,foak-fast-kernels,2e-05,,,24945.0,17618562048,14107927040,ibm/PowerLM-3b,2,,2,,,bfloat16,0.9701127910614014,155.0031,2.581,0.645,5285.056
+,0.29,,foak-fast-kernels,2e-05,,,50077.0,40316365824,21089184256,ibm/PowerLM-3b,1,,8,,,bfloat16,0.9673656368255616,530.8662,1.507,0.188,6172.553
+,0.29,,foak-fast-kernels,2e-05,,,25856.0,20396960768,14108123648,ibm/PowerLM-3b,2,,4,,,bfloat16,0.9672414493560793,280.7892,2.849,0.356,5834.982
+,0.14,,none,0.0002,16.0,0.1,24591.0,20169870848,7183026688,ibm/PowerLM-3b,1,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0160441780090332,325.5188,1.229,0.307,5033.196
+,0.14,,none,0.0002,16.0,0.1,13919.0,10321519104,3624651776,ibm/PowerLM-3b,2,lora,2,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0179708290100098,205.3199,1.948,0.487,3989.872
+,0.29,,none,0.0002,16.0,0.1,41237.0,33152880128,7183419904,ibm/PowerLM-3b,1,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.015061044692993,641.7595,1.247,0.156,5105.963
+,0.29,,none,0.0002,16.0,0.1,22182.0,16819665408,3629917184,ibm/PowerLM-3b,2,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0122298526763915,342.5583,2.335,0.292,4782.835
+,0.14,,foak-fast-kernels,0.0002,16.0,0.1,21375.0,16873442304,7183026688,ibm/PowerLM-3b,1,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0166709804534912,277.343,1.442,0.361,5907.487
+,0.14,,foak-fast-kernels,0.0002,16.0,0.1,12274.0,8673304576,3624651776,ibm/PowerLM-3b,2,lora,2,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0182081413269044,181.9223,2.199,0.55,4503.022
+,0.29,,foak-fast-kernels,0.0002,16.0,0.1,34803.0,26561596416,7183419904,ibm/PowerLM-3b,1,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0157270526885986,546.8807,1.463,0.183,5991.8
+,0.29,,foak-fast-kernels,0.0002,16.0,0.1,19151.0,13518168064,3624848384,ibm/PowerLM-3b,2,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0141748046875,299.1818,2.674,0.334,5476.269
+True,0.14,,baseline-peft-bnb,0.0002,16.0,0.1,23127.0,18174161920,2143825920,ibm/PowerLM-3b,1,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.020770034790039,390.2779,1.025,0.256,4198.034
+True,0.14,,baseline-peft-bnb,0.0002,16.0,0.1,11366.0,7826759168,1129891840,ibm/PowerLM-3b,2,lora,2,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.026328239440918,217.3217,1.841,0.46,3769.527
+True,0.29,,baseline-peft-bnb,0.0002,16.0,0.1,43327.0,34177070080,2144219136,ibm/PowerLM-3b,1,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0214418220520018,759.0108,1.054,0.132,4317.198
+True,0.29,,baseline-peft-bnb,0.0002,16.0,0.1,19662.0,14319836672,1130088448,ibm/PowerLM-3b,2,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0208434104919433,325.6275,2.457,0.307,5031.517
+True,0.14,,accelerated-peft-bnb,0.0002,16.0,0.1,19379.0,15151677952,2141240832,ibm/PowerLM-3b,1,lora,4,16.0,q_proj k_proj v_proj o_proj c_attn,bfloat16,1.0252878952026367,295.7236,1.353,0.338,5540.308
+True,0.14,,accelerated-peft-bnb,0.0002,16.0,0.1,11316.0,7826759168,1129891840,ibm/PowerLM-3b,2,lora,2,16.0,q_proj k_proj v_proj o_proj c_attn,bfloat16,1.0245559310913086,216.0996,1.851,0.463,3790.845
+True,0.29,,accelerated-peft-bnb,0.0002,16.0,0.1,35837.0,28134687232,2141634048,ibm/PowerLM-3b,1,lora,8,16.0,q_proj k_proj v_proj o_proj c_attn,bfloat16,1.0224976921081543,578.3764,1.383,0.173,5665.515
+True,0.29,,accelerated-peft-bnb,0.0002,16.0,0.1,19600.0,14319836672,1130088448,ibm/PowerLM-3b,2,lora,4,16.0,q_proj k_proj v_proj o_proj c_attn,bfloat16,1.0212884616851807,328.1765,2.438,0.305,4992.436
+True,0.14,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,16347.0,11831656448,2141240832,ibm/PowerLM-3b,1,lora,4,16.0,q_proj k_proj v_proj o_proj c_attn,bfloat16,1.0251048755645753,246.4572,1.623,0.406,6647.808
+True,0.14,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,9796.0,6202137600,1129891840,ibm/PowerLM-3b,2,lora,2,16.0,q_proj k_proj v_proj o_proj c_attn,bfloat16,1.0391525173187257,137.496,2.909,0.727,5957.992
+True,0.29,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,29777.0,21519810560,2141634048,ibm/PowerLM-3b,1,lora,8,16.0,q_proj k_proj v_proj o_proj c_attn,bfloat16,1.0225242614746093,479.5512,1.668,0.209,6833.055
+True,0.29,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,16722.0,11047001088,1130088448,ibm/PowerLM-3b,2,lora,4,16.0,q_proj k_proj v_proj o_proj c_attn,bfloat16,1.0340285396575928,252.7203,3.166,0.396,6483.058
diff --git a/scripts/benchmarks/refs/requirements_granite.txt b/scripts/benchmarks/refs/requirements_granite.txt
index 9d58b46..54221b0 100644
--- a/scripts/benchmarks/refs/requirements_granite.txt
+++ b/scripts/benchmarks/refs/requirements_granite.txt
@@ -1,41 +1,86 @@
-accelerate @ git+https://github.com/huggingface/accelerate.git@4305033f8035defad0a87cd38e5c918e78510ba5
-aiohappyeyeballs==2.4.0
-aiohttp==3.10.6
+accelerate==0.34.2
+aiohappyeyeballs==2.4.3
+aiohttp==3.10.9
aiosignal==1.3.1
+anyio==4.6.0
+argon2-cffi==23.1.0
+argon2-cffi-bindings==21.2.0
+arrow==1.3.0
+asttokens==2.4.1
+async-lru==2.0.4
async-timeout==4.0.3
attrs==24.2.0
+babel==2.16.0
+beautifulsoup4==4.12.3
bitsandbytes==0.43.3
+bleach==6.1.0
certifi==2024.8.30
-charset-normalizer==3.3.2
+cffi==1.17.1
+charset-normalizer==3.4.0
+comm==0.2.2
contourpy==1.3.0
cycler==0.12.1
datasets==2.21.0
+debugpy==1.8.6
+decorator==5.1.1
+defusedxml==0.7.1
dill==0.3.8
docstring_parser==0.16
einops==0.8.0
+exceptiongroup==1.2.2
+executing==2.1.0
+fastjsonschema==2.20.0
filelock==3.16.1
flash-attn==2.6.3
--e git+https://github.com/foundation-model-stack/fms-acceleration.git@36c7fa7cb4ef413ddfbf55b75eb6135b369de038#egg=fms_acceleration&subdirectory=plugins/framework
--e git+https://github.com/foundation-model-stack/fms-acceleration.git@36c7fa7cb4ef413ddfbf55b75eb6135b369de038#egg=fms_acceleration_aadp&subdirectory=plugins/attention-and-distributed-packing
--e git+https://github.com/foundation-model-stack/fms-acceleration.git@36c7fa7cb4ef413ddfbf55b75eb6135b369de038#egg=fms_acceleration_foak&subdirectory=plugins/fused-ops-and-kernels
--e git+https://github.com/foundation-model-stack/fms-acceleration.git@36c7fa7cb4ef413ddfbf55b75eb6135b369de038#egg=fms_acceleration_peft&subdirectory=plugins/accelerated-peft
-fms-hf-tuning @ git+https://github.com/foundation-model-stack/fms-hf-tuning.git@8676d0177c2f9436cf176342f38c3b5602f91751
+-e git+https://github.com/foundation-model-stack/fms-acceleration.git@10cc000c12a57774f68e9da861dbbc7eaa559816#egg=fms_acceleration&subdirectory=plugins/framework
+-e git+https://github.com/foundation-model-stack/fms-acceleration.git@10cc000c12a57774f68e9da861dbbc7eaa559816#egg=fms_acceleration_aadp&subdirectory=plugins/attention-and-distributed-packing
+-e git+https://github.com/foundation-model-stack/fms-acceleration.git@10cc000c12a57774f68e9da861dbbc7eaa559816#egg=fms_acceleration_foak&subdirectory=plugins/fused-ops-and-kernels
+-e git+https://github.com/foundation-model-stack/fms-acceleration.git@10cc000c12a57774f68e9da861dbbc7eaa559816#egg=fms_acceleration_peft&subdirectory=plugins/accelerated-peft
+fms-hf-tuning @ git+https://github.com/foundation-model-stack/fms-hf-tuning.git@b33634c8ce5c85b6d10daa187fd1a069db969b67
fonttools==4.54.1
+fqdn==1.5.1
frozenlist==1.4.1
fsspec==2024.6.1
-huggingface-hub==0.25.1
+h11==0.14.0
+httpcore==1.0.6
+httpx==0.27.2
+huggingface-hub==0.25.2
idna==3.10
+ipykernel==6.29.5
+ipython==8.28.0
+isoduration==20.11.0
+jedi==0.19.1
Jinja2==3.1.4
+json5==0.9.25
+jsonpointer==3.0.0
+jsonschema==4.23.0
+jsonschema-specifications==2024.10.1
+jupyter-events==0.10.0
+jupyter-lsp==2.2.5
+jupyter_client==8.6.3
+jupyter_core==5.7.2
+jupyter_server==2.14.2
+jupyter_server_terminals==0.5.3
+jupyterlab==4.2.5
+jupyterlab_pygments==0.3.0
+jupyterlab_server==2.27.3
kiwisolver==1.4.7
llvmlite==0.43.0
markdown-it-py==3.0.0
-MarkupSafe==2.1.5
+MarkupSafe==3.0.1
matplotlib==3.9.2
+matplotlib-inline==0.1.7
mdurl==0.1.2
+mistune==3.0.2
mpmath==1.3.0
multidict==6.1.0
multiprocess==0.70.16
+nbclient==0.10.0
+nbconvert==7.16.4
+nbformat==5.10.4
+nest-asyncio==1.6.0
networkx==3.3
+notebook_shim==0.2.4
numba==0.60.0
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
@@ -48,39 +93,72 @@ nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
-nvidia-nvjitlink-cu12==12.6.68
+nvidia-nvjitlink-cu12==12.6.77
nvidia-nvtx-cu12==12.1.105
+overrides==7.7.0
packaging==24.1
pandas==2.2.3
+pandocfilters==1.5.1
+parso==0.8.4
peft==0.12.0
+pexpect==4.9.0
pillow==10.4.0
+platformdirs==4.3.6
+prometheus_client==0.21.0
+prompt_toolkit==3.0.48
+propcache==0.2.0
protobuf==5.28.2
psutil==6.0.0
+ptyprocess==0.7.0
+pure_eval==0.2.3
pyarrow==17.0.0
+pycparser==2.22
Pygments==2.18.0
pyparsing==3.1.4
python-dateutil==2.9.0.post0
+python-json-logger==2.0.7
pytz==2024.2
PyYAML==6.0.2
+pyzmq==26.2.0
+referencing==0.35.1
regex==2024.9.11
requests==2.32.3
-rich==13.8.1
+rfc3339-validator==0.1.4
+rfc3986-validator==0.1.1
+rich==13.9.2
+rpds-py==0.20.0
safetensors==0.4.5
+Send2Trash==1.8.3
sentencepiece==0.2.0
shtab==1.7.1
simpleeval==0.9.13
six==1.16.0
+sniffio==1.3.1
+soupsieve==2.6
+stack-data==0.6.3
sympy==1.13.3
+tabulate==0.9.0
+terminado==0.18.1
threadpoolctl==3.5.0
+tinycss2==1.3.0
tokenizers==0.20.0
+tomli==2.0.2
torch==2.4.1
+tornado==6.4.1
tqdm==4.66.5
-transformers==4.45.0
+traitlets==5.14.3
+transformers==4.45.2
triton==3.0.0
-trl @ git+https://github.com/huggingface/trl.git@9af4734178d4436a8dc98a069042eedd2ccf178f
+trl @ git+https://github.com/huggingface/trl.git@9b80f3d50ccb98ceee94bab4145a36e7e58aa4eb
+types-python-dateutil==2.9.0.20241003
typing_extensions==4.12.2
tyro==0.8.11
tzdata==2024.2
+uri-template==1.3.0
urllib3==2.2.3
+wcwidth==0.2.13
+webcolors==24.8.0
+webencodings==0.5.1
+websocket-client==1.8.0
xxhash==3.5.0
-yarl==1.12.1
+yarl==1.14.0