From f907988a0c1693aa56025a487d8f0d933e2c23bb Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Tue, 9 Apr 2024 22:26:40 +0200 Subject: [PATCH 01/13] Preliminary Megablocks --- olmo/config.py | 2 + olmo/model.py | 160 +++++++++++++++++++++++++++++++++++++++++++++++++ olmo/train.py | 44 ++++++++++++-- 3 files changed, 202 insertions(+), 4 deletions(-) diff --git a/olmo/config.py b/olmo/config.py index 042c704ce..bbc0000b7 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -181,6 +181,8 @@ class BlockType(StrEnum): sequential = "sequential" llama = "llama" + + moe = "moe" """ A block similar to the sequential block with slightly different implementations of operations like attention to imitate the behavior of Llama. diff --git a/olmo/model.py b/olmo/model.py index 555e0ca81..e6cd7df38 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -72,6 +72,12 @@ log = logging.getLogger(__name__) +try: + from megablocks.layers.moe import MoE + from megablocks.layers.arguments import Arguments as MoEArgs +except ImportError: + log.warning("megablocks not installed, MoE layers will not be available.") + def activation_checkpoint_function(cfg: ModelConfig): preserve_rng_state = ( @@ -626,10 +632,164 @@ def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> OLMoBl return OLMoSequentialBlock(layer_id, config, cache) elif config.block_type == BlockType.llama: return OLMoLlamaBlock(layer_id, config, cache) + elif config.block_type == BlockType.moe: + return OLMoEBlock(layer_id, config, cache) else: raise NotImplementedError(f"Unknown block type: '{config.block_type}'") +class OLMoEBlock(OLMoBlock): + """ + This is a a transformer MoE block where the output is computed as ``MoE(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): + super().__init__() + self.layer_id = layer_id + self.config = config + self.hidden_size = ( + config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model + ) + self.__cache = cache + assert config.d_model % config.n_heads == 0 + + self._activation_checkpoint_fn = None + + # Dropout. + self.dropout = Dropout(config.residual_dropout) + + # Layer norms. + self.k_norm: Optional[LayerNormBase] = None + self.q_norm: Optional[LayerNormBase] = None + if config.attention_layer_norm: + assert config.effective_n_kv_heads is not None + self.k_norm = LayerNormBase.build( + config, + size=(config.d_model // config.n_heads) * config.effective_n_kv_heads, + elementwise_affine=config.attention_layer_norm_with_affine, + ) + self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine) + + # Make sure QKV clip coefficient is positive, otherwise it's not well-defined. + if config.clip_qkv is not None: + assert config.clip_qkv > 0 + + # Activation function. + self.act = Activation.build(config) + assert (self.act.output_multiplier * self.hidden_size) % 1 == 0 + + # Attention output projection. + self.attn_out = nn.Linear( + config.d_model, config.d_model, bias=config.include_bias, device=config.init_device + ) + + # MoE Block + moe_args = MoEArgs( + hidden_size=config.d_model, + ffn_hidden_size=config.d_model*4,#int(self.act.output_multiplier * self.hidden_size), + moe_num_experts=8,#config.moe_num_experts, + moe_weight_parallelism=False,#config.moe_weight_parallelism, + moe_expert_model_parallelism=False,#config.moe_expert_model_parallelism, + moe_top_k=2,#config.moe_top_k, + moe_capacity_factor=1.25,#config.moe_capacity_factor, + moe_loss_weight=0.1,#config.moe_loss_weight, + device=torch.cuda.current_device(), + # Handled by FSDP + bf16=False, + fp16=False, + ) + self.ffn = MoE(moe_args) + + # Rotary embeddings. + if self.config.rope: + self.rotary_emb = RotaryEmbedding(config, self.__cache) + + self.flash_attn_func = None + if config.flash_attention: + try: + from flash_attn import flash_attn_func # type: ignore + + self.flash_attn_func = flash_attn_func + except ModuleNotFoundError: + pass + + def reset_parameters(self): + if self.k_norm is not None: + self.k_norm.reset_parameters() + if self.q_norm is not None: + self.q_norm.reset_parameters() + init_weights( + self.config, + self.attn_out, + d=self.config.d_model, + layer_id=self.layer_id, + type_of_module=ModuleType.out_module, + ) + self.attn_norm.reset_parameters() + self.ff_norm.reset_parameters() + # NOTE: the standard deviation for these weights does not depend on the layer. + init_weights( + self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module + ) + init_weights( + self.config, self.ffn, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module + ) + + def forward( + self, + x: torch.Tensor, + attention_bias: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Get query, key, value projections. + # shape: + # - for regular attn q, k, v: (batch_size, seq_len, d_model) + # - for multi-query attn q: (batch_size, seq_len, d_model) + # k, v: (batch_size, seq_len, d_model // n_heads) + # - for group query attn q: (batch_size, seq_len, d_model) + # k, v: (batch_size, seq_len, d_model // n_kv_heads) + if self._activation_checkpoint_fn is not None: + qkv = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)) + else: + qkv = self.att_proj(self.attn_norm(x)) + + if self.config.clip_qkv is not None: + qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + q, k, v = qkv.split(self.fused_dims, dim=-1) + + # Get attention scores. + if self._activation_checkpoint_fn is not None: + att, cache = self._activation_checkpoint_fn( # type: ignore + self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache + ) + else: + att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache) + + # Add attention scores. + # shape: (B, T, C) + x = x + self.dropout(att) + + # Add feed-forward projection. + # shape: (batch_size, seq_len, d_model) + og_x = x + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore + else: + x = self.ff_norm(x) + + if self._activation_checkpoint_fn is not None: + x, _ = self._activation_checkpoint_fn(self.ffn, x) # type: ignore + else: + x, _ = self.ffn(x) + + x = self.dropout(x) + x = og_x + x + + return x, cache + + class OLMoSequentialBlock(OLMoBlock): """ This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` diff --git a/olmo/train.py b/olmo/train.py index 71a45312e..3e1623bd0 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -26,6 +26,7 @@ from .aliases import PathOrStr from .checkpoint import Checkpointer, FullCheckpointer, build_sharded_checkpointer from .config import ( + BlockType, CheckpointType, SchedulerUnits, ShardedCheckpointerType, @@ -54,6 +55,12 @@ log = logging.getLogger(__name__) +try: + from megablocks.layers.moe import batched_load_balancing_loss, clear_load_balancing_loss + from megablocks.layers.arguments import Arguments as MoEArgs +except ImportError: + log.warning(f"Megablocks not installed. To train MoE, install with pip install megablocks.") + @dataclass class SpeedMonitor: @@ -186,6 +193,22 @@ def fused_loss_fn( self.loss_fn = fused_loss_fn + if self.cfg.block_type == BlockType.moe:#self.cfg.moe_freq > 0: + # these MoEArgs are necessary for logging load balancing. + self.moe_args = MoEArgs( + hidden_size=self.cfg.d_model, + ffn_hidden_size=self.cfg.d_model * 4, + moe_num_experts=8,#self.cfg.moe_num_experts, + num_layers=self.cfg.n_layers,#if params.moe_freq > 0 and layer_id % params.moe_freq == 0: + moe_expert_model_parallelism=True, + moe_top_k=2,#self.cfg.moe_top_k, + device=torch.cuda.current_device(), + moe_capacity_factor=1.25,#self.cfg.moe_capacity_factor, + moe_loss_weight=0.1,#self.cfg.moe_loss_weight, + fp16=False, + bf16=False, + ) + @property def dataset(self) -> IterableDataset: assert isinstance(self.train_loader.dataset, IterableDataset) @@ -643,6 +666,7 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor ce_batch_loss = torch.tensor(0.0, device=self.device) z_batch_loss = None if not self.cfg.softmax_auxiliary_loss else torch.tensor(0.0, device=self.device) + lb_batch_loss = None if self.cfg.block_type != BlockType.moe else torch.tensor(0.0, device=self.device) for micro_batch in micro_batches: with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision): # Run forward pass. @@ -669,12 +693,17 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor else: loss = ce_loss + if self.cfg.block_type == BlockType.moe: + lb_batch_loss = batched_load_balancing_loss(self.moe_args) + clear_load_balancing_loss() + loss += lb_batch_loss + del logits # Run backward pass. loss.backward() - return ce_batch_loss, z_batch_loss + return ce_batch_loss, z_batch_loss, lb_batch_loss def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> Dict[str, float]: metrics: Dict[str, float] = {} @@ -691,7 +720,7 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> batch = move_to_device(batch, self.device) # Run forward-backward pass. - ce_batch_loss, z_batch_loss = self.train_batch(batch) + ce_batch_loss, z_batch_loss, lb_batch_loss = self.train_batch(batch) # Collect loss, potentially reducing over all ranks. if reduce_global_loss: @@ -700,6 +729,9 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> if z_batch_loss is not None: dist.reduce(z_batch_loss, 0) z_batch_loss.div_(get_world_size()) + if lb_batch_loss is not None: + dist.reduce(lb_batch_loss, 0) + lb_batch_loss.div_(get_world_size()) # Clip gradient norms and collect param/gradient/optim metrics. should_log_optim_metrics_this_step = self.should_log_optim_metrics_this_step() @@ -728,9 +760,11 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> # Collect metrics and check for NaN loss. # NOTE: this involves a bunch of host-device syncs so we wait until the last moment to do this. if torch.isnan(ce_batch_loss): - raise ValueError("nan loss encountered") + raise ValueError("nan ce loss encountered") if z_batch_loss is not None and torch.isnan(z_batch_loss): - raise ValueError("nan loss encountered") + raise ValueError("nan z loss encountered") + if lb_batch_loss is not None and torch.isnan(lb_batch_loss): + raise ValueError("nan lb loss encountered") for key, value in optim_metrics.items(): metrics[f"optim/{key}"] = value.item() self.cur_train_loss = ce_batch_loss.item() @@ -739,6 +773,8 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> metrics["train/Perplexity"] = math.exp(self.cur_train_loss) if z_batch_loss is not None: metrics["train/ZLoss"] = z_batch_loss.item() + if lb_batch_loss is not None: + metrics["train/LoadBalancingLoss"] = lb_batch_loss.item() # Maybe collect post-step optimizer-specific metrics. if should_log_optim_metrics_this_step: From 761d36a06a17f4daf514c37497b3646f9e8fb70a Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Tue, 16 Apr 2024 00:53:49 +0000 Subject: [PATCH 02/13] Make MoE configurable --- olmo/config.py | 20 +++++++++++ olmo/initialization.py | 7 ++-- olmo/model.py | 80 +++++++++++++++++++++++++----------------- olmo/optim.py | 9 +++++ olmo/train.py | 26 +++++++------- 5 files changed, 96 insertions(+), 46 deletions(-) diff --git a/olmo/config.py b/olmo/config.py index bbc0000b7..3a5d5de27 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -435,6 +435,26 @@ class ModelConfig(BaseConfig): See :data:`TrainConfig.precision` instead. """ + moe_num_experts: Optional[int] = 8 + """ + The number of experts to use in the MoE block. + """ + + moe_top_k: Optional[int] = 2 + """ + The number of top experts to use in the MoE block. + """ + + moe_capacity_factor: Optional[float] = 1.25 + """ + The capacity factor to use in the MoE block. + """ + + moe_loss_weight: Optional[float] = 0.1 + """ + The weight to use for the MoE loss. + """ + @property def effective_n_kv_heads(self) -> int: if self.n_kv_heads is None: diff --git a/olmo/initialization.py b/olmo/initialization.py index 260e94757..df1f947f4 100644 --- a/olmo/initialization.py +++ b/olmo/initialization.py @@ -18,8 +18,8 @@ class ModuleType(StrEnum): def init_weights( - config: ModelConfig, module: Union[nn.Linear, nn.Embedding], + config: ModelConfig, d: Optional[int] = None, layer_id: Optional[int] = None, std_factor: float = 1.0, @@ -47,7 +47,10 @@ def init_weights( std = std_factor / math.sqrt(d) if layer_id is not None: std = std / math.sqrt(2 * (layer_id + 1)) - nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std) + if hasattr(module, "weight"): + nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std) + else: + nn.init.trunc_normal_(module, mean=0.0, std=std, a=-3 * std, b=3 * std) elif config.init_fn == InitFnType.kaiming_normal: nn.init.kaiming_normal_(module.weight, nonlinearity="relu") elif config.init_fn == InitFnType.fan_in: diff --git a/olmo/model.py b/olmo/model.py index e6cd7df38..b694fd87b 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -53,6 +53,12 @@ else: raise SystemExit("This script supports Python 3.8 or higher") +try: + from megablocks.layers.moe import MoE + from megablocks.layers.arguments import Arguments as MoEArgs +except ImportError: + log.warning("megablocks not installed, MoE layers will not be available.") + __all__ = [ "LayerNormBase", "LayerNorm", @@ -72,12 +78,6 @@ log = logging.getLogger(__name__) -try: - from megablocks.layers.moe import MoE - from megablocks.layers.arguments import Arguments as MoEArgs -except ImportError: - log.warning("megablocks not installed, MoE layers will not be available.") - def activation_checkpoint_function(cfg: ModelConfig): preserve_rng_state = ( @@ -480,15 +480,15 @@ def reset_parameters(self): if self.q_norm is not None: self.q_norm.reset_parameters() init_weights( - self.config, self.attn_out, + self.config, d=self.config.d_model, layer_id=self.layer_id, type_of_module=ModuleType.out_module, ) init_weights( - self.config, self.ff_out, + self.config, d=self.ff_out.in_features, layer_id=self.layer_id, type_of_module=ModuleType.out_module, @@ -633,7 +633,7 @@ def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> OLMoBl elif config.block_type == BlockType.llama: return OLMoLlamaBlock(layer_id, config, cache) elif config.block_type == BlockType.moe: - return OLMoEBlock(layer_id, config, cache) + return OLMoEBlock(layer_id, config, cache) else: raise NotImplementedError(f"Unknown block type: '{config.block_type}'") @@ -644,7 +644,7 @@ class OLMoEBlock(OLMoBlock): (plus another skip connection). """ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): - super().__init__() + nn.Module.__init__(self) self.layer_id = layer_id self.config = config self.hidden_size = ( @@ -685,18 +685,23 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): # MoE Block moe_args = MoEArgs( + activation_fn=F.silu if 'glu' in config.activation_type.lower() else self.act, + mlp_type='glu' if 'glu' in config.activation_type.lower() else 'mlp', hidden_size=config.d_model, - ffn_hidden_size=config.d_model*4,#int(self.act.output_multiplier * self.hidden_size), - moe_num_experts=8,#config.moe_num_experts, - moe_weight_parallelism=False,#config.moe_weight_parallelism, - moe_expert_model_parallelism=False,#config.moe_expert_model_parallelism, - moe_top_k=2,#config.moe_top_k, - moe_capacity_factor=1.25,#config.moe_capacity_factor, - moe_loss_weight=0.1,#config.moe_loss_weight, - device=torch.cuda.current_device(), + ffn_hidden_size=int(self.act.output_multiplier * self.hidden_size), + moe_num_experts=config.moe_num_experts, + # Handled by FSDP (https://github.com/databricks/megablocks/issues/57#issuecomment-1854594483) + moe_weight_parallelism=False, + # Not tested for now + moe_expert_model_parallelism=False, + moe_top_k=config.moe_top_k, + moe_capacity_factor=config.moe_capacity_factor, + moe_loss_weight=config.moe_loss_weight, + device=config.init_device, # Handled by FSDP bf16=False, fp16=False, + init_method=partial(init_weights, config=config, d=config.d_model, layer_id=None, type_of_module=ModuleType.in_module), ) self.ffn = MoE(moe_args) @@ -713,14 +718,28 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): except ModuleNotFoundError: pass + self.attn_norm = LayerNorm.build(config) + self.ff_norm = LayerNorm.build(config) + + # Attention input projection. Projects x -> (q, k, v) + head_dim = config.d_model // config.n_heads + self.fused_dims = ( + config.d_model, + config.effective_n_kv_heads * head_dim, + config.effective_n_kv_heads * head_dim, + ) + self.att_proj = nn.Linear( + config.d_model, sum(self.fused_dims), bias=config.include_bias, device=config.init_device + ) + def reset_parameters(self): if self.k_norm is not None: self.k_norm.reset_parameters() if self.q_norm is not None: self.q_norm.reset_parameters() init_weights( - self.config, self.attn_out, + self.config, d=self.config.d_model, layer_id=self.layer_id, type_of_module=ModuleType.out_module, @@ -729,10 +748,7 @@ def reset_parameters(self): self.ff_norm.reset_parameters() # NOTE: the standard deviation for these weights does not depend on the layer. init_weights( - self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module - ) - init_weights( - self.config, self.ffn, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module + self.att_proj, self.config, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module ) def forward( @@ -823,10 +839,10 @@ def reset_parameters(self): self.ff_norm.reset_parameters() # NOTE: the standard deviation for these weights does not depend on the layer. init_weights( - self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module + self.att_proj, self.config, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module ) init_weights( - self.config, self.ff_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module + self.ff_proj, self.config, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module ) def forward( @@ -928,10 +944,10 @@ def reset_parameters(self): self.attn_norm.reset_parameters() self.ff_norm.reset_parameters() # NOTE: the standard deviation for these weights does not depend on the layer. - init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None) - init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None) - init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None) - init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None) + init_weights(self.q_proj, self.config, d=self.config.d_model, layer_id=None) + init_weights(self.k_proj, self.config, d=self.config.d_model, layer_id=None) + init_weights(self.v_proj, self.config, d=self.config.d_model, layer_id=None) + init_weights(self.ff_proj, self.config, d=self.config.d_model, layer_id=None) def _scaled_dot_product_attention( self, @@ -1181,20 +1197,20 @@ def reset_parameters(self): log.info("Initializing model parameters...") # Top-level embeddings / linear layers. init_weights( - self.config, self.transformer.wte, # type: ignore + self.config, std_factor=(0.5 * math.sqrt(self.config.d_model)) if self.config.scale_logits else 1.0, type_of_module=ModuleType.emb, ) if hasattr(self.transformer, "wpe"): - init_weights(self.config, self.transformer.wpe, type_of_module=ModuleType.emb) # type: ignore + init_weights(self.transformer.wpe, self.config, type_of_module=ModuleType.emb) # type: ignore # Top-level layer norm. self.transformer.ln_f.reset_parameters() # type: ignore # Output weights. if hasattr(self.transformer, "ff_out"): - init_weights(self.config, self.transformer.ff_out, type_of_module=ModuleType.final_out) # type: ignore + init_weights(self.transformer.ff_out, self.config, type_of_module=ModuleType.final_out) # type: ignore # Let the blocks handle themselves. if self.config.block_group_size == 1: diff --git a/olmo/optim.py b/olmo/optim.py index 2c2d988df..d8eac62d0 100644 --- a/olmo/optim.py +++ b/olmo/optim.py @@ -14,6 +14,12 @@ from .config import OptimizerType, SchedulerConfig, SchedulerType, TrainConfig from .torch_util import get_default_device, is_distributed +try: + from megablocks.layers.mlp import MLP, SparseMLP + megablocks_available = True +except ImportError: + megablocks_available = False + __all__ = [ "Optimizer", "LionW", @@ -588,6 +594,7 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]] """ Separate parameters into weight decay and non weight decay groups. """ + from megablocks.layers.mlp import MLP param_groups: List[Dict[str, Any]] param_group_defaults = { "sharded": isinstance(model, FullyShardedDataParallel), @@ -627,6 +634,8 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]] decay.add(fpn) else: no_decay.add(fpn) + elif megablocks_available and pn.endswith(("w1", "w2")) and (isinstance(m, MLP) or isinstance(m, SparseMLP)): + decay.add(fpn) # Validate that we've considered every parameter inter_params = decay & no_decay diff --git a/olmo/train.py b/olmo/train.py index 3e1623bd0..60376eca8 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -193,18 +193,20 @@ def fused_loss_fn( self.loss_fn = fused_loss_fn - if self.cfg.block_type == BlockType.moe:#self.cfg.moe_freq > 0: + print(self.cfg) + if self.model.config.block_type == BlockType.moe: # these MoEArgs are necessary for logging load balancing. self.moe_args = MoEArgs( - hidden_size=self.cfg.d_model, - ffn_hidden_size=self.cfg.d_model * 4, - moe_num_experts=8,#self.cfg.moe_num_experts, - num_layers=self.cfg.n_layers,#if params.moe_freq > 0 and layer_id % params.moe_freq == 0: - moe_expert_model_parallelism=True, - moe_top_k=2,#self.cfg.moe_top_k, - device=torch.cuda.current_device(), - moe_capacity_factor=1.25,#self.cfg.moe_capacity_factor, - moe_loss_weight=0.1,#self.cfg.moe_loss_weight, + hidden_size=self.model.config.d_model, + ffn_hidden_size=self.model.config.d_model * 4, + moe_num_experts=self.model.config.moe_num_experts, + num_layers=self.model.config.n_layers, + # Not tested for nowe + moe_expert_model_parallelism=False, + moe_top_k=self.model.config.moe_top_k, + device=self.model.config.init_device, + moe_capacity_factor=self.model.config.moe_capacity_factor, + moe_loss_weight=self.model.config.moe_loss_weight, fp16=False, bf16=False, ) @@ -666,7 +668,7 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor ce_batch_loss = torch.tensor(0.0, device=self.device) z_batch_loss = None if not self.cfg.softmax_auxiliary_loss else torch.tensor(0.0, device=self.device) - lb_batch_loss = None if self.cfg.block_type != BlockType.moe else torch.tensor(0.0, device=self.device) + lb_batch_loss = None if self.model.config.block_type != BlockType.moe else torch.tensor(0.0, device=self.device) for micro_batch in micro_batches: with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision): # Run forward pass. @@ -693,7 +695,7 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor else: loss = ce_loss - if self.cfg.block_type == BlockType.moe: + if self.model.config.block_type == BlockType.moe: lb_batch_loss = batched_load_balancing_loss(self.moe_args) clear_load_balancing_loss() loss += lb_batch_loss From f898fd9a02be65521aaf31d9dade49c2d5a3907a Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Tue, 16 Apr 2024 14:04:53 +0000 Subject: [PATCH 03/13] Log active params --- olmo/model.py | 10 +++++++++- olmo/train.py | 1 - scripts/train.py | 2 ++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/olmo/model.py b/olmo/model.py index b694fd87b..ba17252b0 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -1505,7 +1505,7 @@ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): else: raise NotImplementedError(wrap_strategy) - def num_params(self, include_embedding: bool = True) -> int: + def num_params(self, include_embedding: bool = True, include_inactivated_experts: bool = True) -> int: """ Get the total number of parameters. """ @@ -1515,6 +1515,14 @@ def num_params(self, include_embedding: bool = True) -> int: lambda np: ".wte." not in np[0] and ".wpe." not in np[0], params, ) + if not include_inactivated_experts: + # Need to reduce blocks the number of experts that are selected + # e.g. 'transformer.blocks.0.ffn.experts.mlp.w1' has shape (total_experts, in_dim, out_dim) + # change to 'transformer.blocks.0.ffn.experts.mlp.w1' has shape (selected_experts, in_dim, out_dim) + params = [ + (np[0], np[1][: self.config.moe_top_k] if "experts.mlp" in np[0] else np[1]) + for np in params + ] return sum(p.numel() for _, p in params) @property diff --git a/olmo/train.py b/olmo/train.py index 60376eca8..6ab5377dc 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -193,7 +193,6 @@ def fused_loss_fn( self.loss_fn = fused_loss_fn - print(self.cfg) if self.model.config.block_type == BlockType.moe: # these MoEArgs are necessary for logging load balancing. self.moe_args = MoEArgs( diff --git a/scripts/train.py b/scripts/train.py index f93734c0b..67444fc31 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -117,6 +117,8 @@ def main(cfg: TrainConfig) -> None: olmo_model = OLMo(cfg.model) log.info(f"Total number of parameters: {olmo_model.num_params():,d}") log.info(f"Number of non-embedding parameters: {olmo_model.num_params(include_embedding=False):,d}") + if olmo_model.config.block_type == "moe": + log.info(f"Number of active parameters: {olmo_model.num_params(include_inactivated_experts=False):,d}") log.info(f"Peak GPU Memory (MB) before FSDP: {int(peak_gpu_memory() or 0)}") olmo_model.set_activation_checkpointing(cfg.activation_checkpointing) From 3b30044ff0eea25884dea15f19d7e251338d22cb Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Tue, 16 Apr 2024 17:20:35 +0000 Subject: [PATCH 04/13] Confs --- configs/OLMoE-200m-80m.yml | 145 ++++++++++++++++++++++++++++++++++++ configs/OLMoE-200m.yml | 107 ++++++++++++++++++++++++++ configs/OLMoE-600m-200m.yml | 145 ++++++++++++++++++++++++++++++++++++ configs/OLMoE-test.yaml | 100 +++++++++++++++++++++++++ olmo/initialization.py | 43 ++++++++--- 5 files changed, 529 insertions(+), 11 deletions(-) create mode 100644 configs/OLMoE-200m-80m.yml create mode 100644 configs/OLMoE-200m.yml create mode 100644 configs/OLMoE-600m-200m.yml create mode 100644 configs/OLMoE-test.yaml diff --git a/configs/OLMoE-200m-80m.yml b/configs/OLMoE-200m-80m.yml new file mode 100644 index 000000000..1af6faa5c --- /dev/null +++ b/configs/OLMoE-200m-80m.yml @@ -0,0 +1,145 @@ +run_name: OLMoE +seed: 6198 +dry_run: false + +wandb: + name: ${run_name} + project: olmoe + group: null + +model: + d_model: 512 + n_heads: 8 + n_layers: 10 + mlp_ratio: 4 # 4 vs 8 (for swiglu) + weight_tying: true + alibi: false + rope: true + flash_attention: true + attention_dropout: 0.0 + attention_layer_norm: false + multi_query_attention: false + include_bias: false + block_type: moe + layer_norm_type: default + layer_norm_with_affine: false + bias_for_layer_norm: false + attention_layer_norm_with_affine: false + activation_type: gelu # gelu vs swiglu + residual_dropout: 0.0 + embedding_dropout: 0.0 + max_sequence_length: 2048 + vocab_size: 50280 + embedding_size: 50304 + eos_token_id: 50279 + pad_token_id: 1 + init_device: meta + init_fn: normal # mitchell vs normal + +compile: null # causes instability on AMD GPUs + +optimizer: + name: adamw + learning_rate: 4.0e-4 + weight_decay: 0.1 + betas: + - 0.9 + - 0.95 + metrics_log_interval: 10 + +scheduler: + name: cosine_with_warmup + t_warmup: 2000 + alpha_f: 0.1 + +tokenizer: + identifier: tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json + truncate_direction: right + +save_folder: /data/niklas/olmoe +save_overwrite: false +# Sharded checkpoints (best for restarts) +save_interval: 1000 +save_num_checkpoints_to_keep: 9 +# Unsharded checkpoints (for final storage) +save_interval_unsharded: 10000 +save_num_unsharded_checkpoints_to_keep: -1 + +load_path: null + +max_duration: 10e9T # 10B tokens +global_train_batch_size: 2048 +device_train_microbatch_size: 8 + +precision: amp_bf16 + +fsdp: + wrapping_strategy: null + precision: mixed + +max_grad_norm: 1.0 +max_grad_norm_ratio: null + +speed_monitor: + window_size: 20 + +eval_interval: ${save_interval} +eval_subset_num_batches: -1 +device_eval_batch_size: ${device_train_microbatch_size} +evaluators: + # lump all the small datasets together (we still get separate metrics). + - label: v3-small-ppl-validation + data: + num_workers: 0 + drop_last: true + datasets: + v3-small-c4_en-validation: + - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/c4_en/val/part-0-00000.npy + v3-small-dolma_books-validation: + - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_books/val/part-0-00000.npy + v3-small-dolma_common-crawl-validation: + - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_common-crawl/val/part-0-00000.npy + v3-small-dolma_pes2o-validation: + - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_pes2o/val/part-0-00000.npy + v3-small-dolma_reddit-validation: + - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_reddit/val/part-0-00000.npy + v3-small-dolma_stack-validation: + - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_stack/val/part-0-00000.npy + v3-small-dolma_wiki-validation: + - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_wiki/val/part-0-00000.npy + v3-small-ice-validation: + - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/ice/val/part-0-00000.npy + v3-small-m2d2_s2orc-validation: + - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/m2d2_s2orc/val/part-0-00000.npy + v3-small-pile-validation: + - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/pile/val/part-0-00000.npy + v3-small-wikitext_103-validation: + - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/wikitext_103/val/part-0-00000.npy + +data: + pad_direction: right + num_workers: 0 + drop_last: true + pin_memory: true + prefetch_factor: 16 + persistent_workers: true + timeout: 0 + paths: + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-000-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-000-00001.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-001-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-002-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-003-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-004-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-004-00001.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-005-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-005-00001.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-006-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-006-00001.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-007-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-008-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-008-00001.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-009-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-009-00001.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-010-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-010-00001.npy diff --git a/configs/OLMoE-200m.yml b/configs/OLMoE-200m.yml new file mode 100644 index 000000000..911791f79 --- /dev/null +++ b/configs/OLMoE-200m.yml @@ -0,0 +1,107 @@ +run_name: OLMoE +seed: 6198 +dry_run: false + +wandb: + name: ${run_name} + project: olmoe + group: null + +model: + d_model: 896 + n_heads: 14 + n_layers: 16 + mlp_ratio: 4 # 4 vs 8 (for swiglu) + weight_tying: true + alibi: false + rope: true + flash_attention: true + attention_dropout: 0.0 + attention_layer_norm: false + multi_query_attention: false + include_bias: false + block_type: sequential + layer_norm_type: default + layer_norm_with_affine: false + bias_for_layer_norm: false + attention_layer_norm_with_affine: false + activation_type: gelu # gelu vs swiglu + residual_dropout: 0.0 + embedding_dropout: 0.0 + max_sequence_length: 2048 + vocab_size: 50280 + embedding_size: 50304 + eos_token_id: 50279 + pad_token_id: 1 + init_device: meta + init_fn: normal # mitchell vs normal + +compile: null # causes instability on AMD GPUs + +optimizer: + name: adamw + learning_rate: 4.0e-4 + weight_decay: 0.1 + betas: + - 0.9 + - 0.95 + metrics_log_interval: 10 + +scheduler: + name: cosine_with_warmup + t_warmup: 2000 + alpha_f: 0.1 + +tokenizer: + identifier: tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json + truncate_direction: right + +save_folder: /data/niklas/olmoe +save_overwrite: false +# Sharded checkpoints (best for restarts) +save_interval: 1000 +save_num_checkpoints_to_keep: 9 +# Unsharded checkpoints (for final storage) +save_interval_unsharded: 10000 +save_num_unsharded_checkpoints_to_keep: -1 + +load_path: null + +max_duration: 10e9T # 10B tokens +global_train_batch_size: 2048 +device_train_microbatch_size: 8 + +precision: amp_bf16 + +fsdp: + wrapping_strategy: null + precision: mixed + +max_grad_norm: 1.0 +max_grad_norm_ratio: null + +speed_monitor: + window_size: 20 + +eval_interval: ${save_interval} +eval_subset_num_batches: -1 +device_eval_batch_size: ${device_train_microbatch_size} + +data: + pad_direction: right + num_workers: 0 + drop_last: true + pin_memory: true + prefetch_factor: 16 + persistent_workers: true + timeout: 0 + paths: + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-000-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-000-00001.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-001-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-002-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-003-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-004-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-004-00001.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-005-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-005-00001.npy \ No newline at end of file diff --git a/configs/OLMoE-600m-200m.yml b/configs/OLMoE-600m-200m.yml new file mode 100644 index 000000000..f5c476be9 --- /dev/null +++ b/configs/OLMoE-600m-200m.yml @@ -0,0 +1,145 @@ +run_name: OLMoE +seed: 6198 +dry_run: false + +wandb: + name: ${run_name} + project: olmoe + group: null + +model: + d_model: 768 + n_heads: 12 + n_layers: 14 + mlp_ratio: 4 # 4 vs 8 (for swiglu) + weight_tying: true + alibi: false + rope: true + flash_attention: true + attention_dropout: 0.0 + attention_layer_norm: false + multi_query_attention: false + include_bias: false + block_type: moe + layer_norm_type: default + layer_norm_with_affine: false + bias_for_layer_norm: false + attention_layer_norm_with_affine: false + activation_type: gelu # gelu vs swiglu + residual_dropout: 0.0 + embedding_dropout: 0.0 + max_sequence_length: 2048 + vocab_size: 50280 + embedding_size: 50304 + eos_token_id: 50279 + pad_token_id: 1 + init_device: meta + init_fn: normal # mitchell vs normal + +compile: null # causes instability on AMD GPUs + +optimizer: + name: adamw + learning_rate: 4.0e-4 + weight_decay: 0.1 + betas: + - 0.9 + - 0.95 + metrics_log_interval: 10 + +scheduler: + name: cosine_with_warmup + t_warmup: 2000 + alpha_f: 0.1 + +tokenizer: + identifier: tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json + truncate_direction: right + +save_folder: /data/niklas/olmoe +save_overwrite: false +# Sharded checkpoints (best for restarts) +save_interval: 1000 +save_num_checkpoints_to_keep: 9 +# Unsharded checkpoints (for final storage) +save_interval_unsharded: 10000 +save_num_unsharded_checkpoints_to_keep: -1 + +load_path: null + +max_duration: 10e9T # 10B tokens +global_train_batch_size: 2048 +device_train_microbatch_size: 8 + +precision: amp_bf16 + +fsdp: + wrapping_strategy: null + precision: mixed + +max_grad_norm: 1.0 +max_grad_norm_ratio: null + +speed_monitor: + window_size: 20 + +eval_interval: ${save_interval} +eval_subset_num_batches: -1 +device_eval_batch_size: ${device_train_microbatch_size} +evaluators: + # lump all the small datasets together (we still get separate metrics). + - label: v3-small-ppl-validation + data: + num_workers: 0 + drop_last: true + datasets: + v3-small-c4_en-validation: + - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/c4_en/val/part-0-00000.npy + v3-small-dolma_books-validation: + - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_books/val/part-0-00000.npy + v3-small-dolma_common-crawl-validation: + - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_common-crawl/val/part-0-00000.npy + v3-small-dolma_pes2o-validation: + - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_pes2o/val/part-0-00000.npy + v3-small-dolma_reddit-validation: + - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_reddit/val/part-0-00000.npy + v3-small-dolma_stack-validation: + - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_stack/val/part-0-00000.npy + v3-small-dolma_wiki-validation: + - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_wiki/val/part-0-00000.npy + v3-small-ice-validation: + - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/ice/val/part-0-00000.npy + v3-small-m2d2_s2orc-validation: + - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/m2d2_s2orc/val/part-0-00000.npy + v3-small-pile-validation: + - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/pile/val/part-0-00000.npy + v3-small-wikitext_103-validation: + - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/wikitext_103/val/part-0-00000.npy + +data: + pad_direction: right + num_workers: 0 + drop_last: true + pin_memory: true + prefetch_factor: 16 + persistent_workers: true + timeout: 0 + paths: + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-000-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-000-00001.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-001-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-002-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-003-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-004-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-004-00001.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-005-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-005-00001.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-006-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-006-00001.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-007-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-008-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-008-00001.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-009-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-009-00001.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-010-00000.npy + - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-010-00001.npy diff --git a/configs/OLMoE-test.yaml b/configs/OLMoE-test.yaml new file mode 100644 index 000000000..e79f36a32 --- /dev/null +++ b/configs/OLMoE-test.yaml @@ -0,0 +1,100 @@ +run_name: OLMoE-1B +seed: 6198 +dry_run: false + +wandb: + name: ${run_name} + project: olmoe + group: null + +model: + d_model: 640 + n_heads: 10 + n_layers: 16 + mlp_ratio: 4 # TODO: 4 or 8 + weight_tying: true + alibi: false + rope: true + flash_attention: true + attention_dropout: 0.0 + attention_layer_norm: false + multi_query_attention: false + include_bias: false + block_type: moe + layer_norm_type: default + layer_norm_with_affine: false + bias_for_layer_norm: false + attention_layer_norm_with_affine: false + activation_type: gelu # TODO: Experiment with gelu vs swiglu + residual_dropout: 0.0 + embedding_dropout: 0.0 + max_sequence_length: 2048 + vocab_size: 50280 + embedding_size: 50304 + eos_token_id: 50279 + pad_token_id: 1 + init_device: meta + init_fn: mitchell # TODO: Experiment with regular normal + +compile: null # causes instability on AMD GPUs + +optimizer: + name: adamw + learning_rate: 4.0e-4 + weight_decay: 0.1 + betas: + - 0.9 + - 0.95 + metrics_log_interval: 10 + +scheduler: + name: cosine_with_warmup + t_warmup: 2000 + alpha_f: 0.1 + +tokenizer: + identifier: tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json + truncate_direction: right + +save_folder: /data/niklas/olmoe +save_overwrite: false +# Sharded checkpoints (best for restarts) +save_interval: 1000 +save_num_checkpoints_to_keep: 9 +# Unsharded checkpoints (for final storage) +save_interval_unsharded: 10000 +save_num_unsharded_checkpoints_to_keep: -1 + +load_path: null + +max_duration: 739_328 # 3.1T tokens +global_train_batch_size: 2048 +device_train_microbatch_size: 8 + +precision: amp_bf16 + +fsdp: + wrapping_strategy: null + precision: mixed + +max_grad_norm: 1.0 +max_grad_norm_ratio: null + +speed_monitor: + window_size: 20 + +eval_interval: ${save_interval} +eval_subset_num_batches: -1 +device_eval_batch_size: ${device_train_microbatch_size} + +data: + pad_direction: right + num_workers: 0 + drop_last: true + pin_memory: true + prefetch_factor: 16 + persistent_workers: true + timeout: 0 + paths: + - /data/niklas/llm/data/part-000-00000.npy + - /data/niklas/llm/data/part-000-00001.npy \ No newline at end of file diff --git a/olmo/initialization.py b/olmo/initialization.py index df1f947f4..775d867d2 100644 --- a/olmo/initialization.py +++ b/olmo/initialization.py @@ -40,9 +40,15 @@ def init_weights( std = config.init_std * std_factor if config.init_cutoff_factor is not None: cutoff_value = config.init_cutoff_factor * std - nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value) + if hasattr(module, "weight"): + nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value) + else: + nn.init.trunc_normal_(module, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value) else: - nn.init.normal_(module.weight, mean=0.0, std=std) + if hasattr(module, "weight"): + nn.init.normal_(module.weight, mean=0.0, std=std) + else: + nn.init.normal_(module, mean=0.0, std=std) elif config.init_fn == InitFnType.mitchell: std = std_factor / math.sqrt(d) if layer_id is not None: @@ -52,10 +58,16 @@ def init_weights( else: nn.init.trunc_normal_(module, mean=0.0, std=std, a=-3 * std, b=3 * std) elif config.init_fn == InitFnType.kaiming_normal: - nn.init.kaiming_normal_(module.weight, nonlinearity="relu") + if hasattr(module, "weight"): + nn.init.kaiming_normal_(module.weight, nonlinearity="relu") + else: + nn.init.kaiming_normal_(module, nonlinearity="relu") elif config.init_fn == InitFnType.fan_in: std = std_factor / math.sqrt(d) - nn.init.normal_(module.weight, mean=0.0, std=std) + if hasattr(module, "weight"): + nn.init.normal_(module.weight, mean=0.0, std=std) + else: + nn.init.normal_(module, mean=0.0, std=std) elif config.init_fn == InitFnType.full_megatron: if type_of_module is None: raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.") @@ -79,13 +91,22 @@ def init_weights( std = config.d_model**-0.5 else: raise RuntimeError(f"Unknown module type '{type_of_module}'") - nn.init.trunc_normal_( - module.weight, - mean=0.0, - std=std, - a=-cutoff_factor * std, - b=cutoff_factor * std, - ) + if hasattr(module, "weight"): + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-cutoff_factor * std, + b=cutoff_factor * std, + ) + else: + nn.init.trunc_normal_( + module, + mean=0.0, + std=std, + a=-cutoff_factor * std, + b=cutoff_factor * std, + ) else: raise NotImplementedError(config.init_fn) From fbe843724b2fc5d95ed70195c2ca7cc87c8f997c Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Thu, 18 Apr 2024 15:52:47 +0000 Subject: [PATCH 05/13] Fix LB loss --- olmo/train.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/olmo/train.py b/olmo/train.py index 6ab5377dc..1afe53750 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -152,6 +152,8 @@ class Trainer: def __post_init__(self): if self.cfg.fused_loss: + if self.model.config.block_type == BlockType.moe: + raise NotImplementedError("Fused loss is not implemented for MoE models.") from flash_attn.ops.triton.cross_entropy import ( # type: ignore cross_entropy_loss, ) @@ -200,7 +202,7 @@ def fused_loss_fn( ffn_hidden_size=self.model.config.d_model * 4, moe_num_experts=self.model.config.moe_num_experts, num_layers=self.model.config.n_layers, - # Not tested for nowe + # Not tested for now moe_expert_model_parallelism=False, moe_top_k=self.model.config.moe_top_k, device=self.model.config.init_device, @@ -695,9 +697,10 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor loss = ce_loss if self.model.config.block_type == BlockType.moe: - lb_batch_loss = batched_load_balancing_loss(self.moe_args) + lb_loss = batched_load_balancing_loss(self.moe_args) / len(micro_batches) clear_load_balancing_loss() - loss += lb_batch_loss + loss += lb_loss + lb_batch_loss += lb_loss.detach() del logits From f44570a546dd1a343fd11d5b84c39c9654405634 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Thu, 18 Apr 2024 17:34:08 +0000 Subject: [PATCH 06/13] Adapt confs --- configs/OLMoE-200m-80m.yml | 56 ++++++------------------------------- configs/OLMoE-200m.yml | 18 ++++++------ configs/OLMoE-600m-200m.yml | 56 ++++++------------------------------- olmo/model.py | 1 + 4 files changed, 28 insertions(+), 103 deletions(-) diff --git a/configs/OLMoE-200m-80m.yml b/configs/OLMoE-200m-80m.yml index 1af6faa5c..49778ff58 100644 --- a/configs/OLMoE-200m-80m.yml +++ b/configs/OLMoE-200m-80m.yml @@ -86,35 +86,6 @@ speed_monitor: eval_interval: ${save_interval} eval_subset_num_batches: -1 device_eval_batch_size: ${device_train_microbatch_size} -evaluators: - # lump all the small datasets together (we still get separate metrics). - - label: v3-small-ppl-validation - data: - num_workers: 0 - drop_last: true - datasets: - v3-small-c4_en-validation: - - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/c4_en/val/part-0-00000.npy - v3-small-dolma_books-validation: - - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_books/val/part-0-00000.npy - v3-small-dolma_common-crawl-validation: - - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_common-crawl/val/part-0-00000.npy - v3-small-dolma_pes2o-validation: - - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_pes2o/val/part-0-00000.npy - v3-small-dolma_reddit-validation: - - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_reddit/val/part-0-00000.npy - v3-small-dolma_stack-validation: - - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_stack/val/part-0-00000.npy - v3-small-dolma_wiki-validation: - - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_wiki/val/part-0-00000.npy - v3-small-ice-validation: - - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/ice/val/part-0-00000.npy - v3-small-m2d2_s2orc-validation: - - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/m2d2_s2orc/val/part-0-00000.npy - v3-small-pile-validation: - - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/pile/val/part-0-00000.npy - v3-small-wikitext_103-validation: - - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/wikitext_103/val/part-0-00000.npy data: pad_direction: right @@ -125,21 +96,12 @@ data: persistent_workers: true timeout: 0 paths: - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-000-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-000-00001.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-001-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-002-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-003-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-004-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-004-00001.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-005-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-005-00001.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-006-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-006-00001.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-007-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-008-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-008-00001.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-009-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-009-00001.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-010-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-010-00001.npy + - /data/niklas/llm/data/part-000-00000.npy + - /data/niklas/llm/data/part-000-00001.npy + - /data/niklas/llm/data/part-001-00000.npy + - /data/niklas/llm/data/part-002-00000.npy + - /data/niklas/llm/data/part-003-00000.npy + - /data/niklas/llm/data/part-004-00000.npy + - /data/niklas/llm/data/part-004-00001.npy + - /data/niklas/llm/data/part-005-00000.npy + - /data/niklas/llm/data/part-005-00001.npy \ No newline at end of file diff --git a/configs/OLMoE-200m.yml b/configs/OLMoE-200m.yml index 911791f79..a7d79ac98 100644 --- a/configs/OLMoE-200m.yml +++ b/configs/OLMoE-200m.yml @@ -96,12 +96,12 @@ data: persistent_workers: true timeout: 0 paths: - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-000-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-000-00001.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-001-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-002-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-003-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-004-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-004-00001.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-005-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-005-00001.npy \ No newline at end of file + - /data/niklas/llm/data/part-000-00000.npy + - /data/niklas/llm/data/part-000-00001.npy + - /data/niklas/llm/data/part-001-00000.npy + - /data/niklas/llm/data/part-002-00000.npy + - /data/niklas/llm/data/part-003-00000.npy + - /data/niklas/llm/data/part-004-00000.npy + - /data/niklas/llm/data/part-004-00001.npy + - /data/niklas/llm/data/part-005-00000.npy + - /data/niklas/llm/data/part-005-00001.npy \ No newline at end of file diff --git a/configs/OLMoE-600m-200m.yml b/configs/OLMoE-600m-200m.yml index f5c476be9..bc25092db 100644 --- a/configs/OLMoE-600m-200m.yml +++ b/configs/OLMoE-600m-200m.yml @@ -86,35 +86,6 @@ speed_monitor: eval_interval: ${save_interval} eval_subset_num_batches: -1 device_eval_batch_size: ${device_train_microbatch_size} -evaluators: - # lump all the small datasets together (we still get separate metrics). - - label: v3-small-ppl-validation - data: - num_workers: 0 - drop_last: true - datasets: - v3-small-c4_en-validation: - - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/c4_en/val/part-0-00000.npy - v3-small-dolma_books-validation: - - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_books/val/part-0-00000.npy - v3-small-dolma_common-crawl-validation: - - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_common-crawl/val/part-0-00000.npy - v3-small-dolma_pes2o-validation: - - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_pes2o/val/part-0-00000.npy - v3-small-dolma_reddit-validation: - - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_reddit/val/part-0-00000.npy - v3-small-dolma_stack-validation: - - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_stack/val/part-0-00000.npy - v3-small-dolma_wiki-validation: - - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_wiki/val/part-0-00000.npy - v3-small-ice-validation: - - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/ice/val/part-0-00000.npy - v3-small-m2d2_s2orc-validation: - - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/m2d2_s2orc/val/part-0-00000.npy - v3-small-pile-validation: - - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/pile/val/part-0-00000.npy - v3-small-wikitext_103-validation: - - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/wikitext_103/val/part-0-00000.npy data: pad_direction: right @@ -125,21 +96,12 @@ data: persistent_workers: true timeout: 0 paths: - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-000-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-000-00001.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-001-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-002-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-003-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-004-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-004-00001.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-005-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-005-00001.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-006-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-006-00001.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-007-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-008-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-008-00001.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-009-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-009-00001.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-010-00000.npy - - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-010-00001.npy + - /data/niklas/llm/data/part-000-00000.npy + - /data/niklas/llm/data/part-000-00001.npy + - /data/niklas/llm/data/part-001-00000.npy + - /data/niklas/llm/data/part-002-00000.npy + - /data/niklas/llm/data/part-003-00000.npy + - /data/niklas/llm/data/part-004-00000.npy + - /data/niklas/llm/data/part-004-00001.npy + - /data/niklas/llm/data/part-005-00000.npy + - /data/niklas/llm/data/part-005-00001.npy diff --git a/olmo/model.py b/olmo/model.py index ba17252b0..08af8f78f 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -687,6 +687,7 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): moe_args = MoEArgs( activation_fn=F.silu if 'glu' in config.activation_type.lower() else self.act, mlp_type='glu' if 'glu' in config.activation_type.lower() else 'mlp', + # mlp_impl='grouped', # 4x slower on H100s hidden_size=config.d_model, ffn_hidden_size=int(self.act.output_multiplier * self.hidden_size), moe_num_experts=config.moe_num_experts, From 62fc4e54be5d7be400b61ce43edcfc7f9d026440 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Sun, 21 Apr 2024 21:13:32 +0000 Subject: [PATCH 07/13] Log assignments & init MoE --- olmo/model.py | 35 ++++++++++++++++++++++++++++++----- olmo/train.py | 16 +++++++++++++--- 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/olmo/model.py b/olmo/model.py index 08af8f78f..eeb1c3df6 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -684,7 +684,7 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): ) # MoE Block - moe_args = MoEArgs( + self.moe_args = MoEArgs( activation_fn=F.silu if 'glu' in config.activation_type.lower() else self.act, mlp_type='glu' if 'glu' in config.activation_type.lower() else 'mlp', # mlp_impl='grouped', # 4x slower on H100s @@ -702,9 +702,8 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): # Handled by FSDP bf16=False, fp16=False, - init_method=partial(init_weights, config=config, d=config.d_model, layer_id=None, type_of_module=ModuleType.in_module), ) - self.ffn = MoE(moe_args) + self.ffn = MoE(self.moe_args) # Rotary embeddings. if self.config.rope: @@ -738,6 +737,15 @@ def reset_parameters(self): self.k_norm.reset_parameters() if self.q_norm is not None: self.q_norm.reset_parameters() + + # NOTE: the standard deviation for these weights does not depend on the layer. + init_weights( + self.att_proj, + self.config, + d=self.config.d_model, + layer_id=None, + type_of_module=ModuleType.in_module + ) init_weights( self.attn_out, self.config, @@ -747,9 +755,26 @@ def reset_parameters(self): ) self.attn_norm.reset_parameters() self.ff_norm.reset_parameters() - # NOTE: the standard deviation for these weights does not depend on the layer. init_weights( - self.att_proj, self.config, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module + self.ffn.experts.mlp.w1, + self.config, + d=self.config.d_model, + layer_id=None, + type_of_module=ModuleType.in_module, + ) + init_weights( + self.ffn.experts.mlp.w2, + self.config, + d=int(self.act.output_multiplier * self.hidden_size), + layer_id=self.layer_id, + type_of_module=ModuleType.out_module, + ) + init_weights( + self.ffn.router.layer, + self.config, + d=self.config.d_model, + layer_id=None, + type_of_module=ModuleType.out_module, ) def forward( diff --git a/olmo/train.py b/olmo/train.py index 1afe53750..55b4342c7 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -56,7 +56,7 @@ log = logging.getLogger(__name__) try: - from megablocks.layers.moe import batched_load_balancing_loss, clear_load_balancing_loss + from megablocks.layers.moe import batched_load_balancing_loss, clear_load_balancing_loss, get_load_balancing_loss from megablocks.layers.arguments import Arguments as MoEArgs except ImportError: log.warning(f"Megablocks not installed. To train MoE, install with pip install megablocks.") @@ -670,6 +670,8 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor ce_batch_loss = torch.tensor(0.0, device=self.device) z_batch_loss = None if not self.cfg.softmax_auxiliary_loss else torch.tensor(0.0, device=self.device) lb_batch_loss = None if self.model.config.block_type != BlockType.moe else torch.tensor(0.0, device=self.device) + # Keep this one on CPU to save memory + expert_assignments = None if self.model.config.block_type != BlockType.moe else torch.zeros((self.model.config.n_layers, self.model.config.moe_num_experts)) for micro_batch in micro_batches: with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision): # Run forward pass. @@ -698,6 +700,10 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor if self.model.config.block_type == BlockType.moe: lb_loss = batched_load_balancing_loss(self.moe_args) / len(micro_batches) + + tokens_per_expert, _ = zip(*get_load_balancing_loss()) + expert_assignments += torch.stack(tokens_per_expert, dim=0).cpu() + clear_load_balancing_loss() loss += lb_loss lb_batch_loss += lb_loss.detach() @@ -707,7 +713,7 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor # Run backward pass. loss.backward() - return ce_batch_loss, z_batch_loss, lb_batch_loss + return ce_batch_loss, z_batch_loss, lb_batch_loss, expert_assignments def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> Dict[str, float]: metrics: Dict[str, float] = {} @@ -724,7 +730,7 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> batch = move_to_device(batch, self.device) # Run forward-backward pass. - ce_batch_loss, z_batch_loss, lb_batch_loss = self.train_batch(batch) + ce_batch_loss, z_batch_loss, lb_batch_loss, expert_assignments = self.train_batch(batch) # Collect loss, potentially reducing over all ranks. if reduce_global_loss: @@ -779,6 +785,10 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> metrics["train/ZLoss"] = z_batch_loss.item() if lb_batch_loss is not None: metrics["train/LoadBalancingLoss"] = lb_batch_loss.item() + # Log assignment metrics. + for layer_idx, expert_assignments_layer in enumerate(expert_assignments): + for expert_idx, expert_assignment in enumerate(expert_assignments_layer): + metrics[f"train/LoadBalancing/layer{layer_idx}/expert{expert_idx}"] = expert_assignment.item() # Maybe collect post-step optimizer-specific metrics. if should_log_optim_metrics_this_step: From e594d6240b4d8a4883ae77ac88a671192d389cb1 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Sun, 21 Apr 2024 23:19:02 +0000 Subject: [PATCH 08/13] Add % --- olmo/train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/olmo/train.py b/olmo/train.py index 55b4342c7..0c7d509f6 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -787,8 +787,10 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> metrics["train/LoadBalancingLoss"] = lb_batch_loss.item() # Log assignment metrics. for layer_idx, expert_assignments_layer in enumerate(expert_assignments): + total_tokens = expert_assignments_layer.sum().item() for expert_idx, expert_assignment in enumerate(expert_assignments_layer): - metrics[f"train/LoadBalancing/layer{layer_idx}/expert{expert_idx}"] = expert_assignment.item() + metrics[f"train/TokensPercentage/layer{layer_idx}/expert{expert_idx}"] = (expert_assignment.item() / total_tokens) * 100 + metrics[f"train/TokensTotal/layer{layer_idx}/expert{expert_idx}"] = expert_assignment.item() # Maybe collect post-step optimizer-specific metrics. if should_log_optim_metrics_this_step: From 34a389bf5ba07322c63236f841fba53f7c46f5f1 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Mon, 22 Apr 2024 02:24:31 +0000 Subject: [PATCH 09/13] Configure bias --- olmo/model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/olmo/model.py b/olmo/model.py index eeb1c3df6..9d4e6ef92 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -702,6 +702,8 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): # Handled by FSDP bf16=False, fp16=False, + bias=self.config.include_bias, + return_bias=False, ) self.ffn = MoE(self.moe_args) @@ -769,6 +771,8 @@ def reset_parameters(self): layer_id=self.layer_id, type_of_module=ModuleType.out_module, ) + if self.ffn.experts.bias is not None: + torch.nn.init.zeros_(self.ffn.experts.bias) init_weights( self.ffn.router.layer, self.config, @@ -822,9 +826,9 @@ def forward( x = self.ff_norm(x) if self._activation_checkpoint_fn is not None: - x, _ = self._activation_checkpoint_fn(self.ffn, x) # type: ignore + x = self._activation_checkpoint_fn(self.ffn, x) # type: ignore else: - x, _ = self.ffn(x) + x = self.ffn(x) x = self.dropout(x) x = og_x + x From 6afe3003d95f02dc9d2bb1ff737ea63783a804a9 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Tue, 23 Apr 2024 00:01:10 +0000 Subject: [PATCH 10/13] Make EP configurable --- olmo/config.py | 5 +++++ olmo/model.py | 6 +++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/olmo/config.py b/olmo/config.py index bea36c435..0ff754d51 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -455,6 +455,11 @@ class ModelConfig(BaseConfig): The weight to use for the MoE loss. """ + moe_expert_model_parallelism: Optional[bool] = False + """ + Whether to use model parallelism for the MoE experts. + """ + @property def effective_n_kv_heads(self) -> int: if self.n_kv_heads is None: diff --git a/olmo/model.py b/olmo/model.py index 9d4e6ef92..e371952ea 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -687,14 +687,14 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): self.moe_args = MoEArgs( activation_fn=F.silu if 'glu' in config.activation_type.lower() else self.act, mlp_type='glu' if 'glu' in config.activation_type.lower() else 'mlp', - # mlp_impl='grouped', # 4x slower on H100s + # Recommended for H100s by megablocks but found it 4x slower on H100s + # mlp_impl='grouped', hidden_size=config.d_model, ffn_hidden_size=int(self.act.output_multiplier * self.hidden_size), moe_num_experts=config.moe_num_experts, # Handled by FSDP (https://github.com/databricks/megablocks/issues/57#issuecomment-1854594483) moe_weight_parallelism=False, - # Not tested for now - moe_expert_model_parallelism=False, + moe_expert_model_parallelism=config.moe_expert_model_parallelism, moe_top_k=config.moe_top_k, moe_capacity_factor=config.moe_capacity_factor, moe_loss_weight=config.moe_loss_weight, From a1bfc496c817fd44f8b21b950c2745e829a41d57 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Tue, 23 Apr 2024 01:54:52 +0000 Subject: [PATCH 11/13] Add dMoE --- olmo/config.py | 5 +++++ olmo/model.py | 7 +++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/olmo/config.py b/olmo/config.py index 0ff754d51..2d56c0e72 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -460,6 +460,11 @@ class ModelConfig(BaseConfig): Whether to use model parallelism for the MoE experts. """ + moe_dropless: Optional[bool] = False + """ + Whether to use dMoE (https://arxiv.org/abs/2211.15841) + """ + @property def effective_n_kv_heads(self) -> int: if self.n_kv_heads is None: diff --git a/olmo/model.py b/olmo/model.py index e371952ea..1e91433ad 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -54,7 +54,7 @@ raise SystemExit("This script supports Python 3.8 or higher") try: - from megablocks.layers.moe import MoE + from megablocks.layers.moe import MoE, dMoE from megablocks.layers.arguments import Arguments as MoEArgs except ImportError: log.warning("megablocks not installed, MoE layers will not be available.") @@ -705,7 +705,10 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): bias=self.config.include_bias, return_bias=False, ) - self.ffn = MoE(self.moe_args) + if self.config.moe_dropless + self.ffn = dMoE(self.moe_args) + else: + self.ffn = MoE(self.moe_args) # Rotary embeddings. if self.config.rope: From bf091ee731db65ad9730dff265969fd005da744f Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Thu, 25 Apr 2024 20:58:23 +0000 Subject: [PATCH 12/13] Fix import --- olmo/model.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/olmo/model.py b/olmo/model.py index 1e91433ad..a0b1dad66 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -53,12 +53,6 @@ else: raise SystemExit("This script supports Python 3.8 or higher") -try: - from megablocks.layers.moe import MoE, dMoE - from megablocks.layers.arguments import Arguments as MoEArgs -except ImportError: - log.warning("megablocks not installed, MoE layers will not be available.") - __all__ = [ "LayerNormBase", "LayerNorm", @@ -75,9 +69,14 @@ "OLMoGenerateOutput", ] - log = logging.getLogger(__name__) +try: + from megablocks.layers.arguments import Arguments as MoEArgs + from megablocks.layers.dmoe import dMoE + from megablocks.layers.moe import MoE +except ImportError: + log.warning("megablocks not installed, MoE layers will not be available.") def activation_checkpoint_function(cfg: ModelConfig): preserve_rng_state = ( @@ -705,7 +704,7 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): bias=self.config.include_bias, return_bias=False, ) - if self.config.moe_dropless + if self.config.moe_dropless: self.ffn = dMoE(self.moe_args) else: self.ffn = MoE(self.moe_args) From 678cb53f674d0a1e44001289bfb14c2dbd5c5213 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Tue, 7 May 2024 19:17:29 +0000 Subject: [PATCH 13/13] Default to gg --- olmo/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olmo/model.py b/olmo/model.py index a0b1dad66..83d9c4da9 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -687,7 +687,7 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): activation_fn=F.silu if 'glu' in config.activation_type.lower() else self.act, mlp_type='glu' if 'glu' in config.activation_type.lower() else 'mlp', # Recommended for H100s by megablocks but found it 4x slower on H100s - # mlp_impl='grouped', + mlp_impl='grouped', hidden_size=config.d_model, ffn_hidden_size=int(self.act.output_multiplier * self.hidden_size), moe_num_experts=config.moe_num_experts,