Skip to content

Commit

Permalink
add device code for adam (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke authored Jul 11, 2024
1 parent 9091555 commit ba5962e
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,10 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
'csrc/adam/adam_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'],
'nvcc':['-O3', '--use_fast_math']}))
'nvcc':['-O3', '--use_fast_math',
'-gencode', 'arch=compute_70,code=sm_70',
'-gencode', 'arch=compute_80,code=sm_80',
'-gencode', 'arch=compute_90,code=sm_90']}))

ext_modules.append(
CUDAExtension(name='unicore_fused_softmax_dropout',
Expand Down

0 comments on commit ba5962e

Please sign in to comment.