Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yuan-2.0-2.1B的多机分布式预训练问题:FileNotFoundError: [Errno 2] No such file or directory:'home/Yuan-2.0-main/megatron/fused_kernels/build/lock' #86

Open
BoyCO3 opened this issue Jan 9, 2024 · 5 comments

Comments

@BoyCO3
Copy link

BoyCO3 commented Jan 9, 2024

代码在单机上可以跑通,到多机上的master会卡在这一步:
setting number of micro-batches to constant 96

building YuanTokenizer tokenizer ...
padded vocab (size: 134953) with 87 dummy tokens (new size: 135040)
initializing torch distributed ...
initialized tensor model parallel with size 2
initialized pipeline model parallel with size 2
setting random seeds to 1234 ...
compiling dataset index builder ...
make: Entering directory '/haoranzheng/Yuan-2.0-main/megatron/data'
make: Nothing to be done for 'default'.
make: Leaving directory '/haoranzheng/Yuan-2.0-main/megatron/data'

done with dataset index builder. Compilation time: 0.050 seconds
WARNING: constraints for invoking optimized fused softmax kernel are not met. We default back to unfused kernel invocations.
compiling and loading fused kernels ...
Detected CUDA files, patching ldflags
Emitting ninja build file /haoranzheng/Yuan-2.0-main/megatron/fused_kernels/build/build.ninja...
Building extension module scaled_masked_softmax_cuda...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/3] c++ -MMD -MF scaled_masked_softmax.o.d -DTORCH_EXTENSION_NAME=scaled_masked_softmax_cuda -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE="gcc" -DPYBIND11_STDLIB="libstdcpp" -DPYBIND11_BUILD_ABI="cxxabi1016" -isystem /usr/local/lib/python3.10/dist-packages/torch/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/TH -isystem /usr/local/lib/python3.10/dist-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /usr/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=1 -fPIC -std=c++17 -O3 -c /haoranzheng/Yuan-2.0-main/megatron/fused_kernels/scaled_masked_softmax.cpp -o scaled_masked_softmax.o
[2/3] /usr/local/cuda/bin/nvcc -DTORCH_EXTENSION_NAME=scaled_masked_softmax_cuda -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE="gcc" -DPYBIND11_STDLIB="libstdcpp" -DPYBIND11_BUILD_ABI="cxxabi1016" -isystem /usr/local/lib/python3.10/dist-packages/torch/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/TH -isystem /usr/local/lib/python3.10/dist-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /usr/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=1 -D__CUDA_NO_HALF_OPERATORS -D__CUDA_NO_HALF_CONVERSIONS
-D__CUDA_NO_BFLOAT16_CONVERSIONS
-D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 --compiler-options '-fPIC' -O3 -gencode arch=compute_70,code=sm_70 --use_fast_math -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ --expt-relaxed-constexpr --expt-extended-lambda -gencode arch=compute_80,code=sm_80 -std=c++17 -c /haoranzheng/Yuan-2.0-main/megatron/fused_kernels/scaled_masked_softmax_cuda.cu -o scaled_masked_softmax_cuda.cuda.o
[3/3] c++ scaled_masked_softmax.o scaled_masked_softmax_cuda.cuda.o -shared -L/usr/local/lib/python3.10/dist-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/usr/local/cuda/lib64 -lcudart -o scaled_masked_softmax_cuda.so
Loading extension module scaled_masked_softmax_cuda...
Detected CUDA files, patching ldflags
Emitting ninja build file /haoranzheng/Yuan-2.0-main/megatron/fused_kernels/build/build.ninja...
Building extension module scaled_softmax_cuda...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/3] c++ -MMD -MF scaled_softmax.o.d -DTORCH_EXTENSION_NAME=scaled_softmax_cuda -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE="gcc" -DPYBIND11_STDLIB="libstdcpp" -DPYBIND11_BUILD_ABI="cxxabi1016" -isystem /usr/local/lib/python3.10/dist-packages/torch/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/TH -isystem /usr/local/lib/python3.10/dist-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /usr/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=1 -fPIC -std=c++17 -O3 -c /haoranzheng/Yuan-2.0-main/megatron/fused_kernels/scaled_softmax.cpp -o scaled_softmax.o
[2/3] /usr/local/cuda/bin/nvcc -DTORCH_EXTENSION_NAME=scaled_softmax_cuda -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE="gcc" -DPYBIND11_STDLIB="libstdcpp" -DPYBIND11_BUILD_ABI="cxxabi1016" -isystem /usr/local/lib/python3.10/dist-packages/torch/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/TH -isystem /usr/local/lib/python3.10/dist-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /usr/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=1 -D__CUDA_NO_HALF_OPERATORS -D__CUDA_NO_HALF_CONVERSIONS
-D__CUDA_NO_BFLOAT16_CONVERSIONS
-D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 --compiler-options '-fPIC' -O3 -gencode arch=compute_70,code=sm_70 --use_fast_math -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ --expt-relaxed-constexpr --expt-extended-lambda -gencode arch=compute_80,code=sm_80 -std=c++17 -c /haoranzheng/Yuan-2.0-main/megatron/fused_kernels/scaled_softmax_cuda.cu -o scaled_softmax_cuda.cuda.o
[3/3] c++ scaled_softmax.o scaled_softmax_cuda.cuda.o -shared -L/usr/local/lib/python3.10/dist-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/usr/local/cuda/lib64 -lcudart -o scaled_softmax_cuda.so
Loading extension module scaled_softmax_cuda...

另外一个worker会报错:

setting tensorboard ...
Traceback (most recent call last):
File "/haoranzheng/Yuan-2.0-main/pretrain_yuan.py", line 124, in
pretrain(train_valid_test_datasets_provider,
File "/haoranzheng/Yuan-2.0-main/megatron/training.py", line 90, in pretrain
initialize_megatron(extra_args_provider=extra_args_provider,
File "/haoranzheng/Yuan-2.0-main/megatron/initialize.py", line 82, in initialize_megatron
_compile_dependencies()
File "/haoranzheng/Yuan-2.0-main/megatron/initialize.py", line 134, in _compile_dependencies
fused_kernels.load(args)
File "/haoranzheng/Yuan-2.0-main/megatron/fused_kernels/init.py", line 74, in load
scaled_softmax_cuda = _cpp_extention_load_helper(
File "/haoranzheng/Yuan-2.0-main/megatron/fused_kernels/init.py", line 37, in _cpp_extention_load_helper
return cpp_extension.load(
File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 1283, in load
return _jit_compile(
File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 1520, in _jit_compile
baton.release()
File "/usr/local/lib/python3.10/dist-packages/torch/utils/file_baton.py", line 49, in release
os.remove(self.lock_file_path)
FileNotFoundError: [Errno 2] No such file or directory: '/haoranzheng/Yuan-2.0-main/megatron/fused_kernels/build/lock'
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 101 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 102 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 103 closing signal SIGTERM
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 3 (pid: 104) of binary: /usr/bin/python
Traceback (most recent call last):
File "/usr/local/bin/torchrun", line 33, in
sys.exit(load_entry_point('torch==2.1.0a0+4136153', 'console_scripts', 'torchrun')())
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/errors/init.py", line 346, in wrapper
return f(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 797, in main
run(args)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 788, in run
elastic_launch(
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 134, in call
return launch_agent(self._config, self._entrypoint, list(args))
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
pretrain_yuan.py FAILED


Failures:
<NO_OTHER_FAILURES>

Root Cause (first observed failure):
[0]:
time : 2024-01-05_13:42:59
host : pytorch-2e4abf84-worker-0
rank : 7 (local_rank: 3)
exitcode : 1 (pid: 104)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

根据报错是cpp的编译问题出错,home/Yuan-2.0-main/megatron/fused_kernels/build目录下没有lock文件。搜索了一些相关问题,有说把build目录删除重新编译,尝试过发现也不行。

@BoyCO3 BoyCO3 changed the title Yuan-2.0-2.1B的多机分布式预训练问题: Yuan-2.0-2.1B的多机分布式预训练问题:FileNotFoundError: [Errno 2] No such file or directory:'home/Yuan-2.0-main/megatron/fused_kernels/build/lock' Jan 9, 2024
@Shawn-IEITSystems
Copy link
Collaborator

@zhaoxudong01-ieisystem

@zhaoxudong01
Copy link
Collaborator

麻烦确认一下代码路径是否在多机共享的目录下。

@jalaxy33
Copy link

jalaxy33 commented Jan 11, 2024

我也遇到了同样的问题,单机多卡脚本正常运行,多机多卡就会遇到这个cpp编译读写锁造成的bug。目前我已经解决这个问题了,只要在训练脚本中为 torchrun pretrain_yuan.py 添加 --no-masked-softmax-fusion参数阻止启动时的 masked-softmax-fusion 的编译即可,移除这一特性不影响训练结果。例如:

torchrun $DISTRIBUTED_ARGS pretrain_yuan.py \
    $GPT_ARGS \
    $DATA_ARGS \
    $OUTPUT_ARGS \
    $LOG_ARGS \
    --distributed-backend nccl \
    --save $CHECKPOINT_PATH \
    --load $CHECKPOINT_PATH \
    --no-masked-softmax-fusion   # <--- 加上这个参数

p.s. 这个问题其实是 Megatron-LM 的 bug,之前也有人在那边提过相同的 issue,目前看来 Nvidia 的开发者应该已经修复了这个 bug。在这个回复中,他们提到:

Glad you figured it out. An update to Megatron coming very soon will remove the need for the fused_kernel compilation, since we'll be switching to using the Apex kernels. Closing this issue.

我对比了一下出 bug 的代码片段,发现他们的确已经移除了与 fused_kernel 编译相关的代码(对比这里),建议 Yuan-2.0 更新一下目前的 megatron 相关的代码。

@BoyCO3
Copy link
Author

BoyCO3 commented Jan 12, 2024

麻烦确认一下代码路径是否在多机共享的目录下。

的确是在共享多机目录下的

@BoyCO3
Copy link
Author

BoyCO3 commented Jan 12, 2024

我也遇到了同样的问题,单机多卡脚本正常运行,多机多卡就会遇到这个cpp编译读写锁造成的bug。目前我已经解决这个问题了,只要在训练脚本中为 torchrun pretrain_yuan.py 添加 --no-masked-softmax-fusion参数阻止启动时的 masked-softmax-fusion 的编译即可,移除这一特性不影响训练结果。例如:

torchrun $DISTRIBUTED_ARGS pretrain_yuan.py \
    $GPT_ARGS \
    $DATA_ARGS \
    $OUTPUT_ARGS \
    $LOG_ARGS \
    --distributed-backend nccl \
    --save $CHECKPOINT_PATH \
    --load $CHECKPOINT_PATH \
    --no-masked-softmax-fusion   # <--- 加上这个参数

p.s. 这个问题其实是 Megatron-LM 的 bug,之前也有人在那边提过相同的 issue,目前看来 Nvidia 的开发者应该已经修复了这个 bug。在这个回复中,他们提到:

Glad you figured it out. An update to Megatron coming very soon will remove the need for the fused_kernel compilation, since we'll be switching to using the Apex kernels. Closing this issue.

我对比了一下出 bug 的代码片段,发现他们的确已经移除了与 fused_kernel 编译相关的代码(对比这里),建议 Yuan-2.0 更新一下目前的 megatron 相关的代码。

问题得到解决,十分感谢

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants