diff --git a/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/DefaultVectorComparators.java b/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/DefaultVectorComparators.java index e16b9ecdae629..2dfa0aaa7cc9f 100644 --- a/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/DefaultVectorComparators.java +++ b/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/DefaultVectorComparators.java @@ -17,7 +17,12 @@ package org.apache.arrow.algorithm.sort; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TinyIntVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.holders.NullableVarCharHolder; @@ -26,6 +31,42 @@ */ public class DefaultVectorComparators { + /** + * Default comparator for bytes. + * The comparison is based on values, with null comes first. + */ + public static class ByteComparator extends VectorValueComparator { + + public ByteComparator() { + super(Byte.SIZE / 8); + } + + @Override + public int compareNotNull(int index1, int index2) { + byte value1 = vector1.get(index1); + byte value2 = vector2.get(index2); + return value1 - value2; + } + } + + /** + * Default comparator for short integers. + * The comparison is based on values, with null comes first. + */ + public static class ShortComparator extends VectorValueComparator { + + public ShortComparator() { + super(Short.SIZE / 8); + } + + @Override + public int compareNotNull(int index1, int index2) { + short value1 = vector1.get(index1); + short value2 = vector2.get(index2); + return value1 - value2; + } + } + /** * Default comparator for 32-bit integers. * The comparison is based on int values, with null comes first. @@ -44,6 +85,89 @@ public int compareNotNull(int index1, int index2) { } } + /** + * Default comparator for long integers. + * The comparison is based on values, with null comes first. + */ + public static class LongComparator extends VectorValueComparator { + + public LongComparator() { + super(Long.SIZE / 8); + } + + @Override + public int compareNotNull(int index1, int index2) { + long value1 = vector1.get(index1); + long value2 = vector2.get(index2); + + return Long.signum(value1 - value2); + } + } + + /** + * Default comparator for float type. + * The comparison is based on values, with null comes first. + */ + public static class Float4Comparator extends VectorValueComparator { + + public Float4Comparator() { + super(Float.SIZE / 8); + } + + @Override + public int compareNotNull(int index1, int index2) { + float value1 = vector1.get(index1); + float value2 = vector2.get(index2); + + boolean isNan1 = Float.isNaN(value1); + boolean isNan2 = Float.isNaN(value2); + if (isNan1 || isNan2) { + if (isNan1 && isNan2) { + return 0; + } else if (isNan1) { + // nan is greater than any normal value + return 1; + } else { + return -1; + } + } + + return (int) Math.signum(value1 - value2); + } + } + + /** + * Default comparator for double type. + * The comparison is based on values, with null comes first. + */ + public static class Float8Comparator extends VectorValueComparator { + + public Float8Comparator() { + super(Double.SIZE / 8); + } + + @Override + public int compareNotNull(int index1, int index2) { + double value1 = vector1.get(index1); + double value2 = vector2.get(index2); + + boolean isNan1 = Double.isNaN(value1); + boolean isNan2 = Double.isNaN(value2); + if (isNan1 || isNan2) { + if (isNan1 && isNan2) { + return 0; + } else if (isNan1) { + // nan is greater than any normal value + return 1; + } else { + return -1; + } + } + + return (int) Math.signum(value1 - value2); + } + } + /** * Default comparator for varchars. * The comparison is in lexicographic order, with null comes first. diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestFixedWidthOutOfPlaceVectorSorter.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestFixedWidthOutOfPlaceVectorSorter.java index 9133ab6b15c79..4fc4a7ac5ea4a 100644 --- a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestFixedWidthOutOfPlaceVectorSorter.java +++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestFixedWidthOutOfPlaceVectorSorter.java @@ -21,8 +21,13 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TinyIntVector; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -45,6 +50,100 @@ public void shutdown() { allocator.close(); } + @Test + public void testSortByte() { + try (TinyIntVector vec = new TinyIntVector("", allocator)) { + vec.allocateNew(10); + vec.setValueCount(10); + + // fill data to sort + vec.set(0, 10); + vec.set(1, 8); + vec.setNull(2); + vec.set(3, 10); + vec.set(4, 12); + vec.set(5, 17); + vec.setNull(6); + vec.set(7, 23); + vec.set(8, 35); + vec.set(9, 2); + + // sort the vector + FixedWidthOutOfPlaceVectorSorter sorter = new FixedWidthOutOfPlaceVectorSorter(); + DefaultVectorComparators.ByteComparator comparator = new DefaultVectorComparators.ByteComparator(); + + TinyIntVector sortedVec = + (TinyIntVector) vec.getField().getFieldType().createNewSingleVector("", allocator, null); + sortedVec.allocateNew(vec.getValueCount()); + sortedVec.setValueCount(vec.getValueCount()); + + sorter.sortOutOfPlace(vec, sortedVec, comparator); + + // verify results + Assert.assertEquals(vec.getValueCount(), sortedVec.getValueCount()); + + assertTrue(sortedVec.isNull(0)); + assertTrue(sortedVec.isNull(1)); + Assert.assertEquals((byte) 2, sortedVec.get(2)); + Assert.assertEquals((byte) 8, sortedVec.get(3)); + Assert.assertEquals((byte) 10, sortedVec.get(4)); + Assert.assertEquals((byte) 10, sortedVec.get(5)); + Assert.assertEquals((byte) 12, sortedVec.get(6)); + Assert.assertEquals((byte) 17, sortedVec.get(7)); + Assert.assertEquals((byte) 23, sortedVec.get(8)); + Assert.assertEquals((byte) 35, sortedVec.get(9)); + + sortedVec.close(); + } + } + + @Test + public void testSortShort() { + try (SmallIntVector vec = new SmallIntVector("", allocator)) { + vec.allocateNew(10); + vec.setValueCount(10); + + // fill data to sort + vec.set(0, 10); + vec.set(1, 8); + vec.setNull(2); + vec.set(3, 10); + vec.set(4, 12); + vec.set(5, 17); + vec.setNull(6); + vec.set(7, 23); + vec.set(8, 35); + vec.set(9, 2); + + // sort the vector + FixedWidthOutOfPlaceVectorSorter sorter = new FixedWidthOutOfPlaceVectorSorter(); + DefaultVectorComparators.ShortComparator comparator = new DefaultVectorComparators.ShortComparator(); + + SmallIntVector sortedVec = + (SmallIntVector) vec.getField().getFieldType().createNewSingleVector("", allocator, null); + sortedVec.allocateNew(vec.getValueCount()); + sortedVec.setValueCount(vec.getValueCount()); + + sorter.sortOutOfPlace(vec, sortedVec, comparator); + + // verify results + Assert.assertEquals(vec.getValueCount(), sortedVec.getValueCount()); + + assertTrue(sortedVec.isNull(0)); + assertTrue(sortedVec.isNull(1)); + Assert.assertEquals((short) 2, sortedVec.get(2)); + Assert.assertEquals((short) 8, sortedVec.get(3)); + Assert.assertEquals((short) 10, sortedVec.get(4)); + Assert.assertEquals((short) 10, sortedVec.get(5)); + Assert.assertEquals((short) 12, sortedVec.get(6)); + Assert.assertEquals((short) 17, sortedVec.get(7)); + Assert.assertEquals((short) 23, sortedVec.get(8)); + Assert.assertEquals((short) 35, sortedVec.get(9)); + + sortedVec.close(); + } + } + @Test public void testSortInt() { try (IntVector vec = new IntVector("", allocator)) { @@ -90,4 +189,142 @@ public void testSortInt() { sortedVec.close(); } } + + @Test + public void testSortLong() { + try (BigIntVector vec = new BigIntVector("", allocator)) { + vec.allocateNew(10); + vec.setValueCount(10); + + // fill data to sort + vec.set(0, 10L); + vec.set(1, 8L); + vec.setNull(2); + vec.set(3, 10L); + vec.set(4, 12L); + vec.set(5, 17L); + vec.setNull(6); + vec.set(7, 23L); + vec.set(8, 1L << 35L); + vec.set(9, 2L); + + // sort the vector + FixedWidthOutOfPlaceVectorSorter sorter = new FixedWidthOutOfPlaceVectorSorter(); + DefaultVectorComparators.LongComparator comparator = new DefaultVectorComparators.LongComparator(); + + BigIntVector sortedVec = (BigIntVector) vec.getField().getFieldType().createNewSingleVector("", allocator, null); + sortedVec.allocateNew(vec.getValueCount()); + sortedVec.setValueCount(vec.getValueCount()); + + sorter.sortOutOfPlace(vec, sortedVec, comparator); + + // verify results + Assert.assertEquals(vec.getValueCount(), sortedVec.getValueCount()); + + assertTrue(sortedVec.isNull(0)); + assertTrue(sortedVec.isNull(1)); + Assert.assertEquals(2L, sortedVec.get(2)); + Assert.assertEquals(8L, sortedVec.get(3)); + Assert.assertEquals(10L, sortedVec.get(4)); + Assert.assertEquals(10L, sortedVec.get(5)); + Assert.assertEquals(12L, sortedVec.get(6)); + Assert.assertEquals(17L, sortedVec.get(7)); + Assert.assertEquals(23L, sortedVec.get(8)); + Assert.assertEquals(1L << 35L, sortedVec.get(9)); + + sortedVec.close(); + } + } + + @Test + public void testSortFloat() { + try (Float4Vector vec = new Float4Vector("", allocator)) { + vec.allocateNew(10); + vec.setValueCount(10); + + // fill data to sort + vec.set(0, 10f); + vec.set(1, 8f); + vec.setNull(2); + vec.set(3, 10f); + vec.set(4, 12f); + vec.set(5, 17f); + vec.setNull(6); + vec.set(7, 23f); + vec.set(8, Float.NaN); + vec.set(9, 2f); + + // sort the vector + FixedWidthOutOfPlaceVectorSorter sorter = new FixedWidthOutOfPlaceVectorSorter(); + DefaultVectorComparators.Float4Comparator comparator = new DefaultVectorComparators.Float4Comparator(); + + Float4Vector sortedVec = (Float4Vector) vec.getField().getFieldType().createNewSingleVector("", allocator, null); + sortedVec.allocateNew(vec.getValueCount()); + sortedVec.setValueCount(vec.getValueCount()); + + sorter.sortOutOfPlace(vec, sortedVec, comparator); + + // verify results + Assert.assertEquals(vec.getValueCount(), sortedVec.getValueCount()); + + assertTrue(sortedVec.isNull(0)); + assertTrue(sortedVec.isNull(1)); + Assert.assertEquals(2f, sortedVec.get(2), 0f); + Assert.assertEquals(8f, sortedVec.get(3), 0f); + Assert.assertEquals(10f, sortedVec.get(4), 0f); + Assert.assertEquals(10f, sortedVec.get(5), 0f); + Assert.assertEquals(12f, sortedVec.get(6), 0f); + Assert.assertEquals(17f, sortedVec.get(7), 0f); + Assert.assertEquals(23f, sortedVec.get(8), 0f); + Assert.assertEquals(Float.NaN, sortedVec.get(9), 0f); + + sortedVec.close(); + } + } + + @Test + public void testSortDobule() { + try (Float8Vector vec = new Float8Vector("", allocator)) { + vec.allocateNew(10); + vec.setValueCount(10); + + // fill data to sort + vec.set(0, 10); + vec.set(1, 8); + vec.setNull(2); + vec.set(3, 10); + vec.set(4, 12); + vec.set(5, 17); + vec.setNull(6); + vec.set(7, Double.NaN); + vec.set(8, 35); + vec.set(9, 2); + + // sort the vector + FixedWidthOutOfPlaceVectorSorter sorter = new FixedWidthOutOfPlaceVectorSorter(); + DefaultVectorComparators.Float8Comparator comparator = new DefaultVectorComparators.Float8Comparator(); + + Float8Vector sortedVec = (Float8Vector) vec.getField().getFieldType().createNewSingleVector("", allocator, null); + sortedVec.allocateNew(vec.getValueCount()); + sortedVec.setValueCount(vec.getValueCount()); + + sorter.sortOutOfPlace(vec, sortedVec, comparator); + + // verify results + Assert.assertEquals(vec.getValueCount(), sortedVec.getValueCount()); + + assertTrue(sortedVec.isNull(0)); + assertTrue(sortedVec.isNull(1)); + Assert.assertEquals(2, sortedVec.get(2), 0); + Assert.assertEquals(8, sortedVec.get(3), 0); + Assert.assertEquals(10, sortedVec.get(4), 0); + Assert.assertEquals(10, sortedVec.get(5), 0); + Assert.assertEquals(12, sortedVec.get(6), 0); + Assert.assertEquals(17, sortedVec.get(7), 0); + Assert.assertEquals(35, sortedVec.get(8), 0); + Assert.assertEquals(Double.NaN, sortedVec.get(9), 0); + + sortedVec.close(); + } + } }