diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 012ca2b6ad088..2effc6428448a 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -88,6 +88,8 @@ if(WITH_MUSA) "gpu/rnn_grad_kernel.cu.cc" "gpu/rnn_kernel.cu.cc" "gpu/slogdeterminant_grad_kernel.cu" + "gpu/softmax_grad_kernel.cu" + "gpu/softmax_kernel.cu" "gpu/solve_grad_kernel.cu" "gpu/solve_kernel.cu" "gpu/spectral_norm_grad_kernel.cu" diff --git a/paddle/phi/kernels/funcs/CMakeLists.txt b/paddle/phi/kernels/funcs/CMakeLists.txt index 7b1e010064259..cffb97d84050c 100644 --- a/paddle/phi/kernels/funcs/CMakeLists.txt +++ b/paddle/phi/kernels/funcs/CMakeLists.txt @@ -20,7 +20,8 @@ endif() if(WITH_MUSA) list(REMOVE_ITEM func_cu_srcs "cross_entropy.cu" - "gru_compute.cu") + "gru_compute.cu" + "softmax.cu") endif() collect_srcs(kernels_srcs SRCS ${func_cc_srcs} ${func_cu_srcs}) diff --git a/paddle/phi/kernels/funcs/softmax.cu b/paddle/phi/kernels/funcs/softmax.cu index 9e7cf84273b04..11ee9c23fa2ff 100644 --- a/paddle/phi/kernels/funcs/softmax.cu +++ b/paddle/phi/kernels/funcs/softmax.cu @@ -21,8 +21,6 @@ limitations under the License. */ namespace phi { namespace funcs { -// TODO(@caizhi): enable it -#if 0 using ScopedTensorDescriptor = phi::backends::gpu::ScopedTensorDescriptor; using DataLayout = phi::backends::gpu::DataLayout; template @@ -61,8 +59,6 @@ void SoftmaxCUDNNFunctor::operator()( context.template Alloc(Y), MIOPEN_SOFTMAX_ACCURATE, MIOPEN_SOFTMAX_MODE_INSTANCE)); -#elif defined(PADDLE_WITH_MUSA) - // TODO #else cudnnTensorDescriptor_t cudnn_x_desc = xDesc.descriptor(layout, cudnn_tensor_dims); @@ -148,18 +144,16 @@ template class SoftmaxCUDNNFunctor; template class SoftmaxCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; -// TODO(@caizhi): enable it -//#if CUDNN_VERSION_MIN(8, 1, 0) -//template class SoftmaxCUDNNFunctor; -//template class SoftmaxGradCUDNNFunctor; -//#endif +#if CUDNN_VERSION_MIN(8, 1, 0) +template class SoftmaxCUDNNFunctor; +template class SoftmaxGradCUDNNFunctor; +#endif // MIOPEN do not support double #ifndef PADDLE_WITH_HIP template class SoftmaxCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; #endif -#endif template class SoftmaxFunctor; template class SoftmaxFunctor; template class SoftmaxFunctor;