diff --git a/src/contrib/sort/sort.cc b/src/contrib/sort/sort.cc index a87ce07cb602..0ccaee515acb 100644 --- a/src/contrib/sort/sort.cc +++ b/src/contrib/sort/sort.cc @@ -75,7 +75,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") // Currently only supports input dtype to be float32. CHECK_EQ(dtype.code, 2) << "Currently only supports input dtype " "to be float."; -#if (__ARM_FP16_FORMAT_IEEE != 1) +#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC != 1) CHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype " "to be float32."; #endif @@ -100,23 +100,23 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx))); } if (is_ascend) { -#if (__ARM_FP16_FORMAT_IEEE == 1) +#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) if (dtype.bits == 16) { std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<__fp16>); } else { #endif std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); -#if (__ARM_FP16_FORMAT_IEEE == 1) +#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) } #endif } else { -#if (__ARM_FP16_FORMAT_IEEE == 1) +#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) if (dtype.bits == 16) { std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<__fp16>); } else { #endif std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); -#if (__ARM_FP16_FORMAT_IEEE == 1) +#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) } #endif } @@ -210,7 +210,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } -#if (__ARM_FP16_FORMAT_IEEE == 1) +#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) } else if (data_dtype == "float16") { if (out_dtype == "float16") { argsort<__fp16, __fp16>(input, output, axis, is_ascend);