From a682b1ce23da55251cf1b3f63389e83c02b44dc3 Mon Sep 17 00:00:00 2001 From: Guy Jacob Date: Thu, 7 Nov 2024 10:36:07 +0200 Subject: [PATCH 1/2] Hyena wrapper: Weight decay override function Still TODO: Apply it in non-test scenario Signed-off-by: Guy Jacob --- nemo/collections/llm/gpt/model/hyena.py | 8 ++++++++ tests/collections/llm/gpt/model/test_hyena.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/nemo/collections/llm/gpt/model/hyena.py b/nemo/collections/llm/gpt/model/hyena.py index 2a6d3ca814dd..8f662db2e779 100644 --- a/nemo/collections/llm/gpt/model/hyena.py +++ b/nemo/collections/llm/gpt/model/hyena.py @@ -24,6 +24,7 @@ from megatron.core import parallel_state from megatron.core.models.hyena import HyenaModel as MCoreHyenaModel from megatron.core.models.hyena.hyena_layer_specs import hyena_stack_spec + from megatron.core.ssm.hyena_utils import hyena_no_weight_decay_cond HAVE_MEGATRON_CORE_OR_TE = True @@ -93,6 +94,11 @@ class HyenaConfig(TransformerConfig, io.IOMixin): tokenizer_model_path: str = None hyena_init_method: str = None hyena_output_layer_init_method: str = None + hyena_filter_no_wd: bool = True + + def __post_init__(self): + super().__post_init__() + self.hyena_no_weight_decay_cond_fn = hyena_no_weight_decay_cond if self.hyena_filter_no_wd else None def configure_model(self, tokenizer) -> "MCoreHyenaModel": model = MCoreHyenaModel( @@ -498,6 +504,7 @@ class HyenaTestConfig(HyenaConfig): recompute_num_layers: int = 2 hyena_init_method: str = 'small_init' hyena_output_layer_init_method: str = 'wang_init' + hyena_filter_no_wd: bool = True @dataclass @@ -531,6 +538,7 @@ class Hyena7bConfig(HyenaConfig): recompute_num_layers: int = 4 hyena_init_method: str = 'small_init' hyena_output_layer_init_method: str = 'wang_init' + hyena_filter_no_wd: bool = True __all__ = [ diff --git a/tests/collections/llm/gpt/model/test_hyena.py b/tests/collections/llm/gpt/model/test_hyena.py index a7b5ac5c3cc4..14ad2c63d342 100644 --- a/tests/collections/llm/gpt/model/test_hyena.py +++ b/tests/collections/llm/gpt/model/test_hyena.py @@ -126,7 +126,7 @@ def get_args(): use_distributed_optimizer=True, bf16=True, ) - opt = MegatronOptimizerModule(config=opt_config) + opt = MegatronOptimizerModule(config=opt_config, no_weight_decay_cond=hyena_config.hyena_no_weight_decay_cond_fn) trainer = nl.Trainer( devices=args.devices, From c3c828be9e663b63c7824e338dc89c06987d5c5c Mon Sep 17 00:00:00 2001 From: guyjacob Date: Thu, 7 Nov 2024 08:41:36 +0000 Subject: [PATCH 2/2] Apply isort and black reformatting Signed-off-by: guyjacob --- nemo/collections/llm/gpt/model/hyena.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/collections/llm/gpt/model/hyena.py b/nemo/collections/llm/gpt/model/hyena.py index 8f662db2e779..e6f4f6e98935 100644 --- a/nemo/collections/llm/gpt/model/hyena.py +++ b/nemo/collections/llm/gpt/model/hyena.py @@ -40,6 +40,7 @@ from nemo.collections.llm.gpt.model.base import GPTModel, gpt_data_step from nemo.lightning import get_vocab_size, io, teardown + def hyena_forward_step(model, batch) -> torch.Tensor: forward_args = {