Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Compile source code error for ROCM6.0 #2646

Closed
lhcalibur opened this issue Jan 29, 2024 · 12 comments · Fixed by #2648
Closed

[BUG] Compile source code error for ROCM6.0 #2646

lhcalibur opened this issue Jan 29, 2024 · 12 comments · Fixed by #2648

Comments

@lhcalibur
Copy link

lhcalibur commented Jan 29, 2024

The error maybe caused by commit #2279
ROCM Version: 6.0
@zhaoyang-star @zhuohan123

creating build/lib.linux-x86_64-3.10
creating build/lib.linux-x86_64-3.10/vllm
x86_64-linux-gnu-g++ -shared -Wl,-O1 -Wl,-Bsymbolic-functions -Wl,-Bsymbolic-functions -g -fwrapv -O2 -Wl,-Bsymbolic-functions -g -fwrapv -O2 -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 /projects/vllm_master/build/temp.linux-x86_64-3.10/csrc/activation_kernels.o /projects/vllm_master/build/temp.linux-x86_64-3.10/csrc/attention/attention_kernels.o /projects/vllm_master/build/temp.linux-x86_64-3.10/csrc/cache_kernels.o /projects/vllm_master/build/temp.linux-x86_64-3.10/csrc/hip_utils_kernels.o /projects/vllm_master/build/temp.linux-x86_64-3.10/csrc/layernorm_kernels.o /projects/vllm_master/build/temp.linux-x86_64-3.10/csrc/pos_encoding_kernels.o /projects/vllm_master/build/temp.linux-x86_64-3.10/csrc/pybind.o /projects/vllm_master/build/temp.linux-x86_64-3.10/csrc/quantization/gptq/q_gemm.o /projects/vllm_master/build/temp.linux-x86_64-3.10/csrc/quantization/squeezellm/quant_hip_kernel.o -L/projects/vllm_master/venv_rocm-6.0/lib/python3.10/site-packages/torch/lib -L/opt/rocm-6.0.0/lib -L/opt/rocm-6.0.0/hip/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lamdhip64 -lc10_hip -ltorch_hip -o build/lib.linux-x86_64-3.10/vllm/_C.cpython-310-x86_64-linux-gnu.so
/usr/bin/ld: /projects/vllm_master/build/temp.linux-x86_64-3.10/csrc/cache_kernels.o: in function __float2bfloat16(float)': cache_kernels.hip:(.text+0x0): multiple definition of __float2bfloat16(float)'; /projects/vllm_master/build/temp.linux-x86_64-3.10/csrc/attention/attention_kernels.o:attention_kernels.hip:(.text+0x0): first defined here
/usr/bin/ld: /projects/vllm_master/build/temp.linux-x86_64-3.10/csrc/cache_kernels.o: in function __bfloat1622float2(__hip_bfloat162)': cache_kernels.hip:(.text+0x40): multiple definition of __bfloat1622float2(__hip_bfloat162)'; /projects/vllm_master/build/temp.linux-x86_64-3.10/csrc/attention/attention_kernels.o:attention_kernels.hip:(.text+0x40): first defined here
/usr/bin/ld: /projects/vllm_master/build/temp.linux-x86_64-3.10/csrc/cache_kernels.o: in function __double2bfloat16(double)': cache_kernels.hip:(.text+0x60): multiple definition of __double2bfloat16(double)'; /projects/vllm_master/build/temp.linux-x86_64-3.10/csrc/attention/attention_kernels.o:attention_kernels.hip:(.text+0x60): first defined here
/usr/bin/ld: /projects/vllm_master/build/temp.linux-x86_64-3.10/csrc/cache_kernels.o: in function __float22bfloat162_rn(HIP_vector_type<float, 2u>)': cache_kernels.hip:(.text+0xa0): multiple definition of __float22bfloat162_rn(HIP_vector_type<float, 2u>)'; /projects/vllm_master/build/temp.linux-x86_64-3.10/csrc/attention/attention_kernels.o:attention_kernels.hip:(.text+0xa0): first defined here
/usr/bin/ld: /projects/vllm_master/build/temp.linux-x86_64-3.10/csrc/cache_kernels.o: in function __high2float(__hip_bfloat162)': cache_kernels.hip:(.text+0x110): multiple definition of __high2float(__hip_bfloat162)'; /projects/vllm_master/build/temp.linux-x86_64-3.10/csrc/attention/attention_kernels.o:attention_kernels.hip:(.text+0x110): first defined here
/usr/bin/ld: /projects/vllm_master/build/temp.linux-x86_64-3.10/csrc/cache_kernels.o: in function __low2float(__hip_bfloat162)': cache_kernels.hip:(.text+0x120): multiple definition of __low2float(__hip_bfloat162)'; /projects/vllm_master/build/temp.linux-x86_64-3.10/csrc/attention/attention_kernels.o:attention_kernels.hip:(.text+0x120): first defined here
collect2: error: ld returned 1 exit status
error: command '/usr/bin/x86_64-linux-gnu-g++' failed with exit code 1

@zhaoyang-star
Copy link
Contributor

@lhcalibur Thanks for your feedback. Could you help me to verify it by appling pr #2648 ?
For fully support amd gpu, I think we'd better add amd gpu to ci. @simon-mo

@lhcalibur
Copy link
Author

@zhaoyang-star

It still has an error:

FAILED: /projects/vllm_master/build/temp.linux-x86_64-3.10/csrc/cache_kernels.o
/opt/rocm-6.0.0/bin/hipcc -I/projects/vllm_master/venv_rocm-6.0/lib/python3.10/site-packages/torch/include -I/projects/vllm_master/venv_rocm-6.0/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/projects/vllm_master/venv_rocm-6.0/lib/python3.10/site-packages/torch/include/TH -I/projects/vllm_master/venv_rocm-6.0/lib/python3.10/site-packages/torch/include/THC -I/projects/vllm_master/venv_rocm-6.0/lib/python3.10/site-packages/torch/include/THH -I/opt/rocm-6.0.0/include -I/projects/vllm_master/venv_rocm-6.0/include -I/usr/include/python3.10 -c -c /projects/vllm_master/csrc/cache_kernels.hip -o /projects/vllm_master/build/temp.linux-x86_64-3.10/csrc/cache_kernels.o -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O2 -std=c++17 -DUSE_ROCM -U__HIP_NO_HALF_CONVERSIONS__ -U__HIP_NO_HALF_OPERATORS__ -D_GLIBCXX_USE_CXX11_ABI=0 --offload-arch=gfx90a -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_C -D_GLIBCXX_USE_CXX11_ABI=0 -fno-gpu-rdc
/projects/vllm_master/csrc/cache_kernels.hip:50:5: warning: ignoring return value of function declared with 'nodiscard' attribute [-Wunused-result]
hipMemcpyAsync(
^~~~~~~~~~~~~~
/projects/vllm_master/csrc/cache_kernels.hip:252:30: error: unknown type name '__nv_bfloat16'; did you mean 'hip_bfloat16'?
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
^~~~~~~~~~~~~
hip_bfloat16
/projects/vllm_master/csrc/cache_kernels.hip:212:54: note: expanded from macro 'CALL_RESHAPE_AND_CACHE'
hipLaunchKernelGGL(( vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE>), dim3(grid), dim3(block), 0, stream,
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:244:74: note: expanded from macro 'hipLaunchKernelGGL'
#define hipLaunchKernelGGL(kernelName, ...) hipLaunchKernelGGLInternal((kernelName), VA_ARGS)
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:241:9: note: expanded from macro 'hipLaunchKernelGGLInternal'
kernelName<<<(numBlocks), (numThreads), (memPerBlock), (streamId)>>>(VA_ARGS);
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_bfloat16.h:57:8: note: 'hip_bfloat16' declared here
struct hip_bfloat16
^
/projects/vllm_master/csrc/cache_kernels.hip:252:45: error: unknown type name '__nv_bfloat16'; did you mean 'hip_bfloat16'?
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
^~~~~~~~~~~~~
hip_bfloat16
/projects/vllm_master/csrc/cache_kernels.hip:212:60: note: expanded from macro 'CALL_RESHAPE_AND_CACHE'
hipLaunchKernelGGL(( vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE>), dim3(grid), dim3(block), 0, stream,
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:244:74: note: expanded from macro 'hipLaunchKernelGGL'
#define hipLaunchKernelGGL(kernelName, ...) hipLaunchKernelGGLInternal((kernelName), VA_ARGS)
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:241:9: note: expanded from macro 'hipLaunchKernelGGLInternal'
kernelName<<<(numBlocks), (numThreads), (memPerBlock), (streamId)>>>(VA_ARGS);
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_bfloat16.h:57:8: note: 'hip_bfloat16' declared here
struct hip_bfloat16
^
/projects/vllm_master/csrc/cache_kernels.hip:252:30: error: unknown type name '__nv_bfloat16'; did you mean 'hip_bfloat16'?
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
^~~~~~~~~~~~~
hip_bfloat16
/projects/vllm_master/csrc/cache_kernels.hip:213:22: note: expanded from macro 'CALL_RESHAPE_AND_CACHE'
reinterpret_cast<KV_T*>(key.data_ptr()),
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:244:87: note: expanded from macro 'hipLaunchKernelGGL'
#define hipLaunchKernelGGL(kernelName, ...) hipLaunchKernelGGLInternal((kernelName), VA_ARGS)
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:241:78: note: expanded from macro 'hipLaunchKernelGGLInternal'
kernelName<<<(numBlocks), (numThreads), (memPerBlock), (streamId)>>>(VA_ARGS);
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_bfloat16.h:57:8: note: 'hip_bfloat16' declared here
struct hip_bfloat16
^
/projects/vllm_master/csrc/cache_kernels.hip:252:30: error: unknown type name '__nv_bfloat16'; did you mean 'hip_bfloat16'?
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
^~~~~~~~~~~~~
hip_bfloat16
/projects/vllm_master/csrc/cache_kernels.hip:214:22: note: expanded from macro 'CALL_RESHAPE_AND_CACHE'
reinterpret_cast<KV_T*>(value.data_ptr()),
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:244:87: note: expanded from macro 'hipLaunchKernelGGL'
#define hipLaunchKernelGGL(kernelName, ...) hipLaunchKernelGGLInternal((kernelName), VA_ARGS)
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:241:78: note: expanded from macro 'hipLaunchKernelGGLInternal'
kernelName<<<(numBlocks), (numThreads), (memPerBlock), (streamId)>>>(VA_ARGS);
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_bfloat16.h:57:8: note: 'hip_bfloat16' declared here
struct hip_bfloat16
^
/projects/vllm_master/csrc/cache_kernels.hip:252:45: error: unknown type name '__nv_bfloat16'; did you mean 'hip_bfloat16'?
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
^~~~~~~~~~~~~
hip_bfloat16
/projects/vllm_master/csrc/cache_kernels.hip:215:22: note: expanded from macro 'CALL_RESHAPE_AND_CACHE'
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()),
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:244:87: note: expanded from macro 'hipLaunchKernelGGL'
#define hipLaunchKernelGGL(kernelName, ...) hipLaunchKernelGGLInternal((kernelName), VA_ARGS)
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:241:78: note: expanded from macro 'hipLaunchKernelGGLInternal'
kernelName<<<(numBlocks), (numThreads), (memPerBlock), (streamId)>>>(VA_ARGS);
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_bfloat16.h:57:8: note: 'hip_bfloat16' declared here
struct hip_bfloat16
^
/projects/vllm_master/csrc/cache_kernels.hip:252:45: error: unknown type name '__nv_bfloat16'; did you mean 'hip_bfloat16'?
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
^~~~~~~~~~~~~
hip_bfloat16
/projects/vllm_master/csrc/cache_kernels.hip:216:22: note: expanded from macro 'CALL_RESHAPE_AND_CACHE'
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()),
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:244:87: note: expanded from macro 'hipLaunchKernelGGL'
#define hipLaunchKernelGGL(kernelName, ...) hipLaunchKernelGGLInternal((kernelName), VA_ARGS)
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:241:78: note: expanded from macro 'hipLaunchKernelGGLInternal'
kernelName<<<(numBlocks), (numThreads), (memPerBlock), (streamId)>>>(VA_ARGS);
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_bfloat16.h:57:8: note: 'hip_bfloat16' declared here
struct hip_bfloat16
^
/projects/vllm_master/csrc/cache_kernels.hip:260:30: error: unknown type name '__nv_bfloat16'; did you mean 'hip_bfloat16'?
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
^~~~~~~~~~~~~
hip_bfloat16
/projects/vllm_master/csrc/cache_kernels.hip:212:54: note: expanded from macro 'CALL_RESHAPE_AND_CACHE'
hipLaunchKernelGGL(( vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE>), dim3(grid), dim3(block), 0, stream,
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:244:74: note: expanded from macro 'hipLaunchKernelGGL'
#define hipLaunchKernelGGL(kernelName, ...) hipLaunchKernelGGLInternal((kernelName), VA_ARGS)
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:241:9: note: expanded from macro 'hipLaunchKernelGGLInternal'
kernelName<<<(numBlocks), (numThreads), (memPerBlock), (streamId)>>>(VA_ARGS);
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_bfloat16.h:57:8: note: 'hip_bfloat16' declared here
struct hip_bfloat16
^
/projects/vllm_master/csrc/cache_kernels.hip:260:30: error: unknown type name '__nv_bfloat16'; did you mean 'hip_bfloat16'?
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
^~~~~~~~~~~~~
hip_bfloat16
/projects/vllm_master/csrc/cache_kernels.hip:213:22: note: expanded from macro 'CALL_RESHAPE_AND_CACHE'
reinterpret_cast<KV_T*>(key.data_ptr()),
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:244:87: note: expanded from macro 'hipLaunchKernelGGL'
#define hipLaunchKernelGGL(kernelName, ...) hipLaunchKernelGGLInternal((kernelName), VA_ARGS)
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:241:78: note: expanded from macro 'hipLaunchKernelGGLInternal'
kernelName<<<(numBlocks), (numThreads), (memPerBlock), (streamId)>>>(VA_ARGS);
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_bfloat16.h:57:8: note: 'hip_bfloat16' declared here
struct hip_bfloat16
^
/projects/vllm_master/csrc/cache_kernels.hip:260:30: error: unknown type name '__nv_bfloat16'; did you mean 'hip_bfloat16'?
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
^~~~~~~~~~~~~
hip_bfloat16
/projects/vllm_master/csrc/cache_kernels.hip:214:22: note: expanded from macro 'CALL_RESHAPE_AND_CACHE'
reinterpret_cast<KV_T*>(value.data_ptr()),
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:244:87: note: expanded from macro 'hipLaunchKernelGGL'
#define hipLaunchKernelGGL(kernelName, ...) hipLaunchKernelGGLInternal((kernelName), VA_ARGS)
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:241:78: note: expanded from macro 'hipLaunchKernelGGLInternal'
kernelName<<<(numBlocks), (numThreads), (memPerBlock), (streamId)>>>(VA_ARGS);
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_bfloat16.h:57:8: note: 'hip_bfloat16' declared here
struct hip_bfloat16
^
/projects/vllm_master/csrc/cache_kernels.hip:470:36: error: unknown type name '__nv_bfloat16'; did you mean 'hip_bfloat16'?
CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
^~~~~~~~~~~~~
hip_bfloat16
/projects/vllm_master/csrc/cache_kernels.hip:449:59: note: expanded from macro 'CALL_CONVERT_FP8_E5M2'
hipLaunchKernelGGL(( vllm::convert_fp8_e5m2_kernel<Tout, Tin>), dim3(grid), dim3(block), 0, stream,
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:244:74: note: expanded from macro 'hipLaunchKernelGGL'
#define hipLaunchKernelGGL(kernelName, ...) hipLaunchKernelGGLInternal((kernelName), VA_ARGS)
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:241:9: note: expanded from macro 'hipLaunchKernelGGLInternal'
kernelName<<<(numBlocks), (numThreads), (memPerBlock), (streamId)>>>(VA_ARGS);
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_bfloat16.h:57:8: note: 'hip_bfloat16' declared here
struct hip_bfloat16
^
/projects/vllm_master/csrc/cache_kernels.hip:470:36: error: unknown type name '__nv_bfloat16'; did you mean 'hip_bfloat16'?
CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
^~~~~~~~~~~~~
hip_bfloat16
/projects/vllm_master/csrc/cache_kernels.hip:450:22: note: expanded from macro 'CALL_CONVERT_FP8_E5M2'
reinterpret_cast<Tin*>(src_cache.data_ptr()),
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:244:87: note: expanded from macro 'hipLaunchKernelGGL'
#define hipLaunchKernelGGL(kernelName, ...) hipLaunchKernelGGLInternal((kernelName), VA_ARGS)
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:241:78: note: expanded from macro 'hipLaunchKernelGGLInternal'
kernelName<<<(numBlocks), (numThreads), (memPerBlock), (streamId)>>>(VA_ARGS);
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_bfloat16.h:57:8: note: 'hip_bfloat16' declared here
struct hip_bfloat16
^
/projects/vllm_master/csrc/cache_kernels.hip:476:27: error: unknown type name '__nv_bfloat16'; did you mean 'hip_bfloat16'?
CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
^~~~~~~~~~~~~
hip_bfloat16
/projects/vllm_master/csrc/cache_kernels.hip:449:53: note: expanded from macro 'CALL_CONVERT_FP8_E5M2'
hipLaunchKernelGGL(( vllm::convert_fp8_e5m2_kernel<Tout, Tin>), dim3(grid), dim3(block), 0, stream,
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:244:74: note: expanded from macro 'hipLaunchKernelGGL'
#define hipLaunchKernelGGL(kernelName, ...) hipLaunchKernelGGLInternal((kernelName), VA_ARGS)
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:241:9: note: expanded from macro 'hipLaunchKernelGGLInternal'
kernelName<<<(numBlocks), (numThreads), (memPerBlock), (streamId)>>>(VA_ARGS);
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_bfloat16.h:57:8: note: 'hip_bfloat16' declared here
struct hip_bfloat16
^
/projects/vllm_master/csrc/cache_kernels.hip:476:27: error: unknown type name '__nv_bfloat16'; did you mean 'hip_bfloat16'?
CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
^~~~~~~~~~~~~
hip_bfloat16
/projects/vllm_master/csrc/cache_kernels.hip:451:22: note: expanded from macro 'CALL_CONVERT_FP8_E5M2'
reinterpret_cast<Tout*>(dst_cache.data_ptr()),
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:244:87: note: expanded from macro 'hipLaunchKernelGGL'
#define hipLaunchKernelGGL(kernelName, ...) hipLaunchKernelGGLInternal((kernelName), VA_ARGS)
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_runtime.h:241:78: note: expanded from macro 'hipLaunchKernelGGLInternal'
kernelName<<<(numBlocks), (numThreads), (memPerBlock), (streamId)>>>(VA_ARGS);
^
/opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_bfloat16.h:57:8: note: 'hip_bfloat16' declared here
struct hip_bfloat16
^
1 warning and 13 errors generated when compiling for gfx90a.

I have tried to include the header file "attention/dtype_bfloat16.cuh" to "csrc/cache_kernels.cu"::
But will cause the multiple definition error the same as this issue. Is there any suggestion as to why this happened? Maybe I can help try to fix it. Thanks.

@zhaoyang-star
Copy link
Contributor

@lhcalibur The error is caused by __nv_bfloat16 which is undefined on rocm. I added the following lines in cache_kernel.cu. Please verify it after pulling the latest commit on #2648

#ifdef USE_ROCM
  #include <hip/hip_bf16.h>
  typedef __hip_bfloat16 __nv_bfloat16;
#endif

@cloudhan
Copy link
Contributor

Not working.

@lhcalibur
Copy link
Author

lhcalibur commented Jan 30, 2024

@zhaoyang-star
Just adding the header file "<hip/hip_bf16.h>" in the cache_kernel.cu will cause this multiple definition error, weird.

I have reproduced this error in the release tag v0.2.7:
`diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu
index 9f17353..c088627 100644
--- a/csrc/cache_kernels.cu
+++ b/csrc/cache_kernels.cu
@@ -9,6 +9,7 @@
#include
#include
#include
+#include <hip/hip_bf16.h>

void swap_blocks(
torch::Tensor& src,`

@lhcalibur
Copy link
Author

@zhaoyang-star
I find that this code will work, is this right? I have borrowed from https://github.com/pytorch/pytorch/blob/15702a8027bdfb1d6da0faec96d833e25a3c72ba/aten/src/ATen/cuda/cub.cuh#L119 & https://github.com/pytorch/pytorch/blob/15702a8027bdfb1d6da0faec96d833e25a3c72ba/c10/util/BFloat16.h#L11

@cloudhan Can you help to verify? Thanks!

This modification is based on your latest pull req #2648

diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu
index ceb7347..036fb36 100644
--- a/csrc/cache_kernels.cu
+++ b/csrc/cache_kernels.cu
@@ -14,8 +14,7 @@
 #include <vector>

 #ifdef USE_ROCM
-  #include <hip/hip_bf16.h>
-  typedef __hip_bfloat16 __nv_bfloat16;
+  typedef hip_bfloat16 __nv_bfloat16;
 #endif

 void swap_blocks(

@cloudhan
Copy link
Contributor

No luck...

@lhcalibur
Copy link
Author

@cloudhan
Can you check if there is any place where you have included the "hip/hip_bf16.h" or "cuda_bf16.h" file? And use hip_bfloat16 instead of __hip_bfloat16

@hanq-moreh
Copy link

I'm having the same issue when build the Docker image with ROCm 6.0 version.

@lhcalibur
Copy link
Author

This is caused by The function in <hip/hip_bf16.h> is defined without static inline , which will cause multi definition error when using this header file in multi .cu/.hip files.

#ROCm/HIP#3403

The error will be fixed in future releases of ROCM, but till ROCM-6.0.2, this is not be fixed yet, maybe we need to find a way to bypass this. @zhaoyang-star

@guangzlu
Copy link

guangzlu commented Feb 5, 2024

This is caused by The function in <hip/hip_bf16.h> is defined without static inline , which will cause multi definition error when using this header file in multi .cu/.hip files.

#ROCm/HIP#3403

The error will be fixed in future releases of ROCM, but till ROCM-6.0.2, this is not be fixed yet, maybe we need to find a way to bypass this. @zhaoyang-star

Hi @lhcalibur is there any way to bypass this?

@thesues
Copy link
Contributor

thesues commented Apr 20, 2024

I had the same issue on ROCM 5.7. just apply this patch for rocm as workaround: https://github.com/vllm-project/vllm/pull/2790/files

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants