From ce62afca03f427d093a15b8217a1cfc6bd480f90 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Fri, 12 Jul 2024 00:39:07 +0800 Subject: [PATCH] add device code for adam --- setup.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 3aed289..09aca70 100644 --- a/setup.py +++ b/setup.py @@ -163,7 +163,10 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): 'csrc/adam/adam_kernel.cu'], include_dirs=[os.path.join(this_dir, 'csrc')], extra_compile_args={'cxx': ['-O3'], - 'nvcc':['-O3', '--use_fast_math']})) + 'nvcc':['-O3', '--use_fast_math', + '-gencode', 'arch=compute_70,code=sm_70', + '-gencode', 'arch=compute_80,code=sm_80', + '-gencode', 'arch=compute_90,code=sm_90']})) ext_modules.append( CUDAExtension(name='unicore_fused_softmax_dropout',