Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1. Reduce compile time by 22%+ 2. Fix compile linking error on Ubuntu 22.04 gcc/g++ 11.4 with Cuda 12.4 #171

Merged
merged 1 commit into from
Mar 11, 2024

Conversation

Qubitium
Copy link
Contributor

@Qubitium Qubitium commented Mar 11, 2024

Test Env:

Os/Container: Ubuntu 22.04
GCC/G++: 11.4.0 (Ubuntu 11.4.0-1ubuntu1~22.04) 
Cpu: Zen3 9334 (2x)
Cpu cores/threads: Container limited to 80 instead of full 96
Ram: 2TB 
Disk: ZFS array

Torch: 2.2.0
Python 3.11.5 (miniconda)
Cuda: 12.4

This PR resolves two issues 1) compile linking failed on my env 2) reduce compile time by ~22%. I expect the diff to be larger if underlying io storage is slow spinners/ssd.

  1. Resolve compilation compat in some envs causing final stage ld/linking errors: Ubuntu 22.04 gcc/g++ 11.4. Most likely due to some strange reason the -std=c++17 flag is not propagated correctly? I am unsure why force setting -std=c++17 resolves issue when nvcc already defaults to -std=c++17. There should be no harm to foce this flag since nvcc already defaults to -std=c++17 based on my monitoring of the compilation processes on main branch.

Linking Error Stracktrace

/root/flashinfer/python/build/temp.linux-x86_64-cpython-311/csrc/batch_decode.o: in function `cudaError flashinfer::BatchDecodeWithPaddedKVCacheDispatched<8u, 64u, (flashinfer::QKVLayout)1, (flashinfer::PosEncodingMode)1, __half, __half>(__half*, __half*, __half*, __half*, __half*, float*, unsigned int, unsigned int, unsigned int, float, float, float, CUstream_st*) [clone .isra.0]':
tmpxft_0006a162_00000000-6_batch_decode.compute_89.cudafe1.cpp:(.text+0x4a0): relocation truncated to fit: R_X86_64_REX_GOTPCRELX against symbol `void flashinfer::BatchDecodeWithPaddedKVCacheKernel<(flashinfer::QKVLayout)1, (flashinfer::PosEncodingMode)1, 2u, 8u, 8u, 8u, 2u, __half, __half>(__half*, __half*, __half*, __half*, float*, flashinfer::tensor_info_t<(flashinfer::QKVLayout)1, 8u, (8u)*(8u)>, float, float, float)' defined in .text._ZN10flashinfer34BatchDecodeWithPaddedKVCacheKernelILNS_9QKVLayoutE1ELNS_15PosEncodingModeE1ELj2ELj8ELj8ELj8ELj2E6__halfS3_EEvPT6_S5_S5_PT7_PfNS_13tensor_info_tIXT_EXT4_EXmlT3_T2_EEEfff[_ZN10flashinfer34BatchDecodeWithPaddedKVCacheKernelILNS_9QKVLayoutE1ELNS_15PosEncodingModeE1ELj2ELj8ELj8ELj8ELj2E6__halfS3_EEvPT6_S5_S5_PT7_PfNS_13tensor_info_tIXT_EXT4_EXmlT3_T2_EEEfff] section in /root/flashinfer/python/build/temp.linux-x86_64-cpython-311/csrc/batch_decode.o
tmpxft_0006a162_00000000-6_batch_decode.compute_89.cudafe1.cpp:(.text+0x6a3): relocation truncated to fit: R_X86_64_REX_GOTPCRELX against symbol `std::cerr@@GLIBCXX_3.4' defined in .bss section in /root/miniconda3/lib/libstdc++.so
tmpxft_0006a162_00000000-6_batch_decode.compute_89.cudafe1.cpp:(.text+0x7d8): relocation truncated to fit: R_X86_64_REX_GOTPCRELX against symbol `std::ctype<char>::do_widen(char) const' defined in .text._ZNKSt5ctypeIcE8do_widenEc[_ZNKSt5ctypeIcE8do_widenEc] section in /root/flashinfer/python/build/temp.linux-x86_64-cpython-311/csrc/batch_decode.o
tmpxft_0006a162_00000000-6_batch_decode.compute_89.cudafe1.cpp:(.text+0x7f3): relocation truncated to fit: R_X86_64_REX_GOTPCRELX against symbol `std::cerr@@GLIBCXX_3.4' defined in .bss section in /root/miniconda3/lib/libstdc++.so
/root/flashinfer/python/build/temp.linux-x86_64-cpython-311/csrc/batch_decode.o: in function `cudaError flashinfer::BatchDecodeWithPaddedKVCacheDispatched<8u, 256u, (flashinfer::QKVLayout)0, (flashinfer::PosEncodingMode)0, __nv_fp8_e4m3, __half>(__nv_fp8_e4m3*, __nv_fp8_e4m3*, __nv_fp8_e4m3*, __half*, __half*, float*, unsigned int, unsigned int, unsigned int, float, float, float, CUstream_st*) [clone .isra.0]':
tmpxft_0006a162_00000000-6_batch_decode.compute_89.cudafe1.cpp:(.text+0x930): relocation truncated to fit: R_X86_64_REX_GOTPCRELX against symbol `void flashinfer::BatchDecodeWithPaddedKVCacheKernel<(flashinfer::QKVLayout)0, (flashinfer::PosEncodingMode)0, 2u, 16u, 16u, 8u, 1u, __nv_fp8_e4m3, __half>(__nv_fp8_e4m3*, __nv_fp8_e4m3*, __nv_fp8_e4m3*, __half*, float*, flashinfer::tensor_info_t<(flashinfer::QKVLayout)0, 8u, (16u)*(16u)>, float, float, float)' defined in .text._ZN10flashinfer34BatchDecodeWithPaddedKVCacheKernelILNS_9QKVLayoutE0ELNS_15PosEncodingModeE0ELj2ELj16ELj16ELj8ELj1E13__nv_fp8_e4m36__halfEEvPT6_S6_S6_PT7_PfNS_13tensor_info_tIXT_EXT4_EXmlT3_T2_EEEfff[_ZN10flashinfer34BatchDecodeWithPaddedKVCacheKernelILNS_9QKVLayoutE0ELNS_15PosEncodingModeE0ELj2ELj16ELj16ELj8ELj1E13__nv_fp8_e4m36__halfEEvPT6_S6_S6_PT7_PfNS_13tensor_info_tIXT_EXT4_EXmlT3_T2_EEEfff] section in /root/flashinfer/python/build/temp.linux-x86_64-cpython-311/csrc/batch_decode.o
tmpxft_0006a162_00000000-6_batch_decode.compute_89.cudafe1.cpp:(.text+0xb33): relocation truncated to fit: R_X86_64_REX_GOTPCRELX against symbol `std::cerr@@GLIBCXX_3.4' defined in .bss section in /root/miniconda3/lib/libstdc++.so
tmpxft_0006a162_00000000-6_batch_decode.compute_89.cudafe1.cpp:(.text+0xc68): additional relocation overflows omitted from the output
build/lib.linux-x86_64-cpython-311/flashinfer/_kernels.cpython-311-x86_64-linux-gnu.so: PC-relative offset overflow in PLT entry for `_ZN10flashinfer38BatchPrefillWithPagedKVCacheDispatchedILNS_11PageStorageE0ELNS_9QKVLayoutE1ELj1ELj1ELj4ELj256ELNS_15PosEncodingModeE0ELb1ELb1E6__halfS4_iEE9cudaErrorPT8_PT10_S9_S9_S9_NS_10paged_kv_tIXT_EXT0_ES6_S8_EEPT9_PfSE_jfffP11CUstream_st'
collect2: error: ld returned 1 exit status
error: command '/usr/bin/g++' failed with exit code 1
  1. Reduce compile time: A) Reduce thread contention due to ninja naively using all cores. Instead, use half cores but more threads. B) Only compile for the base archs. sm80 for ampere, sm89 for ada, and sm90 for hopper. Skip compilation for minor cuda archs such as ampere/86.

Test was performed on a Zen3 2x9334 system with 48 cores (96 threads) but for test I limited via lxd container to max 80 cores/threads. Ran an earlier version of the PR test with full 96 cores/threads and the diff % is same. Between each test I removed python/build and python/csrc/generated dirs.

Main:

real    43m31.594s
user    1441m59.558s
sys     79m51.832s

PR:

real    33m47.865s
user    1247m17.189s
sys     70m19.907s

Diff is ~22% but I expect the value to be even larger if underlying io system is slow spinner/sdd since thread contention would kill small io during compile.

Cause of slow compilation is as follows:

  1. Primary cause is thread contention (resource over-subscription). Ninja using all cores by default is not optimal. Fix: Use half the cores but increase the number of threads nvcc can spawn from 1 to 8.

  2. Also reduce the unnecessary compilation of minor archs. For example, Ampere has sm80 and sm86. The only diff is hardware resources such as sm cores, cache size, etc. I did not find any documentation from Nvidia that nvcc actually compile differently due to same arch but different cache size, sm cores. Just compile for the base archs. sm80 for Ampere, sm89 for Ada, sm90 for Hopper.

ref: https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html

@yzh119

…king errors: Ubuntu 22.04 gcc/g++ 11.4

2. Reduce compile time: A) Reduce thread contention due to ninja naively using all cores. Instead, use half cores but more threads. B) Only compile for the base archs. sm80 for ampere, sm89 for ada, and sm90 for hopper. Skip compilation for minor cuda archs such as ampere/86.
@yzh119
Copy link
Collaborator

yzh119 commented Mar 11, 2024

I did not find any documentation from Nvidia that nvcc actually compile differently due to same arch but different cache size, sm cores. Just compile for the base archs.

The throughput per sm doubles from sm80 to sm86 (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#arithmetic-instructions), so nvcc might behave a little bit differently because of the difference.

I agree with you to only keep sm80/sm89/sm90 considering sm86 is not mainstream cuda arch.

@yzh119 yzh119 merged commit 2657813 into flashinfer-ai:main Mar 11, 2024
@yzh119
Copy link
Collaborator

yzh119 commented Mar 11, 2024

Thank you for your contribution!

@Qubitium
Copy link
Contributor Author

@yzh119 See if you can stop and restart the CI runners. I see the current one is already failing with very similar ld/linking errors that the PR may fix.

https://github.com/flashinfer-ai/flashinfer/actions/runs/8226582215

@yzh119 yzh119 mentioned this pull request Mar 11, 2024
yzh119 added a commit that referenced this pull request Mar 11, 2024
1. Do not generate prefill kernels for `page_size=8`
2. Build with `-Xfatbin=-compress-all` to reduce binary size.

Followup of #171 , @Qubitium the cuda architectures to be compiled could
be controlled by environment variable `TORCH_CUDA_ARCH_LIST`, so I
removed the gencode/archs specified in compile args.
@yzh119
Copy link
Collaborator

yzh119 commented Mar 11, 2024

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants