diff --git a/java/algorithm/src/main/java/org/apache/arrow/algorithm/search/VectorSearcher.java b/java/algorithm/src/main/java/org/apache/arrow/algorithm/search/VectorSearcher.java new file mode 100644 index 0000000000000..8bed811e2fe94 --- /dev/null +++ b/java/algorithm/src/main/java/org/apache/arrow/algorithm/search/VectorSearcher.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.search; + +import org.apache.arrow.algorithm.sort.VectorValueComparator; +import org.apache.arrow.vector.ValueVector; + +/** + * Search for a particular element in the vector. + */ +public final class VectorSearcher { + + /** + * Search for a particular element from the key vector in the target vector by binary search. + * The target vector must be sorted. + * @param targetVector the vector from which to perform the sort. + * @param comparator the criterion for the sort. + * @param keyVector the vector containing the element to search. + * @param keyIndex the index of the search key in the key vector. + * @param the vector type. + * @return the index of a matched element if any, and -1 otherwise. + */ + public static int binarySearch( + V targetVector, VectorValueComparator comparator, V keyVector, int keyIndex) { + comparator.attachVectors(keyVector, targetVector); + + // perform binary search + int low = 0; + int high = targetVector.getValueCount() - 1; + + while (low <= high) { + int mid = (high + low) / 2; + + if (mid < 0) { + // overflow has occurred, so calculate the mid by converting to long first + mid = (int) (((long) high + (long) low) / 2L); + } + + int cmp = comparator.compare(keyIndex, mid); + if (cmp < 0) { + high = mid - 1; + } else if (cmp > 0) { + low = mid + 1; + } else { + return mid; + } + } + return -1; + } + + /** + * Search for a particular element from the key vector in the target vector by traversing the vector in sequence. + * @param targetVector the vector from which to perform the sort. + * @param comparator the criterion for element equality. + * @param keyVector the vector containing the element to search. + * @param keyIndex the index of the search key in the key vector. + * @param the vector type. + * @return the index of a matched element if any, and -1 otherwise. + */ + public static int linearSearch( + V targetVector, VectorValueComparator comparator, V keyVector, int keyIndex) { + comparator.attachVectors(keyVector, targetVector); + for (int i = 0; i < targetVector.getValueCount(); i++) { + if (comparator.compare(keyIndex, i) == 0) { + return i; + } + } + return -1; + } + + private VectorSearcher() { + + } +} diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestVectorSearcher.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestVectorSearcher.java new file mode 100644 index 0000000000000..f5c2912476594 --- /dev/null +++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestVectorSearcher.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.algorithm.search; + +import static org.junit.Assert.assertEquals; + +import org.apache.arrow.algorithm.sort.DefaultVectorComparators; +import org.apache.arrow.algorithm.sort.VectorValueComparator; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Test cases for {@link org.apache.arrow.algorithm.search.VectorSearcher}. + */ +public class TestVectorSearcher { + + private final int VECTOR_LENGTH = 100; + + private BufferAllocator allocator; + + @Before + public void prepare() { + allocator = new RootAllocator(1024 * 1024); + } + + @After + public void shutdown() { + allocator.close(); + } + + @Test + public void testBinarySearchInt() { + try (IntVector rawVector = new IntVector("", allocator); + IntVector negVector = new IntVector("", allocator)) { + rawVector.allocateNew(VECTOR_LENGTH); + rawVector.setValueCount(VECTOR_LENGTH); + negVector.allocateNew(1); + negVector.setValueCount(1); + + // prepare data in sorted order + for (int i = 0; i < VECTOR_LENGTH; i++) { + if (i == 0) { + rawVector.setNull(i); + } else { + rawVector.set(i, i); + } + } + negVector.set(0, -333); + + // do search + VectorValueComparator comparator = new DefaultVectorComparators.IntComparator(); + for (int i = 0; i < VECTOR_LENGTH; i++) { + int result = VectorSearcher.binarySearch(rawVector, comparator, rawVector, i); + assertEquals(i, result); + } + + // negative case + assertEquals(-1, VectorSearcher.binarySearch(rawVector, comparator, negVector, 0)); + } + } + + @Test + public void testLinearSearchInt() { + try (IntVector rawVector = new IntVector("", allocator); + IntVector negVector = new IntVector("", allocator)) { + rawVector.allocateNew(VECTOR_LENGTH); + rawVector.setValueCount(VECTOR_LENGTH); + negVector.allocateNew(1); + negVector.setValueCount(1); + + // prepare data in sorted order + for (int i = 0; i < VECTOR_LENGTH; i++) { + if (i == 0) { + rawVector.setNull(i); + } else { + rawVector.set(i, i); + } + } + negVector.set(0, -333); + + // do search + VectorValueComparator comparator = new DefaultVectorComparators.IntComparator(); + for (int i = 0; i < VECTOR_LENGTH; i++) { + int result = VectorSearcher.linearSearch(rawVector, comparator, rawVector, i); + assertEquals(i, result); + } + + // negative case + assertEquals(-1, VectorSearcher.linearSearch(rawVector, comparator, negVector, 0)); + } + } + + @Test + public void testBinarySearchVarChar() { + try (VarCharVector rawVector = new VarCharVector("", allocator); + VarCharVector negVector = new VarCharVector("", allocator)) { + rawVector.allocateNew(VECTOR_LENGTH * 16, VECTOR_LENGTH); + rawVector.setValueCount(VECTOR_LENGTH); + negVector.allocateNew(VECTOR_LENGTH, 1); + negVector.setValueCount(1); + + byte[] content = new byte[2]; + + // prepare data in sorted order + for (int i = 0; i < VECTOR_LENGTH; i++) { + if (i == 0) { + rawVector.setNull(i); + } else { + int q = i / 10; + int r = i % 10; + + content[0] = (byte) ('a' + q); + content[1] = (byte) r; + rawVector.set(i, content); + } + } + negVector.set(0, "abcd".getBytes()); + + // do search + VectorValueComparator comparator = new DefaultVectorComparators.VarCharComparator(); + for (int i = 0; i < VECTOR_LENGTH; i++) { + int result = VectorSearcher.binarySearch(rawVector, comparator, rawVector, i); + assertEquals(i, result); + } + + // negative case + assertEquals(-1, VectorSearcher.binarySearch(rawVector, comparator, negVector, 0)); + } + } + + @Test + public void testLinearSearchVarChar() { + try (VarCharVector rawVector = new VarCharVector("", allocator); + VarCharVector negVector = new VarCharVector("", allocator)) { + rawVector.allocateNew(VECTOR_LENGTH * 16, VECTOR_LENGTH); + rawVector.setValueCount(VECTOR_LENGTH); + negVector.allocateNew(VECTOR_LENGTH, 1); + negVector.setValueCount(1); + + byte[] content = new byte[2]; + + // prepare data in sorted order + for (int i = 0; i < VECTOR_LENGTH; i++) { + if (i == 0) { + rawVector.setNull(i); + } else { + int q = i / 10; + int r = i % 10; + + content[0] = (byte) ('a' + q); + content[1] = (byte) r; + rawVector.set(i, content); + } + } + negVector.set(0, "abcd".getBytes()); + + // do search + VectorValueComparator comparator = new DefaultVectorComparators.VarCharComparator(); + for (int i = 0; i < VECTOR_LENGTH; i++) { + int result = VectorSearcher.linearSearch(rawVector, comparator, rawVector, i); + assertEquals(i, result); + } + + // negative case + assertEquals(-1, VectorSearcher.linearSearch(rawVector, comparator, negVector, 0)); + } + } +}