Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[apex FusedAdam] crash workaround #249

Merged
merged 5 commits into from
Feb 18, 2022
Merged

Conversation

stas00
Copy link
Contributor

@stas00 stas00 commented Feb 18, 2022

This is a workaround for the crash when using apex.optimizers.FusedAdam on A100s with 80GB.

Currently can only load half the gpu memory, if I try to pack just a tad more over 40Gb this happens:

terminate called after throwing an instance of 'c10::CUDAError'
  what():  CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Exception raised from query at /opt/conda/conda-bld/pytorch_1644999004077/work/aten/src/ATen/cuda/CUDAEvent.h:95 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x4d (0x14b6eefe8ead in /gpfswork/rech/six/commun/conda/py38-pt111/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x11a (0x14b72cf9271a in /gpfswork/rech/six/commun/conda/py38-pt111/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cpp.so)
frame #2: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0x50 (0x14b72cf94d00 in /gpfswork/rech/six/commun/conda/py38-pt111/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cpp.so)
frame #3: c10d::ProcessGroupNCCL::workCleanupLoop() + 0x145 (0x14b72cf95f95 in /gpfswork/rech/six/commun/conda/py38-pt111/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cpp.so)
frame #4: <unknown function> + 0xc9039 (0x14b784e05039 in /gpfswork/rech/six/commun/conda/py38-pt111/lib/python3.8/site-packages/torch/lib/../../../../libstdc++.so.6)
frame #5: <unknown function> + 0x82de (0x14b7a4fc22de in /lib64/libpthread.so.0)
frame #6: clone + 0x43 (0x14b7a4cf3e83 in /lib64/libc.so.6)

with CUDA_LAUNCH_BLOCKING=1 getting:

Traceback (most recent call last):
  File "/gpfswork/rech/six/commun/code/tr8b-104B/Megatron-DeepSpeed-tr8b-104B/pretrain_gpt.py", line 255, in <module>
    pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
  File "/gpfsssd/worksf/projects/rech/six/commun/code/tr8b-104B/Megatron-DeepSpeed/megatron/training.py", line 141, in pretrain
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
  File "/gpfsssd/worksf/projects/rech/six/commun/code/tr8b-104B/Megatron-DeepSpeed/megatron/training.py", line 396, in setup_model_and_optimizer
    model, optimizer, _, lr_scheduler = deepspeed.initialize(
  File "/gpfsssd/worksf/projects/rech/six/commun/code/tr8b-104B/DeepSpeed-master/deepspeed/__init__.py", line 132, in initialize
    engine = PipelineEngine(args=args,
  File "/gpfsssd/worksf/projects/rech/six/commun/code/tr8b-104B/DeepSpeed-master/deepspeed/runtime/pipe/engine.py", line 69, in __init__
    super().__init__(*super_args, **super_kwargs)
  File "/gpfsssd/worksf/projects/rech/six/commun/code/tr8b-104B/DeepSpeed-master/deepspeed/runtime/engine.py", line 293, in __init__
    self._configure_optimizer(optimizer, model_parameters)
  File "/gpfsssd/worksf/projects/rech/six/commun/code/tr8b-104B/DeepSpeed-master/deepspeed/runtime/engine.py", line 1139, in _configure_optimizer
    self.optimizer = self._configure_zero_optimizer(basic_optimizer)
  File "/gpfsssd/worksf/projects/rech/six/commun/code/tr8b-104B/DeepSpeed-master/deepspeed/runtime/engine.py", line 1360, in _configure_zero_optimizer
    optimizer = DeepSpeedZeroOptimizer(
  File "/gpfsssd/worksf/projects/rech/six/commun/code/tr8b-104B/DeepSpeed-master/deepspeed/runtime/zero/stage_1_and_2.py", line 493, in __init__
    self.initialize_optimizer_states()
  File "/gpfsssd/worksf/projects/rech/six/commun/code/tr8b-104B/DeepSpeed-master/deepspeed/runtime/zero/stage_1_and_2.py", line 583, in initialize_optimizer_states
    self.optimizer.step()
  File "/gpfswork/rech/six/commun/conda/py38-pt111/lib/python3.8/site-packages/torch/optim/optimizer.py", line 88, in wrapper
    return func(*args, **kwargs)
  File "/gpfswork/rech/six/commun/conda/py38-pt111/lib/python3.8/site-packages/apex/optimizers/fused_adam.py", line 165, in step
    multi_tensor_applier(self.multi_tensor_adam,
  File "/gpfswork/rech/six/commun/conda/py38-pt111/lib/python3.8/site-packages/apex/multi_tensor_apply/multi_tensor_apply.py", line 27, in __call__
    return op(self.chunk_size,
RuntimeError: CUDA error: an illegal memory access was encountered

there is about 35GB free out of 80GB when this happens. (before self.optimizer.step())

We observed that if the model is just slightly smaller it all works, so somehow multi_tensor_applier tries to duplicate the first large param group and crashes. Splitting it in 2 halves seems to workaround this issue.

This workaround is a courtesy of @samyam and @jeffra

Reading other related tickets I have tried all the proposals, including set_device and none helped or we already were doing it.
See: NVIDIA/apex#319

There was also a coredump, bt attached. It appears to be failing to free some resource and getting illegal memory access there. some devices must be crossing there.

log-core.txt

@stas00 stas00 merged commit 541b967 into main Feb 18, 2022
@stas00 stas00 deleted the apex-multi_tensor_applier-workaround branch February 18, 2022 02:10
adammoody pushed a commit to adammoody/Megatron-DeepSpeed that referenced this pull request Dec 18, 2023
)

The megatron dataset is also very useful for some small models or other workload.
Adding an argument allows it to be used by external trainers

Signed-off-by: yuanwu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant