Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSpeed doesn't move tensors to GPU in deepspeed 0.9.3 and above #17806

Closed
calvinzhan opened this issue Jun 11, 2023 · 3 comments · Fixed by #18091
Closed

DeepSpeed doesn't move tensors to GPU in deepspeed 0.9.3 and above #17806

calvinzhan opened this issue Jun 11, 2023 · 3 comments · Fixed by #18091
Labels
3rd party Related to a 3rd-party bug Something isn't working strategy: deepspeed ver: 2.0.x

Comments

@calvinzhan
Copy link

calvinzhan commented Jun 11, 2023

Bug description

Lightning version (pytorch-lightning 2.0.2)
deepspeed version (0.9.4)

Trained chatglm, but got weight and input on different devices exception. Weight is on cpu, while input is on gpu. It works if deepspeed 0.9.2 is used. I noticed a difference in DeepSpeedEngine._configure_distributed_model() between 0.9.2 and 0.9.3.

deepspeed 0.9.2

    if not self.dont_change_device:
        self.module.to(self.device)

deepspeed 0.9.3 and above

    # zero.Init() handles device placement of model
    if not (self.dont_change_device or is_zero3_model):
        self.module.to(self.device)

In use stage 3, model won't move to GPU, which caused the following exception.

----------------------------------------
Traceback (most recent call last):
 File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/zhanqing/.vscode-server/extensions/ms-python.python-2023.8.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File "/home/zhanqing/.vscode-server/extensions/ms-python.python-2023.8.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File "/home/zhanqing/.vscode-server/extensions/ms-python.python-2023.8.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/home/zhanqing/.vscode-server/extensions/ms-python.python-2023.8.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/home/zhanqing/.vscode-server/extensions/ms-python.python-2023.8.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/home/zhanqing/.vscode-server/extensions/ms-python.python-2023.8.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "/nfs_data/data/personal/zhanqing/projects/nlp_basedon_lightning/llm/src/training/trainer.py", line 75, in <module>
    main(lightning_cli)
  File "/nfs_data/data/personal/zhanqing/projects/nlp_basedon_lightning/llm/src/training/trainer.py", line 60, in main
    trainer.train(lightning_cli=lightning_cli)
  File "/nfs_data/data/personal/zhanqing/projects/nlp_basedon_lightning/llm/src/training/trainer.py", line 53, in train
    lightning_cli.trainer.fit(lightning_cli.model, lightning_cli.datamodule)  # 完成后self.model会进入cpu
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 520, in fit
    call._call_and_handle_interrupt(
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 42, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 92, in launch
    return function(*args, **kwargs)
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 559, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 935, in _run
    results = self._run_stage()
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 976, in _run_stage
    self._run_sanity_check()
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1005, in _run_sanity_check
    val_loop.run()
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/pytorch_lightning/loops/utilities.py", line 177, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 115, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx)
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 375, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_kwargs.values())
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 288, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/pytorch_lightning/strategies/deepspeed.py", line 906, in validation_step
    return self.model(*args, **kwargs)
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1736, in forward
    loss = self.module(*inputs, **kwargs)
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/pytorch_lightning/overrides/base.py", line 102, in forward
    return self._forward_module.validation_step(*inputs, **kwargs)
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/lightning_base/lightning_module/abstract_lightning_module.py", line 54, in validation_step
    loss = self._step_internal(batch)
  File "/nfs_data/data/personal/zhanqing/projects/nlp_basedon_lightning/llm/src/training/lightning_module.py", line 136, in _step_internal
    outputs = self.model(
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/peft/peft_model.py", line 678, in forward
    return self.base_model(
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/home/zhanqing/.cache/huggingface/modules/transformers_modules/chatglm-6b/modeling_chatglm.py", line 1190, in forward
    transformer_outputs = self.transformer(
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/home/zhanqing/.cache/huggingface/modules/transformers_modules/chatglm-6b/modeling_chatglm.py", line 996, in forward
    layer_ret = layer(
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/home/zhanqing/.cache/huggingface/modules/transformers_modules/chatglm-6b/modeling_chatglm.py", line 627, in forward
    attention_outputs = self.attention(
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/home/zhanqing/.cache/huggingface/modules/transformers_modules/chatglm-6b/modeling_chatglm.py", line 460, in forward
    cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1)
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/home/zhanqing/.cache/huggingface/modules/transformers_modules/chatglm-6b/modeling_chatglm.py", line 203, in forward
    freqs = torch.einsum('i,j->ij', t, self.inv_freq)
  File "/home/zhanqing/.conda/envs/llm_5/lib/python3.9/site-packages/torch/functional.py", line 378, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

What version are you seeing the problem on?

v2.0

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

cc @awaelchli

@calvinzhan calvinzhan added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jun 11, 2023
@awaelchli
Copy link
Contributor

@calvinzhan This was a change introduced in deepspeed (perhaps breaking change since it was done in a minor release) microsoft/DeepSpeed#3611. We had to update our tests too, see here: #17748
As of deepspeed 0.9.3, all model layers/parameters must be created under the deepspeed.Init() context manager, otherwise they won't get moved to the GPU. In Lightning, this translates to defining the layers in the designated configure_sharded_model hook.

Could you please report this to DeepSpeed to make sure this is intended?

@awaelchli awaelchli added 3rd party Related to a 3rd-party strategy: deepspeed and removed needs triage Waiting to be triaged by maintainers labels Jun 11, 2023
@awaelchli awaelchli changed the title Lightning doesn't work with stage3 of deepspeed 0.9.3 and above DeepSpeed doesn't move tensors to GPU in deepspeed 0.9.3 and above Jun 11, 2023
@calvinzhan
Copy link
Author

@awaelchli

I created model in the following way, since I don't have enough GPU memory. Have to put the model on CPU. I don't know who should move weights to GPU later, lightning or deepspeed. I did report an issue to deepspeed.

def configure_sharded_model(self):
    deepspeed.zero.Init(module=self.model, remote_device="cpu", pin_memory=True)

@awaelchli
Copy link
Contributor

@calvinzhan You don't need to move the layers to the GPU yourself, nor do you need to call deepspeed.zero.Init anywhere in Lightning. You can just define the layers here:

def configure_sharded_model(self):
    self.layer = nn.Linear(...)
    # etc.

(the way you normally define your layer modules in __init__()).

You can find further examples of how to do this in: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3rd party Related to a 3rd-party bug Something isn't working strategy: deepspeed ver: 2.0.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants