diff --git a/src/spconv/indice.cu b/src/spconv/indice.cu index f2c8494..e10205b 100644 --- a/src/spconv/indice.cu +++ b/src/spconv/indice.cu @@ -105,6 +105,7 @@ int create_conv_indice_pair_p2_cuda( auto kernelVolume = indiceNum.size(0); if (numActIn == 0) return 0; + bool failed = false; tv::dispatch_torch(indicesIn.scalar_type(), [&](auto IndexValue) { using Index = TV_DECLTYPE(IndexValue); using IndexGrid = int32_t; @@ -131,7 +132,8 @@ int create_conv_indice_pair_p2_cuda( cudaFree(d_values); TV_CHECK_CUDA_ERR_V2("cudaFree failed"); if (!res) { - return -1; // use -1 to tell outside use CPU implementation + failed = true; + return; } assignIndiceOutKernel <<(indicesIn.scalar_type(), [&](auto IndexValue) { using Index = TV_DECLTYPE(IndexValue); using IndexGrid = int32_t; @@ -245,7 +252,8 @@ int create_submconv_indice_pair_cuda( cudaFree(d_keyvalues); TV_CHECK_CUDA_ERR_V2("cudaFree failed"); if (!res) { - return -1; // use -1 to tell outside use CPU implementation + failed = true; + return; } auto tableSize = table.get_table_size(); auto tableData = table.data(); @@ -349,6 +357,10 @@ int create_submconv_indice_pair_cuda( } }); }); + if (failed){ + return -1; + } + return numActIn; }