Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API to set a module as a leaf node when recursively setting Z3 hooks #4966

Merged
merged 9 commits into from
Jan 19, 2024

Conversation

tohtana
Copy link
Contributor

@tohtana tohtana commented Jan 17, 2024

ZeRO3 does not work with MoE models because the order of executing modules can change at every forward/backward pass (#4094, #4808).

This PR adds an API to stop breaking down a module for parameter fetching. The following shows an example of the usage:

import torch
import deepspeed
import deepspeed.comm as dist
from transformers.deepspeed import HfDeepSpeedConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

model_id = "mistralai/Mixtral-8x7B-v0.1"
ds_config = {
      "bf16": {
          "enabled": True,
      },
      "zero_optimization": {
          "stage": 3,
      },
      "train_micro_batch_size_per_gpu": 1,
  }

hfdsc = HfDeepSpeedConfig(ds_config)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)

deepspeed.utils.set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
model.eval()

ds_engine = deepspeed.initialize(model=model, config_params=ds_config)[0]
ds_engine.module.eval()
model = ds_engine.module

inputs = tokenizer.encode("DeepSpeed is", return_tensors="pt").to("cuda")
outputs = model.generate(inputs, max_new_tokens=200)
output_str = tokenizer.decode(outputs[0])
if dist.get_rank() == 0:
  print(f"output: {output_str}")

By passing names of modules to set_z3_leaf_modules, DeepSpeed engine stops breaking down the module.

In this example, MixtralSparseMoeBlock has multiple experts as its submodule. Using set_z3_leaf_modules, the DeepSpeed engine fetches parameters of all the submodules when pre-fetching the parameters of MixtralSparseMoeBlock.

@tohtana tohtana changed the title Add API to stop breaking down modules for parameter fetching Add API to set a module as a leaf node when recursively setting Z3 hooks Jan 18, 2024
@tohtana tohtana marked this pull request as ready for review January 18, 2024 16:55
@tohtana tohtana added this pull request to the merge queue Jan 19, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Jan 19, 2024
@tohtana tohtana added this pull request to the merge queue Jan 19, 2024
Merged via the queue into master with commit 96c5a87 Jan 19, 2024
12 checks passed
@xs1997zju
Copy link

@tohtana thanks for the pr, however, when I use zero3 train mixtral with set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) , it hang for nccl time out during backward, any clues?

@tohtana
Copy link
Contributor Author

tohtana commented Jan 23, 2024

@xs1997zju Thank you for sharing! Can you show a simple repro?

@xs1997zju
Copy link

Ok, here is my scripts:

import argparse

import torch
from datasets import load_dataset
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, get_linear_schedule_with_warmup, set_seed

from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from torch.utils.data import Dataset

from accelerate.utils import DummyOptim, DummyScheduler, set_seed

import math

from accelerate.utils import DeepSpeedPlugin, FullyShardedDataParallelPlugin

from transformers import get_scheduler


from deepspeed.utils import set_z3_leaf_modules  # mixtra;

from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

from transformers.integrations import is_deepspeed_zero3_enabled


MAX_GPU_BATCH_SIZE = 4


class RandomDataset(Dataset):
    def __init__(self, num_samples: int = 1000, max_length: int = 4096, vocab_size: int = 100, tokenizer=None):
        self.num_samples = num_samples
        self.max_length = max_length
        self.input_ids = torch.randint(2, vocab_size, (num_samples, max_length))
        self.attention_mask = torch.ones_like(self.input_ids)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "labels": self.input_ids[idx],
        }


def training_function(args):
    # Initialize accelerator
    deepPlugin = DeepSpeedPlugin(hf_ds_config='./ds_config_dump.json', zero3_init_flag=True)
    accelerator = Accelerator(mixed_precision='bf16', deepspeed_plugin=deepPlugin, gradient_accumulation_steps=1)
    
    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
    lr = 2e-5
    num_epochs = 2
    seed = 42
    batch_size = 1
    warmup_ratio = 0.03
    
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    dataset = RandomDataset(num_samples=1000, tokenizer=tokenizer)
    train_dataloader = DataLoader(
        dataset, shuffle=True, collate_fn=None, batch_size=batch_size, drop_last=True
    )
    
    if accelerator.is_main_process:
        print(f'before prepare dataloader len: {len(train_dataloader)}')
    
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps)
    max_train_steps = num_epochs * num_update_steps_per_epoch
    
    
    config = AutoConfig.from_pretrained(args.model_path)  # 
    model = AutoModelForCausalLM.from_pretrained(
            args.model_path,
            config=config,
            use_flash_attention_2=True,
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=(not is_deepspeed_zero3_enabled())
        )
    
    model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
    model.enable_input_require_grads()
    model.config.use_cache = False  # turn off when gradient checkpointing is enabled
    print("Gradient checkpointing enabled.")
    
    
    
    set_z3_leaf_modules(model, [MixtralSparseMoeBlock])  # z3_leaf
    model.train() #
    
    optimizer_cls = (
        torch.optim.AdamW
        if accelerator.state.deepspeed_plugin is None
            or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
            else DummyOptim
    )
    
    optimizer = optimizer_cls(params=model.parameters(), lr=lr)
    
    
    if (
        accelerator.state.deepspeed_plugin is None
        or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
    ):
        lr_scheduler = get_scheduler(
            name='linear',
            optimizer=optimizer,
            num_warmup_steps=math.ceil(max_train_steps * warmup_ratio),
            num_training_steps=max_train_steps,
        )
    else:
        lr_scheduler = DummyScheduler(
            optimizer, total_num_steps=max_train_steps, warmup_num_steps=math.ceil(max_train_steps * warmup_ratio)
        )

    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )
    

    # Now we train the model
    for epoch in range(num_epochs):
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(model):
                model.train()
                outputs = model(**batch)
                loss = outputs.loss
                
                print(f" epoch: {epoch}, step: {step} loss: {loss}")
                
                accelerator.backward(loss)
                
                print(f"finish backward")
                
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                
                print(f"finish optimizer step")
            


def main():
    parser = argparse.ArgumentParser(description="Simple example of training script.")
    parser.add_argument(
        "--model_path",
        type=str,
        default="path_to_mixtral-8x7b",)
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default="bf16",
        choices=["no", "fp16", "bf16", "fp8"],
        help="Whether to use mixed precision. Choose"
        "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
        "and an Nvidia Ampere GPU.",
    )
    args = parser.parse_args()
    training_function(args)


if __name__ == "__main__":
    main()

@xs1997zju
Copy link

An the deepspeed config is:

{
  
    "bf16": {
        "enabled": true
    },

    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "weight_decay": "auto"
        }
    },

 "scheduler": {
        "type": "WarmupDecayLR",
        "params": {
            "warmup_min_lr":  1e-6,
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto",
            "total_num_steps": "auto"
        }
    },


    "zero_optimization": {
        "stage": 3,
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_16bit_weights_on_model_save": true
    },

    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 1,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": 1,
    "wall_clock_breakdown": false
}

@tohtana
Copy link
Contributor Author

tohtana commented Jan 23, 2024

@xs1997zju Sorry I have missed one of your messages. I used your script but changed the config to reduce the mode size.

    config = AutoConfig.from_pretrained(model_id)
    config.num_hidden_layers = 2
    model = MixtralForCausalLM(config)

Then I ran this with 2xA100. In this case, it keeps running for around 20 steps.

@awzhgw
Copy link

awzhgw commented Jan 24, 2024

@xs1997zju Sorry I have missed one of your messages. I used your script but changed the config to reduce the mode size.

    config = AutoConfig.from_pretrained(model_id)
    config.num_hidden_layers = 2
    model = MixtralForCausalLM(config)

Then I ran this with 2xA100. In this case, it keeps running for around 20 steps.

why change num_hidden_layers?? if not change, What will happen ???

@tohtana
Copy link
Contributor Author

tohtana commented Jan 24, 2024

@awzhgw It is for debugging purpose. I wanted to debug with fewer GPUs and launch faster as long as the issue is reproduced.

@awzhgw
Copy link

awzhgw commented Jan 24, 2024

@awzhgw It is for debugging purpose. I wanted to debug with fewer GPUs and launch faster as long as the issue is reproduced.

@tohtana so, when i train my model,the config.num_hidden_layers cannot change. right ??

@tohtana
Copy link
Contributor Author

tohtana commented Jan 24, 2024

@tohtana so, when i train my model,the config.num_hidden_layers cannot change. right ??

If you use the pretrained model, you can't.

@awzhgw
Copy link

awzhgw commented Jan 24, 2024

@tohtana

when i add code:

config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path)
            config.num_hidden_layers = 2
            model = LlavaMixtralForCausalLM.from_pretrained(
                model_args.model_name_or_path,
                config=config,
                cache_dir=training_args.cache_dir,
                **bnb_model_from_pretrained_args
            )
            deepspeed.utils.set_z3_leaf_modules(model, [MixtralSparseMoeBlock])

but ,when it train 270 step ,it hang ,and wait 30 m, nccl timeout ..

my deepspeed is 0.13.1,

my ds.config is :

{
  "fp16": {
    "enabled": "auto",
    "loss_scale": 0,
    "loss_scale_window": 1000,
    "initial_scale_power": 16,
    "hysteresis": 2,
    "min_loss_scale": 1
  },
  "bf16": {
    "enabled": "auto"
  },
  "optimizer": {
    "type": "AdamW",
    "params": {
      "lr": "auto",
      "betas": "auto",
      "eps": "auto",
      "weight_decay": "auto"
    }
  },
  "scheduler": {
    "type": "WarmupLR",
    "params": {
      "warmup_min_lr": "auto",
      "warmup_max_lr": "auto",
      "warmup_num_steps": "auto"
    }
  },
  "zero_optimization": {
    "stage": 3,
    "offload_optimizer": {
      "device": "cpu",
      "pin_memory": true
    },
    "offload_param": {
      "device": "cpu",
      "pin_memory": true
    },
    "overlap_comm": true,
    "contiguous_gradients": true,
    "sub_group_size": 1e9,
    "reduce_bucket_size": "auto",
    "stage3_prefetch_bucket_size": "auto",
    "stage3_param_persistence_threshold": "auto",
    "stage3_max_live_parameters": 1e9,
    "stage3_max_reuse_distance": 1e9,
    "gather_16bit_weights_on_model_save": true
  },
  "gradient_accumulation_steps": "auto",
  "gradient_clipping": "auto",
  "train_batch_size": "auto",
  "train_micro_batch_size_per_gpu": "auto",
  "steps_per_print": 1e5,
  "wall_clock_breakdown": false
}

how to resolve it ?

@awzhgw
Copy link

awzhgw commented Jan 24, 2024

Thank you @xs1997zju, I could reproduced the issue. I will have a look. It seemed to hang at a certain step, not the first step.

@tohtana Would u mind to share your reproduce scripts、environment and machine configs? I use two A800 nodes with 8-GPU of each, and just hang at the first step of loss backward.

I same too , just hang at 270 step after. then NCCL timeout..

Invalidate trace cache @ step 1: expected module 25, but got module 323

how to resolve it ?

@awzhgw
Copy link

awzhgw commented Jan 24, 2024

@tohtana thanks for the pr, however, when I use zero3 train mixtral with set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) , it hang for nccl time out during backward, any clues?

I am same too. how to resolve it ? i use deepspeed 0.13.1

@tohtana
Copy link
Contributor Author

tohtana commented Jan 25, 2024

I same too , just hang at 270 step after. then NCCL timeout..

@awzhgw @xs1997zju I opened #5008 to address the issue. Please feel free to try.

@duanzhenyu001
Copy link

@tohtana thanks for the pr, however, when I use zero3 train mixtral with set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) , it hang for nccl time out during backward, any clues?

I am same too. how to resolve it ? i use deepspeed 0.13.1

I tried and meet this issue too. a trick way to solve this problem is add special token for each expert at the begin of input_ids, here's my code.
`class MixtralSparseMoeBlock(nn.Module):
"""
This implementation is
strictly equivalent to standard MoE with full capacity (no
dropped tokens). It's faster since it formulates MoE operations
in terms of block-sparse operations to accomodate imbalanced
assignments of tokens to experts, whereas standard MoE either
(1) drop tokens at the cost of reduced performance or (2) set
capacity factor to number of experts and thus waste computation
and memory on padding.
"""

def __init__(self, config: MOEConfig):
    super().__init__()
    self.hidden_dim = config.hidden_size
    self.ffn_dim = config.intermediate_size
    self.num_experts = config.num_local_experts
    self.top_k = config.num_experts_per_tok
    self.one_const_expert = config.one_const_expert
    self.independent_weight_loss = config.independent_weight_loss

    self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)

    self.experts = nn.ModuleList([LlamaMLP(config) for _ in range(self.num_experts)])
    self.const_expert = None

def forward(self, hidden_states: torch.Tensor):
    """ """
    prefix_hidden_states = hidden_states[:, :self.num_experts, :]
    prefix_out = torch.zeros_like(prefix_hidden_states, dtype=prefix_hidden_states.dtype,
                                  device=prefix_hidden_states.device)

    for expert_idx in range(self.num_experts):
        expert_layer = self.experts[expert_idx]
        prefix_out[:,expert_idx,:] = expert_layer(prefix_hidden_states[:,expert_idx,:])

    hidden_states = hidden_states[:, self.num_experts:, :]
    .
    .
    .
    final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
    final_hidden_states = torch.concat([prefix_out, final_hidden_states], dim=1)
    return final_hidden_states

`
I see tohtana have a new commit, I'll try it later

@Sniper970119
Copy link

Ok, here is my scripts:

import argparse

import torch
from datasets import load_dataset
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, get_linear_schedule_with_warmup, set_seed

from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from torch.utils.data import Dataset

from accelerate.utils import DummyOptim, DummyScheduler, set_seed

import math

from accelerate.utils import DeepSpeedPlugin, FullyShardedDataParallelPlugin

from transformers import get_scheduler


from deepspeed.utils import set_z3_leaf_modules  # mixtra;

from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

from transformers.integrations import is_deepspeed_zero3_enabled


MAX_GPU_BATCH_SIZE = 4


class RandomDataset(Dataset):
    def __init__(self, num_samples: int = 1000, max_length: int = 4096, vocab_size: int = 100, tokenizer=None):
        self.num_samples = num_samples
        self.max_length = max_length
        self.input_ids = torch.randint(2, vocab_size, (num_samples, max_length))
        self.attention_mask = torch.ones_like(self.input_ids)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "labels": self.input_ids[idx],
        }


def training_function(args):
    # Initialize accelerator
    deepPlugin = DeepSpeedPlugin(hf_ds_config='./ds_config_dump.json', zero3_init_flag=True)
    accelerator = Accelerator(mixed_precision='bf16', deepspeed_plugin=deepPlugin, gradient_accumulation_steps=1)
    
    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
    lr = 2e-5
    num_epochs = 2
    seed = 42
    batch_size = 1
    warmup_ratio = 0.03
    
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    dataset = RandomDataset(num_samples=1000, tokenizer=tokenizer)
    train_dataloader = DataLoader(
        dataset, shuffle=True, collate_fn=None, batch_size=batch_size, drop_last=True
    )
    
    if accelerator.is_main_process:
        print(f'before prepare dataloader len: {len(train_dataloader)}')
    
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps)
    max_train_steps = num_epochs * num_update_steps_per_epoch
    
    
    config = AutoConfig.from_pretrained(args.model_path)  # 
    model = AutoModelForCausalLM.from_pretrained(
            args.model_path,
            config=config,
            use_flash_attention_2=True,
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=(not is_deepspeed_zero3_enabled())
        )
    
    model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
    model.enable_input_require_grads()
    model.config.use_cache = False  # turn off when gradient checkpointing is enabled
    print("Gradient checkpointing enabled.")
    
    
    
    set_z3_leaf_modules(model, [MixtralSparseMoeBlock])  # z3_leaf
    model.train() #
    
    optimizer_cls = (
        torch.optim.AdamW
        if accelerator.state.deepspeed_plugin is None
            or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
            else DummyOptim
    )
    
    optimizer = optimizer_cls(params=model.parameters(), lr=lr)
    
    
    if (
        accelerator.state.deepspeed_plugin is None
        or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
    ):
        lr_scheduler = get_scheduler(
            name='linear',
            optimizer=optimizer,
            num_warmup_steps=math.ceil(max_train_steps * warmup_ratio),
            num_training_steps=max_train_steps,
        )
    else:
        lr_scheduler = DummyScheduler(
            optimizer, total_num_steps=max_train_steps, warmup_num_steps=math.ceil(max_train_steps * warmup_ratio)
        )

    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )
    

    # Now we train the model
    for epoch in range(num_epochs):
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(model):
                model.train()
                outputs = model(**batch)
                loss = outputs.loss
                
                print(f" epoch: {epoch}, step: {step} loss: {loss}")
                
                accelerator.backward(loss)
                
                print(f"finish backward")
                
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                
                print(f"finish optimizer step")
            


def main():
    parser = argparse.ArgumentParser(description="Simple example of training script.")
    parser.add_argument(
        "--model_path",
        type=str,
        default="path_to_mixtral-8x7b",)
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default="bf16",
        choices=["no", "fp16", "bf16", "fp8"],
        help="Whether to use mixed precision. Choose"
        "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
        "and an Nvidia Ampere GPU.",
    )
    args = parser.parse_args()
    training_function(args)


if __name__ == "__main__":
    main()

I run this script get an error like this

/python3.8/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: The size of tensor a (0) must match the size of tensor b (14336) at non-singleton dimension 1

I just change L50 and L 143 for model init. Could I please inquire about what these question are referring to?

@awzhgw
Copy link

awzhgw commented Jan 25, 2024

@tohtana thanks for the pr, however, when I use zero3 train mixtral with set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) , it hang for nccl time out during backward, any clues?

i am same too. how to resolve it ?

@awzhgw
Copy link

awzhgw commented Jan 25, 2024

@tohtana

i test four Scenario:

If I use the zero2 or zero2_offload training script and set config.num_hidden_layers to 2, the training process runs normally. However, if I use the zero2 training script and don't set config.num_hidden_layers, the model loading process encounters an out-of-memory (OOM) error.

On the other hand, if I use the zero3 or zero3_offload training script and set config.num_hidden_layers to 2, the training process hangs after 270 steps. Similarly, if I use the zero3 training script and don't set config.num_hidden_layers, the training process also hangs after 270 steps. the GPU state is like this.Power down to 100W until NCCL timeout kill the process.

so can you help me ? how can i do ?

@tohtana tohtana deleted the tohtana/no_break_for_param_fetch branch January 26, 2024 01:17
mrwyattii pushed a commit that referenced this pull request Feb 2, 2024
ZeRO3 sets hooks on parameters to run reduce-scatter. This is often
problematic for MoE models. Our data parallel processes may activate
different sets of experts, but the hook is not fired unless the expert
is activated at a forward pass. The reduce-scatter is called only on
some processes in this case.

This PR delays reduce-scatter for ZeRO3 leaf modules (Refer to #4966) to
address the issue.
We no longer set reduce-scatter hooks on parameters of the leaf modules.
Instead, we launch reduce-scatter on all parameters belonging to the
leaf module when exiting the module during the backward pass.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
mauryaavinash95 pushed a commit to mauryaavinash95/DeepSpeed that referenced this pull request Feb 17, 2024
…oks (microsoft#4966)

ZeRO3 does not work with MoE models because the order of executing
modules can change at every forward/backward pass (microsoft#4094, microsoft#4808).

This PR adds an API to stop breaking down a module for parameter
fetching. The following shows an example of the usage:
```python
import torch
import deepspeed
import deepspeed.comm as dist
from transformers.deepspeed import HfDeepSpeedConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

model_id = "mistralai/Mixtral-8x7B-v0.1"
ds_config = {
      "bf16": {
          "enabled": True,
      },
      "zero_optimization": {
          "stage": 3,
      },
      "train_micro_batch_size_per_gpu": 1,
  }

hfdsc = HfDeepSpeedConfig(ds_config)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)

deepspeed.utils.set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
model.eval()

ds_engine = deepspeed.initialize(model=model, config_params=ds_config)[0]
ds_engine.module.eval()
model = ds_engine.module

inputs = tokenizer.encode("DeepSpeed is", return_tensors="pt").to("cuda")
outputs = model.generate(inputs, max_new_tokens=200)
output_str = tokenizer.decode(outputs[0])
if dist.get_rank() == 0:
  print(f"output: {output_str}")
```

By passing names of modules to `set_z3_leaf_modules`, DeepSpeed engine
stops breaking down the module.

In this example, `MixtralSparseMoeBlock` has multiple experts as its
submodule. Using `set_z3_leaf_modules`, the DeepSpeed engine fetches
parameters of all the submodules when pre-fetching the parameters of
`MixtralSparseMoeBlock`.
mauryaavinash95 pushed a commit to mauryaavinash95/DeepSpeed that referenced this pull request Feb 17, 2024
ZeRO3 sets hooks on parameters to run reduce-scatter. This is often
problematic for MoE models. Our data parallel processes may activate
different sets of experts, but the hook is not fired unless the expert
is activated at a forward pass. The reduce-scatter is called only on
some processes in this case.

This PR delays reduce-scatter for ZeRO3 leaf modules (Refer to microsoft#4966) to
address the issue.
We no longer set reduce-scatter hooks on parameters of the leaf modules.
Instead, we launch reduce-scatter on all parameters belonging to the
leaf module when exiting the module during the backward pass.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
@Hannibal046
Copy link

For Mixtral-MOE Zero-3, a simple workaround woud be: enlarging the total batch size (increase the gradient acc steps)

@al093
Copy link

al093 commented Mar 7, 2024

I think the hacks done to allow MOE to be executed is still not enough.

In my use-case, I am evaluating Mixtral MOE and I see a consistent behaviour that after the warning Invalidate trace cache @ step 2808: expected module 3447, but got module 3469 evaluation just hangs and eventually NCCL timeout occurs.
I am using Zero 3 with parameter offloading.

Running loglikelihood requests
Running loglikelihood requests:   7%|▋         | 21/320 [01:52<22:43,  4.56s/it]
Running loglikelihood requests:  14%|█▍        | 45/320 [03:32<19:10,  4.18s/it]
Running loglikelihood requests:  15%|█▌        | 49/320 [03:48<18:50,  4.17s/it]
Running loglikelihood requests:  23%|██▎       | 73/320 [05:28<17:05,  4.15s/it]
Invalidate trace cache @ step 2808: expected module 3447, but got module 3469
Running loglikelihood requests:  24%|██▍       | 77/320 [05:47<17:29,  4.32s/it]
...

@JinpilChoi
Copy link

JinpilChoi commented Mar 8, 2024

For Mixtral-MOE Zero-3, a simple workaround woud be: enlarging the total batch size (increase the gradient acc steps)

I had same problem, but I fixed it by enlarging micro batch-size from 1 to 4. How does batch size affect it?

@stas00
Copy link
Collaborator

stas00 commented Mar 12, 2024

My colleague has confirmed that:

from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **policy_kwargs)
deepspeed.utils.set_z3_leaf_modules(model, [MixtralSparseMoeBlock])

allows normal training.

It works fine with MBS=1 or higher.

@tohtana, first - thank you for the fix. How would users know about it?

@chensnathan
Copy link

chensnathan commented Mar 26, 2024

For Mixtral-MOE Zero-3, a simple workaround woud be: enlarging the total batch size (increase the gradient acc steps)

I had same problem, but I fixed it by enlarging micro batch-size from 1 to 4. How does batch size affect it?

@JinpilChoi
From my experience, when the micro batch-size is too small, it cannot ensure the participation of all experts in the computing process or guarantee that they receive gradients. A workaround for this issue is to increase the batch size.

An alternative solution is to involve all experts in the computing process. For instance, consider the implementation in transformers. By examining the code in transformers, we can add the following code snippets:

         for expert_idx in range(self.num_experts):
            expert_layer = self.experts[expert_idx]
            idx, top_x = torch.where(expert_mask[expert_idx])

            if top_x.shape[0] == 0:
                # continue
                
                # NOTE: make all experts have gradients!!!
                # NOTE: important for training with small batch size
                zero_idx, zero_top_x = torch.where(expert_mask[expert_idx]==0)
                first_zero_idx = zero_idx[:1]
                first_zero_top_x = zero_top_x[:1]
                
                zero_top_x_list = first_zero_top_x.tolist()
                zero_idx_list = first_zero_idx.tolist()
                
                current_state = hidden_states[None, zero_top_x_list].reshape(-1, hidden_dim)
                current_hidden_states = expert_layer(current_state) * routing_weights[zero_top_x_list, zero_idx_list, None]
                # multiply by 0 to avoid real gradient
                final_hidden_states.index_add_(0, first_zero_top_x, current_hidden_states.to(hidden_states.dtype) * 0.)
            else:

                # in torch it is faster to index using lists than torch tensors
                top_x_list = top_x.tolist()
                idx_list = idx.tolist()
    
                # Index the correct hidden states and compute the expert hidden state for
                # the current expert. We need to make sure to multiply the output hidden
                # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
                current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
                current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
    
                # However `index_add_` only support torch tensors for indexing so we'll use
                # the `top_x` tensor here.
                final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))

rraminen pushed a commit to ROCm/DeepSpeed that referenced this pull request May 9, 2024
ZeRO3 sets hooks on parameters to run reduce-scatter. This is often
problematic for MoE models. Our data parallel processes may activate
different sets of experts, but the hook is not fired unless the expert
is activated at a forward pass. The reduce-scatter is called only on
some processes in this case.

This PR delays reduce-scatter for ZeRO3 leaf modules (Refer to microsoft#4966) to
address the issue.
We no longer set reduce-scatter hooks on parameters of the leaf modules.
Instead, we launch reduce-scatter on all parameters belonging to the
leaf module when exiting the module during the backward pass.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.