diff --git a/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/DefaultVectorComparators.java b/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/DefaultVectorComparators.java index e7bc64df6dff9..98ec3ecd538b6 100644 --- a/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/DefaultVectorComparators.java +++ b/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/DefaultVectorComparators.java @@ -126,6 +126,19 @@ public int compareNotNull(int index1, int index2) { float value1 = vector1.get(index1); float value2 = vector2.get(index2); + boolean isNan1 = Float.isNaN(value1); + boolean isNan2 = Float.isNaN(value2); + if (isNan1 || isNan2) { + if (isNan1 && isNan2) { + return 0; + } else if (isNan1) { + // nan is greater than any normal value + return 1; + } else { + return -1; + } + } + float result = value1 - value2; if (result < 0f) { return -1; @@ -152,6 +165,19 @@ public int compareNotNull(int index1, int index2) { double value1 = vector1.get(index1); double value2 = vector2.get(index2); + boolean isNan1 = Double.isNaN(value1); + boolean isNan2 = Double.isNaN(value2); + if (isNan1 || isNan2) { + if (isNan1 && isNan2) { + return 0; + } else if (isNan1) { + // nan is greater than any normal value + return 1; + } else { + return -1; + } + } + double result = value1 - value2; if (result < 0) { return -1; diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestFixedWidthOutOfPlaceVectorSorter.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestFixedWidthOutOfPlaceVectorSorter.java index d8cf56ff2f310..4fc4a7ac5ea4a 100644 --- a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestFixedWidthOutOfPlaceVectorSorter.java +++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestFixedWidthOutOfPlaceVectorSorter.java @@ -251,7 +251,7 @@ public void testSortFloat() { vec.set(5, 17f); vec.setNull(6); vec.set(7, 23f); - vec.set(8, 35f); + vec.set(8, Float.NaN); vec.set(9, 2f); // sort the vector @@ -276,7 +276,7 @@ public void testSortFloat() { Assert.assertEquals(12f, sortedVec.get(6), 0f); Assert.assertEquals(17f, sortedVec.get(7), 0f); Assert.assertEquals(23f, sortedVec.get(8), 0f); - Assert.assertEquals(35f, sortedVec.get(9), 0f); + Assert.assertEquals(Float.NaN, sortedVec.get(9), 0f); sortedVec.close(); } @@ -296,7 +296,7 @@ public void testSortDobule() { vec.set(4, 12); vec.set(5, 17); vec.setNull(6); - vec.set(7, 23); + vec.set(7, Double.NaN); vec.set(8, 35); vec.set(9, 2); @@ -321,8 +321,8 @@ public void testSortDobule() { Assert.assertEquals(10, sortedVec.get(5), 0); Assert.assertEquals(12, sortedVec.get(6), 0); Assert.assertEquals(17, sortedVec.get(7), 0); - Assert.assertEquals(23, sortedVec.get(8), 0); - Assert.assertEquals(35, sortedVec.get(9), 0); + Assert.assertEquals(35, sortedVec.get(8), 0); + Assert.assertEquals(Double.NaN, sortedVec.get(9), 0); sortedVec.close(); }