Skip to content

Commit

Permalink
support for kqv adapter weights in fp8
Browse files Browse the repository at this point in the history
  • Loading branch information
anmolgupt committed Mar 5, 2024
1 parent 21df976 commit 7ff6976
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,14 @@
try:
from megatron.core import ModelParallelConfig
from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.custom_layers.transformer_engine import (
TEColumnParallelLinear,
TERowParallelLinear,
)
from megatron.core.tensor_parallel.mappings import (
gather_from_sequence_parallel_region,
scatter_to_sequence_parallel_region,
gather_from_tensor_model_parallel_region,
)

HAVE_MEGATRON_CORE = True
Expand Down Expand Up @@ -138,6 +143,7 @@ def __init__(
gather_output: bool = True,
input_is_parallel: bool = False, # NOTE: (@ertkonuk) we need this for LoRA adapters that are applied to RowParallelLinear layers
dropout: float = 0.0,
fp8_weights: bool = False,
model_parallel_config: Optional[ModelParallelConfig] = None,
alpha: float | None = None,
**kwargs,
Expand All @@ -154,14 +160,37 @@ def __init__(
self.dim = dim
self.alpha = alpha if alpha is not None else self.dim
self.input_is_parallel = input_is_parallel
self.fp8_weights = fp8_weights
self.gather_output = gather_output

# megatron_gpt_peft_models will provide this arg, but deprecated ones do not.
# in case this arg is not provided, use the dummy default config.
if model_parallel_config is None:
model_parallel_config = ModelParallelConfig()
self.fp8_weights = False
self._sequence_parallel = model_parallel_config.sequence_parallel
model_parallel_config.sequence_parallel = False # SP is irrelevant for the lora linear layer

if self.fp8_weights and kwargs.get('base_model_cfg', False) and kwargs['base_model_cfg'].get('fp8', False):
tp_size = model_parallel_config.tensor_model_parallel_size
#TODO move this check out of here.
if (dim / tp_size) % 16 != 0:
# TE doesn't support such tensor shapes with fp8 precision
self.fp8_weights = False
logging.info(f"LORA with FP8 weights cannot be supported for the given dim={dim} and TP size={tp_size}")
if input_is_parallel:
self.fp8_weights = False
logging.info(f"LORA with FP8 weights are not supported when input_is_parallel is set")

if gather_output:
self.fp8_weights = False
logging.info(f"LORA with FP8 weights are not supported when gather_output is set")
else:
self.fp8_weights = False

if fp8_weights and (not self.fp8_weights):
logging.info(f"LORA with FP8 weights cannot be supported for the configuration provided")

if input_is_parallel:
self.linear_in = RowParallelLinear(
in_features,
Expand All @@ -173,14 +202,26 @@ def __init__(
init_method=self._get_init_fn(column_init_method),
)
else:
self.linear_in = ColumnParallelLinear(
in_features,
dim,
config=model_parallel_config,
bias=False,
gather_output=True,
init_method=self._get_init_fn(column_init_method),
)
if self.fp8_weights:
self.linear_in = TEColumnParallelLinear(
in_features,
dim,
config=model_parallel_config,
bias=False,
gather_output=False,
init_method=self._get_init_fn(column_init_method),
skip_bias_add=True,
is_expert=False,
)
else:
self.linear_in = ColumnParallelLinear(
in_features,
dim,
config=model_parallel_config,
bias=False,
gather_output=True,
init_method=self._get_init_fn(column_init_method),
)
if gather_output:
self.linear_out = RowParallelLinear(
dim,
Expand All @@ -192,16 +233,28 @@ def __init__(
skip_bias_add=True,
)
else:
# (@adithyare) we use this option to mirror the behavior a column parallel layer with two low-rank column parallel layers
# if the original column parallel layer uses gather_output=False, then we will use the self.liner_out layer defined below.
self.linear_out = ColumnParallelLinear(
dim,
out_features,
config=model_parallel_config,
bias=False,
gather_output=True if input_is_parallel else False,
init_method=self._get_init_fn(row_init_method),
)
if self.fp8_weights:
self.linear_out = TEColumnParallelLinear(
dim,
out_features,
config=model_parallel_config,
bias=False,
gather_output=False,
init_method=self._get_init_fn(row_init_method),
skip_bias_add=True,
is_expert=False,
)
else:
# (@adithyare) we use this option to mirror the behavior a column parallel layer with two low-rank column parallel layers
# if the original column parallel layer uses gather_output=False, then we will use the self.liner_out layer defined below.
self.linear_out = ColumnParallelLinear(
dim,
out_features,
config=model_parallel_config,
bias=False,
gather_output=True if input_is_parallel else False,
init_method=self._get_init_fn(row_init_method),
)

if self.norm_position in ["pre", "post"]:
ln_features = in_features if self.norm_position == "pre" else out_features
Expand Down Expand Up @@ -261,6 +314,8 @@ def forward(self, x):

x, _ = self.linear_in(x) # (@adithyare) ColumnLinear returns output and bias, we are ignoring the bias term.
x = self.activation(x)
if self.fp8_weights:
x = gather_from_tensor_model_parallel_region(x)
x, _ = self.linear_out(x)

if self._sequence_parallel and self.input_is_parallel:
Expand Down Expand Up @@ -356,6 +411,7 @@ class Lora4HtoHAdapter(ParallelLinearAdapter):

@dataclass
class LoraKQVAdapterConfig(ParallelLinearAdapterConfig):
fp8_weights: bool = False
_target_: str = "{0}.{1}".format(LoraKQVAdapter.__module__, LoraKQVAdapter.__name__)


Expand Down
1 change: 1 addition & 0 deletions nemo/collections/nlp/parts/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def _create_lora_config(self, cfg, lora_cfg, in_features, out_features, adapter_
"gather_output": False,
"dropout": lora_cfg.adapter_dropout,
"alpha": lora_cfg.get("alpha", lora_cfg.adapter_dim),
"fp8_weights": lora_cfg.get("fp8_weights", False),
}

if lora_cfg.weight_tying:
Expand Down

0 comments on commit 7ff6976

Please sign in to comment.