diff --git a/configs/OLMoE-200m-80m.yml b/configs/OLMoE-200m-80m.yml new file mode 100644 index 000000000..49778ff58 --- /dev/null +++ b/configs/OLMoE-200m-80m.yml @@ -0,0 +1,107 @@ +run_name: OLMoE +seed: 6198 +dry_run: false + +wandb: + name: ${run_name} + project: olmoe + group: null + +model: + d_model: 512 + n_heads: 8 + n_layers: 10 + mlp_ratio: 4 # 4 vs 8 (for swiglu) + weight_tying: true + alibi: false + rope: true + flash_attention: true + attention_dropout: 0.0 + attention_layer_norm: false + multi_query_attention: false + include_bias: false + block_type: moe + layer_norm_type: default + layer_norm_with_affine: false + bias_for_layer_norm: false + attention_layer_norm_with_affine: false + activation_type: gelu # gelu vs swiglu + residual_dropout: 0.0 + embedding_dropout: 0.0 + max_sequence_length: 2048 + vocab_size: 50280 + embedding_size: 50304 + eos_token_id: 50279 + pad_token_id: 1 + init_device: meta + init_fn: normal # mitchell vs normal + +compile: null # causes instability on AMD GPUs + +optimizer: + name: adamw + learning_rate: 4.0e-4 + weight_decay: 0.1 + betas: + - 0.9 + - 0.95 + metrics_log_interval: 10 + +scheduler: + name: cosine_with_warmup + t_warmup: 2000 + alpha_f: 0.1 + +tokenizer: + identifier: tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json + truncate_direction: right + +save_folder: /data/niklas/olmoe +save_overwrite: false +# Sharded checkpoints (best for restarts) +save_interval: 1000 +save_num_checkpoints_to_keep: 9 +# Unsharded checkpoints (for final storage) +save_interval_unsharded: 10000 +save_num_unsharded_checkpoints_to_keep: -1 + +load_path: null + +max_duration: 10e9T # 10B tokens +global_train_batch_size: 2048 +device_train_microbatch_size: 8 + +precision: amp_bf16 + +fsdp: + wrapping_strategy: null + precision: mixed + +max_grad_norm: 1.0 +max_grad_norm_ratio: null + +speed_monitor: + window_size: 20 + +eval_interval: ${save_interval} +eval_subset_num_batches: -1 +device_eval_batch_size: ${device_train_microbatch_size} + +data: + pad_direction: right + num_workers: 0 + drop_last: true + pin_memory: true + prefetch_factor: 16 + persistent_workers: true + timeout: 0 + paths: + - /data/niklas/llm/data/part-000-00000.npy + - /data/niklas/llm/data/part-000-00001.npy + - /data/niklas/llm/data/part-001-00000.npy + - /data/niklas/llm/data/part-002-00000.npy + - /data/niklas/llm/data/part-003-00000.npy + - /data/niklas/llm/data/part-004-00000.npy + - /data/niklas/llm/data/part-004-00001.npy + - /data/niklas/llm/data/part-005-00000.npy + - /data/niklas/llm/data/part-005-00001.npy \ No newline at end of file diff --git a/configs/OLMoE-200m.yml b/configs/OLMoE-200m.yml new file mode 100644 index 000000000..a7d79ac98 --- /dev/null +++ b/configs/OLMoE-200m.yml @@ -0,0 +1,107 @@ +run_name: OLMoE +seed: 6198 +dry_run: false + +wandb: + name: ${run_name} + project: olmoe + group: null + +model: + d_model: 896 + n_heads: 14 + n_layers: 16 + mlp_ratio: 4 # 4 vs 8 (for swiglu) + weight_tying: true + alibi: false + rope: true + flash_attention: true + attention_dropout: 0.0 + attention_layer_norm: false + multi_query_attention: false + include_bias: false + block_type: sequential + layer_norm_type: default + layer_norm_with_affine: false + bias_for_layer_norm: false + attention_layer_norm_with_affine: false + activation_type: gelu # gelu vs swiglu + residual_dropout: 0.0 + embedding_dropout: 0.0 + max_sequence_length: 2048 + vocab_size: 50280 + embedding_size: 50304 + eos_token_id: 50279 + pad_token_id: 1 + init_device: meta + init_fn: normal # mitchell vs normal + +compile: null # causes instability on AMD GPUs + +optimizer: + name: adamw + learning_rate: 4.0e-4 + weight_decay: 0.1 + betas: + - 0.9 + - 0.95 + metrics_log_interval: 10 + +scheduler: + name: cosine_with_warmup + t_warmup: 2000 + alpha_f: 0.1 + +tokenizer: + identifier: tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json + truncate_direction: right + +save_folder: /data/niklas/olmoe +save_overwrite: false +# Sharded checkpoints (best for restarts) +save_interval: 1000 +save_num_checkpoints_to_keep: 9 +# Unsharded checkpoints (for final storage) +save_interval_unsharded: 10000 +save_num_unsharded_checkpoints_to_keep: -1 + +load_path: null + +max_duration: 10e9T # 10B tokens +global_train_batch_size: 2048 +device_train_microbatch_size: 8 + +precision: amp_bf16 + +fsdp: + wrapping_strategy: null + precision: mixed + +max_grad_norm: 1.0 +max_grad_norm_ratio: null + +speed_monitor: + window_size: 20 + +eval_interval: ${save_interval} +eval_subset_num_batches: -1 +device_eval_batch_size: ${device_train_microbatch_size} + +data: + pad_direction: right + num_workers: 0 + drop_last: true + pin_memory: true + prefetch_factor: 16 + persistent_workers: true + timeout: 0 + paths: + - /data/niklas/llm/data/part-000-00000.npy + - /data/niklas/llm/data/part-000-00001.npy + - /data/niklas/llm/data/part-001-00000.npy + - /data/niklas/llm/data/part-002-00000.npy + - /data/niklas/llm/data/part-003-00000.npy + - /data/niklas/llm/data/part-004-00000.npy + - /data/niklas/llm/data/part-004-00001.npy + - /data/niklas/llm/data/part-005-00000.npy + - /data/niklas/llm/data/part-005-00001.npy \ No newline at end of file diff --git a/configs/OLMoE-600m-200m.yml b/configs/OLMoE-600m-200m.yml new file mode 100644 index 000000000..bc25092db --- /dev/null +++ b/configs/OLMoE-600m-200m.yml @@ -0,0 +1,107 @@ +run_name: OLMoE +seed: 6198 +dry_run: false + +wandb: + name: ${run_name} + project: olmoe + group: null + +model: + d_model: 768 + n_heads: 12 + n_layers: 14 + mlp_ratio: 4 # 4 vs 8 (for swiglu) + weight_tying: true + alibi: false + rope: true + flash_attention: true + attention_dropout: 0.0 + attention_layer_norm: false + multi_query_attention: false + include_bias: false + block_type: moe + layer_norm_type: default + layer_norm_with_affine: false + bias_for_layer_norm: false + attention_layer_norm_with_affine: false + activation_type: gelu # gelu vs swiglu + residual_dropout: 0.0 + embedding_dropout: 0.0 + max_sequence_length: 2048 + vocab_size: 50280 + embedding_size: 50304 + eos_token_id: 50279 + pad_token_id: 1 + init_device: meta + init_fn: normal # mitchell vs normal + +compile: null # causes instability on AMD GPUs + +optimizer: + name: adamw + learning_rate: 4.0e-4 + weight_decay: 0.1 + betas: + - 0.9 + - 0.95 + metrics_log_interval: 10 + +scheduler: + name: cosine_with_warmup + t_warmup: 2000 + alpha_f: 0.1 + +tokenizer: + identifier: tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json + truncate_direction: right + +save_folder: /data/niklas/olmoe +save_overwrite: false +# Sharded checkpoints (best for restarts) +save_interval: 1000 +save_num_checkpoints_to_keep: 9 +# Unsharded checkpoints (for final storage) +save_interval_unsharded: 10000 +save_num_unsharded_checkpoints_to_keep: -1 + +load_path: null + +max_duration: 10e9T # 10B tokens +global_train_batch_size: 2048 +device_train_microbatch_size: 8 + +precision: amp_bf16 + +fsdp: + wrapping_strategy: null + precision: mixed + +max_grad_norm: 1.0 +max_grad_norm_ratio: null + +speed_monitor: + window_size: 20 + +eval_interval: ${save_interval} +eval_subset_num_batches: -1 +device_eval_batch_size: ${device_train_microbatch_size} + +data: + pad_direction: right + num_workers: 0 + drop_last: true + pin_memory: true + prefetch_factor: 16 + persistent_workers: true + timeout: 0 + paths: + - /data/niklas/llm/data/part-000-00000.npy + - /data/niklas/llm/data/part-000-00001.npy + - /data/niklas/llm/data/part-001-00000.npy + - /data/niklas/llm/data/part-002-00000.npy + - /data/niklas/llm/data/part-003-00000.npy + - /data/niklas/llm/data/part-004-00000.npy + - /data/niklas/llm/data/part-004-00001.npy + - /data/niklas/llm/data/part-005-00000.npy + - /data/niklas/llm/data/part-005-00001.npy diff --git a/configs/OLMoE-test.yaml b/configs/OLMoE-test.yaml new file mode 100644 index 000000000..e79f36a32 --- /dev/null +++ b/configs/OLMoE-test.yaml @@ -0,0 +1,100 @@ +run_name: OLMoE-1B +seed: 6198 +dry_run: false + +wandb: + name: ${run_name} + project: olmoe + group: null + +model: + d_model: 640 + n_heads: 10 + n_layers: 16 + mlp_ratio: 4 # TODO: 4 or 8 + weight_tying: true + alibi: false + rope: true + flash_attention: true + attention_dropout: 0.0 + attention_layer_norm: false + multi_query_attention: false + include_bias: false + block_type: moe + layer_norm_type: default + layer_norm_with_affine: false + bias_for_layer_norm: false + attention_layer_norm_with_affine: false + activation_type: gelu # TODO: Experiment with gelu vs swiglu + residual_dropout: 0.0 + embedding_dropout: 0.0 + max_sequence_length: 2048 + vocab_size: 50280 + embedding_size: 50304 + eos_token_id: 50279 + pad_token_id: 1 + init_device: meta + init_fn: mitchell # TODO: Experiment with regular normal + +compile: null # causes instability on AMD GPUs + +optimizer: + name: adamw + learning_rate: 4.0e-4 + weight_decay: 0.1 + betas: + - 0.9 + - 0.95 + metrics_log_interval: 10 + +scheduler: + name: cosine_with_warmup + t_warmup: 2000 + alpha_f: 0.1 + +tokenizer: + identifier: tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json + truncate_direction: right + +save_folder: /data/niklas/olmoe +save_overwrite: false +# Sharded checkpoints (best for restarts) +save_interval: 1000 +save_num_checkpoints_to_keep: 9 +# Unsharded checkpoints (for final storage) +save_interval_unsharded: 10000 +save_num_unsharded_checkpoints_to_keep: -1 + +load_path: null + +max_duration: 739_328 # 3.1T tokens +global_train_batch_size: 2048 +device_train_microbatch_size: 8 + +precision: amp_bf16 + +fsdp: + wrapping_strategy: null + precision: mixed + +max_grad_norm: 1.0 +max_grad_norm_ratio: null + +speed_monitor: + window_size: 20 + +eval_interval: ${save_interval} +eval_subset_num_batches: -1 +device_eval_batch_size: ${device_train_microbatch_size} + +data: + pad_direction: right + num_workers: 0 + drop_last: true + pin_memory: true + prefetch_factor: 16 + persistent_workers: true + timeout: 0 + paths: + - /data/niklas/llm/data/part-000-00000.npy + - /data/niklas/llm/data/part-000-00001.npy \ No newline at end of file diff --git a/olmo/config.py b/olmo/config.py index d6bb3179b..2d56c0e72 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -181,6 +181,8 @@ class BlockType(StrEnum): sequential = "sequential" llama = "llama" + + moe = "moe" """ A block similar to the sequential block with slightly different implementations of operations like attention to imitate the behavior of Llama. @@ -433,6 +435,36 @@ class ModelConfig(BaseConfig): See :data:`TrainConfig.precision` instead. """ + moe_num_experts: Optional[int] = 8 + """ + The number of experts to use in the MoE block. + """ + + moe_top_k: Optional[int] = 2 + """ + The number of top experts to use in the MoE block. + """ + + moe_capacity_factor: Optional[float] = 1.25 + """ + The capacity factor to use in the MoE block. + """ + + moe_loss_weight: Optional[float] = 0.1 + """ + The weight to use for the MoE loss. + """ + + moe_expert_model_parallelism: Optional[bool] = False + """ + Whether to use model parallelism for the MoE experts. + """ + + moe_dropless: Optional[bool] = False + """ + Whether to use dMoE (https://arxiv.org/abs/2211.15841) + """ + @property def effective_n_kv_heads(self) -> int: if self.n_kv_heads is None: diff --git a/olmo/initialization.py b/olmo/initialization.py index 260e94757..775d867d2 100644 --- a/olmo/initialization.py +++ b/olmo/initialization.py @@ -18,8 +18,8 @@ class ModuleType(StrEnum): def init_weights( - config: ModelConfig, module: Union[nn.Linear, nn.Embedding], + config: ModelConfig, d: Optional[int] = None, layer_id: Optional[int] = None, std_factor: float = 1.0, @@ -40,19 +40,34 @@ def init_weights( std = config.init_std * std_factor if config.init_cutoff_factor is not None: cutoff_value = config.init_cutoff_factor * std - nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value) + if hasattr(module, "weight"): + nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value) + else: + nn.init.trunc_normal_(module, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value) else: - nn.init.normal_(module.weight, mean=0.0, std=std) + if hasattr(module, "weight"): + nn.init.normal_(module.weight, mean=0.0, std=std) + else: + nn.init.normal_(module, mean=0.0, std=std) elif config.init_fn == InitFnType.mitchell: std = std_factor / math.sqrt(d) if layer_id is not None: std = std / math.sqrt(2 * (layer_id + 1)) - nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std) + if hasattr(module, "weight"): + nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std) + else: + nn.init.trunc_normal_(module, mean=0.0, std=std, a=-3 * std, b=3 * std) elif config.init_fn == InitFnType.kaiming_normal: - nn.init.kaiming_normal_(module.weight, nonlinearity="relu") + if hasattr(module, "weight"): + nn.init.kaiming_normal_(module.weight, nonlinearity="relu") + else: + nn.init.kaiming_normal_(module, nonlinearity="relu") elif config.init_fn == InitFnType.fan_in: std = std_factor / math.sqrt(d) - nn.init.normal_(module.weight, mean=0.0, std=std) + if hasattr(module, "weight"): + nn.init.normal_(module.weight, mean=0.0, std=std) + else: + nn.init.normal_(module, mean=0.0, std=std) elif config.init_fn == InitFnType.full_megatron: if type_of_module is None: raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.") @@ -76,13 +91,22 @@ def init_weights( std = config.d_model**-0.5 else: raise RuntimeError(f"Unknown module type '{type_of_module}'") - nn.init.trunc_normal_( - module.weight, - mean=0.0, - std=std, - a=-cutoff_factor * std, - b=cutoff_factor * std, - ) + if hasattr(module, "weight"): + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-cutoff_factor * std, + b=cutoff_factor * std, + ) + else: + nn.init.trunc_normal_( + module, + mean=0.0, + std=std, + a=-cutoff_factor * std, + b=cutoff_factor * std, + ) else: raise NotImplementedError(config.init_fn) diff --git a/olmo/model.py b/olmo/model.py index 555e0ca81..83d9c4da9 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -69,9 +69,14 @@ "OLMoGenerateOutput", ] - log = logging.getLogger(__name__) +try: + from megablocks.layers.arguments import Arguments as MoEArgs + from megablocks.layers.dmoe import dMoE + from megablocks.layers.moe import MoE +except ImportError: + log.warning("megablocks not installed, MoE layers will not be available.") def activation_checkpoint_function(cfg: ModelConfig): preserve_rng_state = ( @@ -474,15 +479,15 @@ def reset_parameters(self): if self.q_norm is not None: self.q_norm.reset_parameters() init_weights( - self.config, self.attn_out, + self.config, d=self.config.d_model, layer_id=self.layer_id, type_of_module=ModuleType.out_module, ) init_weights( - self.config, self.ff_out, + self.config, d=self.ff_out.in_features, layer_id=self.layer_id, type_of_module=ModuleType.out_module, @@ -626,10 +631,213 @@ def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> OLMoBl return OLMoSequentialBlock(layer_id, config, cache) elif config.block_type == BlockType.llama: return OLMoLlamaBlock(layer_id, config, cache) + elif config.block_type == BlockType.moe: + return OLMoEBlock(layer_id, config, cache) else: raise NotImplementedError(f"Unknown block type: '{config.block_type}'") +class OLMoEBlock(OLMoBlock): + """ + This is a a transformer MoE block where the output is computed as ``MoE(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): + nn.Module.__init__(self) + self.layer_id = layer_id + self.config = config + self.hidden_size = ( + config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model + ) + self.__cache = cache + assert config.d_model % config.n_heads == 0 + + self._activation_checkpoint_fn = None + + # Dropout. + self.dropout = Dropout(config.residual_dropout) + + # Layer norms. + self.k_norm: Optional[LayerNormBase] = None + self.q_norm: Optional[LayerNormBase] = None + if config.attention_layer_norm: + assert config.effective_n_kv_heads is not None + self.k_norm = LayerNormBase.build( + config, + size=(config.d_model // config.n_heads) * config.effective_n_kv_heads, + elementwise_affine=config.attention_layer_norm_with_affine, + ) + self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine) + + # Make sure QKV clip coefficient is positive, otherwise it's not well-defined. + if config.clip_qkv is not None: + assert config.clip_qkv > 0 + + # Activation function. + self.act = Activation.build(config) + assert (self.act.output_multiplier * self.hidden_size) % 1 == 0 + + # Attention output projection. + self.attn_out = nn.Linear( + config.d_model, config.d_model, bias=config.include_bias, device=config.init_device + ) + + # MoE Block + self.moe_args = MoEArgs( + activation_fn=F.silu if 'glu' in config.activation_type.lower() else self.act, + mlp_type='glu' if 'glu' in config.activation_type.lower() else 'mlp', + # Recommended for H100s by megablocks but found it 4x slower on H100s + mlp_impl='grouped', + hidden_size=config.d_model, + ffn_hidden_size=int(self.act.output_multiplier * self.hidden_size), + moe_num_experts=config.moe_num_experts, + # Handled by FSDP (https://github.com/databricks/megablocks/issues/57#issuecomment-1854594483) + moe_weight_parallelism=False, + moe_expert_model_parallelism=config.moe_expert_model_parallelism, + moe_top_k=config.moe_top_k, + moe_capacity_factor=config.moe_capacity_factor, + moe_loss_weight=config.moe_loss_weight, + device=config.init_device, + # Handled by FSDP + bf16=False, + fp16=False, + bias=self.config.include_bias, + return_bias=False, + ) + if self.config.moe_dropless: + self.ffn = dMoE(self.moe_args) + else: + self.ffn = MoE(self.moe_args) + + # Rotary embeddings. + if self.config.rope: + self.rotary_emb = RotaryEmbedding(config, self.__cache) + + self.flash_attn_func = None + if config.flash_attention: + try: + from flash_attn import flash_attn_func # type: ignore + + self.flash_attn_func = flash_attn_func + except ModuleNotFoundError: + pass + + self.attn_norm = LayerNorm.build(config) + self.ff_norm = LayerNorm.build(config) + + # Attention input projection. Projects x -> (q, k, v) + head_dim = config.d_model // config.n_heads + self.fused_dims = ( + config.d_model, + config.effective_n_kv_heads * head_dim, + config.effective_n_kv_heads * head_dim, + ) + self.att_proj = nn.Linear( + config.d_model, sum(self.fused_dims), bias=config.include_bias, device=config.init_device + ) + + def reset_parameters(self): + if self.k_norm is not None: + self.k_norm.reset_parameters() + if self.q_norm is not None: + self.q_norm.reset_parameters() + + # NOTE: the standard deviation for these weights does not depend on the layer. + init_weights( + self.att_proj, + self.config, + d=self.config.d_model, + layer_id=None, + type_of_module=ModuleType.in_module + ) + init_weights( + self.attn_out, + self.config, + d=self.config.d_model, + layer_id=self.layer_id, + type_of_module=ModuleType.out_module, + ) + self.attn_norm.reset_parameters() + self.ff_norm.reset_parameters() + init_weights( + self.ffn.experts.mlp.w1, + self.config, + d=self.config.d_model, + layer_id=None, + type_of_module=ModuleType.in_module, + ) + init_weights( + self.ffn.experts.mlp.w2, + self.config, + d=int(self.act.output_multiplier * self.hidden_size), + layer_id=self.layer_id, + type_of_module=ModuleType.out_module, + ) + if self.ffn.experts.bias is not None: + torch.nn.init.zeros_(self.ffn.experts.bias) + init_weights( + self.ffn.router.layer, + self.config, + d=self.config.d_model, + layer_id=None, + type_of_module=ModuleType.out_module, + ) + + def forward( + self, + x: torch.Tensor, + attention_bias: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Get query, key, value projections. + # shape: + # - for regular attn q, k, v: (batch_size, seq_len, d_model) + # - for multi-query attn q: (batch_size, seq_len, d_model) + # k, v: (batch_size, seq_len, d_model // n_heads) + # - for group query attn q: (batch_size, seq_len, d_model) + # k, v: (batch_size, seq_len, d_model // n_kv_heads) + if self._activation_checkpoint_fn is not None: + qkv = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)) + else: + qkv = self.att_proj(self.attn_norm(x)) + + if self.config.clip_qkv is not None: + qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + q, k, v = qkv.split(self.fused_dims, dim=-1) + + # Get attention scores. + if self._activation_checkpoint_fn is not None: + att, cache = self._activation_checkpoint_fn( # type: ignore + self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache + ) + else: + att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache) + + # Add attention scores. + # shape: (B, T, C) + x = x + self.dropout(att) + + # Add feed-forward projection. + # shape: (batch_size, seq_len, d_model) + og_x = x + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore + else: + x = self.ff_norm(x) + + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.ffn, x) # type: ignore + else: + x = self.ffn(x) + + x = self.dropout(x) + x = og_x + x + + return x, cache + + class OLMoSequentialBlock(OLMoBlock): """ This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` @@ -663,10 +871,10 @@ def reset_parameters(self): self.ff_norm.reset_parameters() # NOTE: the standard deviation for these weights does not depend on the layer. init_weights( - self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module + self.att_proj, self.config, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module ) init_weights( - self.config, self.ff_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module + self.ff_proj, self.config, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module ) def forward( @@ -768,10 +976,10 @@ def reset_parameters(self): self.attn_norm.reset_parameters() self.ff_norm.reset_parameters() # NOTE: the standard deviation for these weights does not depend on the layer. - init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None) - init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None) - init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None) - init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None) + init_weights(self.q_proj, self.config, d=self.config.d_model, layer_id=None) + init_weights(self.k_proj, self.config, d=self.config.d_model, layer_id=None) + init_weights(self.v_proj, self.config, d=self.config.d_model, layer_id=None) + init_weights(self.ff_proj, self.config, d=self.config.d_model, layer_id=None) def _scaled_dot_product_attention( self, @@ -1021,20 +1229,20 @@ def reset_parameters(self): log.info("Initializing model parameters...") # Top-level embeddings / linear layers. init_weights( - self.config, self.transformer.wte, # type: ignore + self.config, std_factor=(0.5 * math.sqrt(self.config.d_model)) if self.config.scale_logits else 1.0, type_of_module=ModuleType.emb, ) if hasattr(self.transformer, "wpe"): - init_weights(self.config, self.transformer.wpe, type_of_module=ModuleType.emb) # type: ignore + init_weights(self.transformer.wpe, self.config, type_of_module=ModuleType.emb) # type: ignore # Top-level layer norm. self.transformer.ln_f.reset_parameters() # type: ignore # Output weights. if hasattr(self.transformer, "ff_out"): - init_weights(self.config, self.transformer.ff_out, type_of_module=ModuleType.final_out) # type: ignore + init_weights(self.transformer.ff_out, self.config, type_of_module=ModuleType.final_out) # type: ignore # Let the blocks handle themselves. if self.config.block_group_size == 1: @@ -1329,7 +1537,7 @@ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): else: raise NotImplementedError(wrap_strategy) - def num_params(self, include_embedding: bool = True) -> int: + def num_params(self, include_embedding: bool = True, include_inactivated_experts: bool = True) -> int: """ Get the total number of parameters. """ @@ -1339,6 +1547,14 @@ def num_params(self, include_embedding: bool = True) -> int: lambda np: ".wte." not in np[0] and ".wpe." not in np[0], params, ) + if not include_inactivated_experts: + # Need to reduce blocks the number of experts that are selected + # e.g. 'transformer.blocks.0.ffn.experts.mlp.w1' has shape (total_experts, in_dim, out_dim) + # change to 'transformer.blocks.0.ffn.experts.mlp.w1' has shape (selected_experts, in_dim, out_dim) + params = [ + (np[0], np[1][: self.config.moe_top_k] if "experts.mlp" in np[0] else np[1]) + for np in params + ] return sum(p.numel() for _, p in params) @property diff --git a/olmo/optim.py b/olmo/optim.py index 2f2634238..23ffd6ac8 100644 --- a/olmo/optim.py +++ b/olmo/optim.py @@ -14,6 +14,12 @@ from .config import OptimizerType, SchedulerConfig, SchedulerType, TrainConfig from .torch_util import get_default_device, is_distributed +try: + from megablocks.layers.mlp import MLP, SparseMLP + megablocks_available = True +except ImportError: + megablocks_available = False + __all__ = [ "Optimizer", "LionW", @@ -607,6 +613,7 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]] """ Separate parameters into weight decay and non weight decay groups. """ + from megablocks.layers.mlp import MLP param_groups: List[Dict[str, Any]] param_group_defaults = { "sharded": isinstance(model, FullyShardedDataParallel), @@ -646,6 +653,8 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]] decay.add(fpn) else: no_decay.add(fpn) + elif megablocks_available and pn.endswith(("w1", "w2")) and (isinstance(m, MLP) or isinstance(m, SparseMLP)): + decay.add(fpn) # Validate that we've considered every parameter inter_params = decay & no_decay diff --git a/olmo/train.py b/olmo/train.py index ab4d871b2..e36e5e534 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -26,6 +26,7 @@ from .aliases import PathOrStr from .checkpoint import Checkpointer, FullCheckpointer, build_sharded_checkpointer from .config import ( + BlockType, CheckpointType, SchedulerUnits, ShardedCheckpointerType, @@ -54,6 +55,12 @@ log = logging.getLogger(__name__) +try: + from megablocks.layers.moe import batched_load_balancing_loss, clear_load_balancing_loss, get_load_balancing_loss + from megablocks.layers.arguments import Arguments as MoEArgs +except ImportError: + log.warning(f"Megablocks not installed. To train MoE, install with pip install megablocks.") + @dataclass class SpeedMonitor: @@ -145,6 +152,8 @@ class Trainer: def __post_init__(self): if self.cfg.fused_loss: + if self.model.config.block_type == BlockType.moe: + raise NotImplementedError("Fused loss is not implemented for MoE models.") from flash_attn.ops.triton.cross_entropy import ( # type: ignore cross_entropy_loss, ) @@ -186,6 +195,23 @@ def fused_loss_fn( self.loss_fn = fused_loss_fn + if self.model.config.block_type == BlockType.moe: + # these MoEArgs are necessary for logging load balancing. + self.moe_args = MoEArgs( + hidden_size=self.model.config.d_model, + ffn_hidden_size=self.model.config.d_model * 4, + moe_num_experts=self.model.config.moe_num_experts, + num_layers=self.model.config.n_layers, + # Not tested for now + moe_expert_model_parallelism=False, + moe_top_k=self.model.config.moe_top_k, + device=self.model.config.init_device, + moe_capacity_factor=self.model.config.moe_capacity_factor, + moe_loss_weight=self.model.config.moe_loss_weight, + fp16=False, + bf16=False, + ) + @property def dataset(self) -> IterableDataset: assert isinstance(self.train_loader.dataset, IterableDataset) @@ -643,6 +669,9 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor ce_batch_loss = torch.tensor(0.0, device=self.device) z_batch_loss = None if not self.cfg.softmax_auxiliary_loss else torch.tensor(0.0, device=self.device) + lb_batch_loss = None if self.model.config.block_type != BlockType.moe else torch.tensor(0.0, device=self.device) + # Keep this one on CPU to save memory + expert_assignments = None if self.model.config.block_type != BlockType.moe else torch.zeros((self.model.config.n_layers, self.model.config.moe_num_experts)) for micro_batch in micro_batches: with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision): # Run forward pass. @@ -669,12 +698,22 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor else: loss = ce_loss + if self.model.config.block_type == BlockType.moe: + lb_loss = batched_load_balancing_loss(self.moe_args) / len(micro_batches) + + tokens_per_expert, _ = zip(*get_load_balancing_loss()) + expert_assignments += torch.stack(tokens_per_expert, dim=0).cpu() + + clear_load_balancing_loss() + loss += lb_loss + lb_batch_loss += lb_loss.detach() + del logits # Run backward pass. loss.backward() - return ce_batch_loss, z_batch_loss + return ce_batch_loss, z_batch_loss, lb_batch_loss, expert_assignments def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> Dict[str, float]: metrics: Dict[str, float] = {} @@ -691,7 +730,7 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> batch = move_to_device(batch, self.device) # Run forward-backward pass. - ce_batch_loss, z_batch_loss = self.train_batch(batch) + ce_batch_loss, z_batch_loss, lb_batch_loss, expert_assignments = self.train_batch(batch) # Collect loss, potentially reducing over all ranks. if reduce_global_loss: @@ -700,6 +739,9 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> if z_batch_loss is not None: dist.reduce(z_batch_loss, 0) z_batch_loss.div_(get_world_size()) + if lb_batch_loss is not None: + dist.reduce(lb_batch_loss, 0) + lb_batch_loss.div_(get_world_size()) # Clip gradient norms and collect param/gradient/optim metrics. should_log_optim_metrics_this_step = self.should_log_optim_metrics_this_step() @@ -732,9 +774,11 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> # Collect metrics and check for NaN loss. # NOTE: this involves a bunch of host-device syncs so we wait until the last moment to do this. if torch.isnan(ce_batch_loss): - raise ValueError("nan loss encountered") + raise ValueError("nan ce loss encountered") if z_batch_loss is not None and torch.isnan(z_batch_loss): - raise ValueError("nan loss encountered") + raise ValueError("nan z loss encountered") + if lb_batch_loss is not None and torch.isnan(lb_batch_loss): + raise ValueError("nan lb loss encountered") for key, value in optim_metrics.items(): metrics[f"optim/{key}"] = value.item() self.cur_train_loss = ce_batch_loss.item() @@ -743,6 +787,14 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> metrics["train/Perplexity"] = math.exp(self.cur_train_loss) if z_batch_loss is not None: metrics["train/ZLoss"] = z_batch_loss.item() + if lb_batch_loss is not None: + metrics["train/LoadBalancingLoss"] = lb_batch_loss.item() + # Log assignment metrics. + for layer_idx, expert_assignments_layer in enumerate(expert_assignments): + total_tokens = expert_assignments_layer.sum().item() + for expert_idx, expert_assignment in enumerate(expert_assignments_layer): + metrics[f"train/TokensPercentage/layer{layer_idx}/expert{expert_idx}"] = (expert_assignment.item() / total_tokens) * 100 + metrics[f"train/TokensTotal/layer{layer_idx}/expert{expert_idx}"] = expert_assignment.item() # Maybe collect post-step optimizer-specific metrics. if should_log_optim_metrics_this_step: diff --git a/scripts/train.py b/scripts/train.py index 23471ca94..0f0b4da46 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -119,6 +119,8 @@ def main(cfg: TrainConfig) -> None: olmo_model = OLMo(cfg.model) log.info(f"Total number of parameters: {olmo_model.num_params():,d}") log.info(f"Number of non-embedding parameters: {olmo_model.num_params(include_embedding=False):,d}") + if olmo_model.config.block_type == "moe": + log.info(f"Number of active parameters: {olmo_model.num_params(include_inactivated_experts=False):,d}") log.info(f"Peak GPU Memory (MB) before FSDP: {int(peak_gpu_memory() or 0)}") olmo_model.set_activation_checkpointing(cfg.activation_checkpointing)