Skip to content

Commit

Permalink
save extra state dict for fp8 and support for lora fp8 weights for kq…
Browse files Browse the repository at this point in the history
…v adapter
  • Loading branch information
anmolgupt committed Feb 29, 2024
1 parent cccd80d commit 19ecdda
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,14 @@
try:
from megatron.core import ModelParallelConfig
from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.custom_layers.transformer_engine import (
TEColumnParallelLinear,
TERowParallelLinear,
)
from megatron.core.tensor_parallel.mappings import (
gather_from_sequence_parallel_region,
scatter_to_sequence_parallel_region,
gather_from_tensor_model_parallel_region,
)

HAVE_MEGATRON_CORE = True
Expand Down Expand Up @@ -138,6 +143,7 @@ def __init__(
gather_output: bool = True,
input_is_parallel: bool = False, # NOTE: (@ertkonuk) we need this for LoRA adapters that are applied to RowParallelLinear layers
dropout: float = 0.0,
fp8_weights: bool = False,
model_parallel_config: Optional[ModelParallelConfig] = None,
alpha: float | None = None,
**kwargs,
Expand All @@ -154,14 +160,37 @@ def __init__(
self.dim = dim
self.alpha = alpha if alpha is not None else self.dim
self.input_is_parallel = input_is_parallel
self.fp8_weights = fp8_weights
self.gather_output = gather_output

# megatron_gpt_peft_models will provide this arg, but deprecated ones do not.
# in case this arg is not provided, use the dummy default config.
if model_parallel_config is None:
model_parallel_config = ModelParallelConfig()
self.fp8_weights = False
self._sequence_parallel = model_parallel_config.sequence_parallel
model_parallel_config.sequence_parallel = False # SP is irrelevant for the lora linear layer

if self.fp8_weights and kwargs.get('base_model_cfg', False) and kwargs['base_model_cfg'].get('fp8', False):
tp_size = model_parallel_config.tensor_model_parallel_size
#TODO move this check out of here.
if (dim / tp_size) % 16 != 0:
# TE doesn't support such tensor shapes with fp8 precision
self.fp8_weights = False
logging.info(f"LORA with FP8 weights cannot be supported for the given dim={dim} and TP size={tp_size}")
if input_is_parallel:
self.fp8_weights = False
logging.info(f"LORA with FP8 weights are not supported when input_is_parallel is set")

if gather_output:
self.fp8_weights = False
logging.info(f"LORA with FP8 weights are not supported when gather_output is set")
else:
self.fp8_weights = False

if fp8_weights and (not self.fp8_weights):
logging.info(f"LORA with FP8 weights cannot be supported for the configuration provided")

if input_is_parallel:
self.linear_in = RowParallelLinear(
in_features,
Expand All @@ -173,14 +202,26 @@ def __init__(
init_method=self._get_init_fn(column_init_method),
)
else:
self.linear_in = ColumnParallelLinear(
in_features,
dim,
config=model_parallel_config,
bias=False,
gather_output=True,
init_method=self._get_init_fn(column_init_method),
)
if self.fp8_weights:
self.linear_in = TEColumnParallelLinear(
in_features,
dim,
config=model_parallel_config,
bias=False,
gather_output=False,
init_method=self._get_init_fn(column_init_method),
skip_bias_add=True,
is_expert=False,
)
else:
self.linear_in = ColumnParallelLinear(
in_features,
dim,
config=model_parallel_config,
bias=False,
gather_output=True,
init_method=self._get_init_fn(column_init_method),
)
if gather_output:
self.linear_out = RowParallelLinear(
dim,
Expand All @@ -192,16 +233,28 @@ def __init__(
skip_bias_add=True,
)
else:
# (@adithyare) we use this option to mirror the behavior a column parallel layer with two low-rank column parallel layers
# if the original column parallel layer uses gather_output=False, then we will use the self.liner_out layer defined below.
self.linear_out = ColumnParallelLinear(
dim,
out_features,
config=model_parallel_config,
bias=False,
gather_output=True if input_is_parallel else False,
init_method=self._get_init_fn(row_init_method),
)
if self.fp8_weights:
self.linear_out = TEColumnParallelLinear(
dim,
out_features,
config=model_parallel_config,
bias=False,
gather_output=False,
init_method=self._get_init_fn(row_init_method),
skip_bias_add=True,
is_expert=False,
)
else:
# (@adithyare) we use this option to mirror the behavior a column parallel layer with two low-rank column parallel layers
# if the original column parallel layer uses gather_output=False, then we will use the self.liner_out layer defined below.
self.linear_out = ColumnParallelLinear(
dim,
out_features,
config=model_parallel_config,
bias=False,
gather_output=True if input_is_parallel else False,
init_method=self._get_init_fn(row_init_method),
)

if self.norm_position in ["pre", "post"]:
ln_features = in_features if self.norm_position == "pre" else out_features
Expand Down Expand Up @@ -261,6 +314,8 @@ def forward(self, x):

x, _ = self.linear_in(x) # (@adithyare) ColumnLinear returns output and bias, we are ignoring the bias term.
x = self.activation(x)
if self.fp8_weights:
x = gather_from_tensor_model_parallel_region(x)
x, _ = self.linear_out(x)

if self._sequence_parallel and self.input_is_parallel:
Expand Down Expand Up @@ -356,6 +411,7 @@ class Lora4HtoHAdapter(ParallelLinearAdapter):

@dataclass
class LoraKQVAdapterConfig(ParallelLinearAdapterConfig):
fp8_weights: bool = False
_target_: str = "{0}.{1}".format(LoraKQVAdapter.__module__, LoraKQVAdapter.__name__)


Expand Down
13 changes: 12 additions & 1 deletion nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def __init__(self, *args, **kwargs):
else:
self.model_prefix = "model.module." if self.cfg.get('megatron_amp_O2', False) else "model."

self.is_fp8_enabled = self.cfg.get('fp8', False)

self.use_mcore_gpt = hasattr(self, 'mcore_gpt') and self.mcore_gpt
if self.use_mcore_gpt:
assert HAVE_MEGATRON_CORE, "You set `mcore_gpt` as True but megatron core is not found."
Expand Down Expand Up @@ -355,6 +357,11 @@ def get_peft_state_dict(self):
# state_dict keys needs to be in non-O2 format and will be corrected in PEFTSaveRestoreConnector if O2=True
new_k = k.replace("model.module.", "model.", 1)
peft_state_dict[new_k] = state_dict[k]
if self.is_fp8_enabled:
for key in state_dict.keys():
if "_extra_state" in key:
new_k = key.replace("model.module.", "model.", 1)
peft_state_dict[new_k] = state_dict[key]
return peft_state_dict

def state_dict(self, destination=None, prefix=None, keep_vars=False):
Expand Down Expand Up @@ -383,7 +390,11 @@ def load_state_dict(self, state_dict, strict: bool = True):
# setting strict=False will ignore the missing keys (which are not being updated anyway)
# explicitly check if state_dict.keys matches all the expected self.adapter_keys since we don't have the
# safety in strict=True anymore.
assert set(state_dict.keys()) == self.adapter_keys.union(self.tunable_base_param_keys)
if not self.is_fp8_enabled:
assert set(state_dict.keys()) == self.adapter_keys.union(self.tunable_base_param_keys)
else:
state_dict_keys = [k for k in state_dict.keys() if "_extra_state" not in k]
assert set(state_dict_keys) == self.adapter_keys.union(self.tunable_base_param_keys)
super().load_state_dict(state_dict, strict=False)
else:
super().load_state_dict(state_dict, strict=True)
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/nlp/parts/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def _create_lora_config(self, cfg, lora_cfg, in_features, out_features, adapter_
"gather_output": False,
"dropout": lora_cfg.adapter_dropout,
"alpha": lora_cfg.get("alpha", lora_cfg.adapter_dim),
"fp8_weights": lora_cfg.get("fp8_weights", False),
}

if lora_cfg.weight_tying:
Expand Down

0 comments on commit 19ecdda

Please sign in to comment.