-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ARROW-5832: [Java] Support search operations for vector data
Support searching for a particular data item in a sorted/unsorted vector. Author: liyafan82 <[email protected]> Closes #4788 from liyafan82/fly_0703_search and squashes the following commits: 769113e <liyafan82> Add negative test cases 34be18e <liyafan82> Support search operations for vector data
- Loading branch information
1 parent
e78ea91
commit 03360e1
Showing
2 changed files
with
277 additions
and
0 deletions.
There are no files selected for viewing
89 changes: 89 additions & 0 deletions
89
java/algorithm/src/main/java/org/apache/arrow/algorithm/search/VectorSearcher.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.arrow.algorithm.search; | ||
|
||
import org.apache.arrow.algorithm.sort.VectorValueComparator; | ||
import org.apache.arrow.vector.ValueVector; | ||
|
||
/** | ||
* Search for a particular element in the vector. | ||
*/ | ||
public final class VectorSearcher { | ||
|
||
/** | ||
* Search for a particular element from the key vector in the target vector by binary search. | ||
* The target vector must be sorted. | ||
* @param targetVector the vector from which to perform the sort. | ||
* @param comparator the criterion for the sort. | ||
* @param keyVector the vector containing the element to search. | ||
* @param keyIndex the index of the search key in the key vector. | ||
* @param <V> the vector type. | ||
* @return the index of a matched element if any, and -1 otherwise. | ||
*/ | ||
public static <V extends ValueVector> int binarySearch( | ||
V targetVector, VectorValueComparator<V> comparator, V keyVector, int keyIndex) { | ||
comparator.attachVectors(keyVector, targetVector); | ||
|
||
// perform binary search | ||
int low = 0; | ||
int high = targetVector.getValueCount() - 1; | ||
|
||
while (low <= high) { | ||
int mid = (high + low) / 2; | ||
|
||
if (mid < 0) { | ||
// overflow has occurred, so calculate the mid by converting to long first | ||
mid = (int) (((long) high + (long) low) / 2L); | ||
} | ||
|
||
int cmp = comparator.compare(keyIndex, mid); | ||
if (cmp < 0) { | ||
high = mid - 1; | ||
} else if (cmp > 0) { | ||
low = mid + 1; | ||
} else { | ||
return mid; | ||
} | ||
} | ||
return -1; | ||
} | ||
|
||
/** | ||
* Search for a particular element from the key vector in the target vector by traversing the vector in sequence. | ||
* @param targetVector the vector from which to perform the sort. | ||
* @param comparator the criterion for element equality. | ||
* @param keyVector the vector containing the element to search. | ||
* @param keyIndex the index of the search key in the key vector. | ||
* @param <V> the vector type. | ||
* @return the index of a matched element if any, and -1 otherwise. | ||
*/ | ||
public static <V extends ValueVector> int linearSearch( | ||
V targetVector, VectorValueComparator<V> comparator, V keyVector, int keyIndex) { | ||
comparator.attachVectors(keyVector, targetVector); | ||
for (int i = 0; i < targetVector.getValueCount(); i++) { | ||
if (comparator.compare(keyIndex, i) == 0) { | ||
return i; | ||
} | ||
} | ||
return -1; | ||
} | ||
|
||
private VectorSearcher() { | ||
|
||
} | ||
} |
188 changes: 188 additions & 0 deletions
188
java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestVectorSearcher.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.arrow.algorithm.search; | ||
|
||
import static org.junit.Assert.assertEquals; | ||
|
||
import org.apache.arrow.algorithm.sort.DefaultVectorComparators; | ||
import org.apache.arrow.algorithm.sort.VectorValueComparator; | ||
import org.apache.arrow.memory.BufferAllocator; | ||
import org.apache.arrow.memory.RootAllocator; | ||
import org.apache.arrow.vector.IntVector; | ||
import org.apache.arrow.vector.VarCharVector; | ||
import org.junit.After; | ||
import org.junit.Before; | ||
import org.junit.Test; | ||
|
||
/** | ||
* Test cases for {@link org.apache.arrow.algorithm.search.VectorSearcher}. | ||
*/ | ||
public class TestVectorSearcher { | ||
|
||
private final int VECTOR_LENGTH = 100; | ||
|
||
private BufferAllocator allocator; | ||
|
||
@Before | ||
public void prepare() { | ||
allocator = new RootAllocator(1024 * 1024); | ||
} | ||
|
||
@After | ||
public void shutdown() { | ||
allocator.close(); | ||
} | ||
|
||
@Test | ||
public void testBinarySearchInt() { | ||
try (IntVector rawVector = new IntVector("", allocator); | ||
IntVector negVector = new IntVector("", allocator)) { | ||
rawVector.allocateNew(VECTOR_LENGTH); | ||
rawVector.setValueCount(VECTOR_LENGTH); | ||
negVector.allocateNew(1); | ||
negVector.setValueCount(1); | ||
|
||
// prepare data in sorted order | ||
for (int i = 0; i < VECTOR_LENGTH; i++) { | ||
if (i == 0) { | ||
rawVector.setNull(i); | ||
} else { | ||
rawVector.set(i, i); | ||
} | ||
} | ||
negVector.set(0, -333); | ||
|
||
// do search | ||
VectorValueComparator<IntVector> comparator = new DefaultVectorComparators.IntComparator(); | ||
for (int i = 0; i < VECTOR_LENGTH; i++) { | ||
int result = VectorSearcher.binarySearch(rawVector, comparator, rawVector, i); | ||
assertEquals(i, result); | ||
} | ||
|
||
// negative case | ||
assertEquals(-1, VectorSearcher.binarySearch(rawVector, comparator, negVector, 0)); | ||
} | ||
} | ||
|
||
@Test | ||
public void testLinearSearchInt() { | ||
try (IntVector rawVector = new IntVector("", allocator); | ||
IntVector negVector = new IntVector("", allocator)) { | ||
rawVector.allocateNew(VECTOR_LENGTH); | ||
rawVector.setValueCount(VECTOR_LENGTH); | ||
negVector.allocateNew(1); | ||
negVector.setValueCount(1); | ||
|
||
// prepare data in sorted order | ||
for (int i = 0; i < VECTOR_LENGTH; i++) { | ||
if (i == 0) { | ||
rawVector.setNull(i); | ||
} else { | ||
rawVector.set(i, i); | ||
} | ||
} | ||
negVector.set(0, -333); | ||
|
||
// do search | ||
VectorValueComparator<IntVector> comparator = new DefaultVectorComparators.IntComparator(); | ||
for (int i = 0; i < VECTOR_LENGTH; i++) { | ||
int result = VectorSearcher.linearSearch(rawVector, comparator, rawVector, i); | ||
assertEquals(i, result); | ||
} | ||
|
||
// negative case | ||
assertEquals(-1, VectorSearcher.linearSearch(rawVector, comparator, negVector, 0)); | ||
} | ||
} | ||
|
||
@Test | ||
public void testBinarySearchVarChar() { | ||
try (VarCharVector rawVector = new VarCharVector("", allocator); | ||
VarCharVector negVector = new VarCharVector("", allocator)) { | ||
rawVector.allocateNew(VECTOR_LENGTH * 16, VECTOR_LENGTH); | ||
rawVector.setValueCount(VECTOR_LENGTH); | ||
negVector.allocateNew(VECTOR_LENGTH, 1); | ||
negVector.setValueCount(1); | ||
|
||
byte[] content = new byte[2]; | ||
|
||
// prepare data in sorted order | ||
for (int i = 0; i < VECTOR_LENGTH; i++) { | ||
if (i == 0) { | ||
rawVector.setNull(i); | ||
} else { | ||
int q = i / 10; | ||
int r = i % 10; | ||
|
||
content[0] = (byte) ('a' + q); | ||
content[1] = (byte) r; | ||
rawVector.set(i, content); | ||
} | ||
} | ||
negVector.set(0, "abcd".getBytes()); | ||
|
||
// do search | ||
VectorValueComparator<VarCharVector> comparator = new DefaultVectorComparators.VarCharComparator(); | ||
for (int i = 0; i < VECTOR_LENGTH; i++) { | ||
int result = VectorSearcher.binarySearch(rawVector, comparator, rawVector, i); | ||
assertEquals(i, result); | ||
} | ||
|
||
// negative case | ||
assertEquals(-1, VectorSearcher.binarySearch(rawVector, comparator, negVector, 0)); | ||
} | ||
} | ||
|
||
@Test | ||
public void testLinearSearchVarChar() { | ||
try (VarCharVector rawVector = new VarCharVector("", allocator); | ||
VarCharVector negVector = new VarCharVector("", allocator)) { | ||
rawVector.allocateNew(VECTOR_LENGTH * 16, VECTOR_LENGTH); | ||
rawVector.setValueCount(VECTOR_LENGTH); | ||
negVector.allocateNew(VECTOR_LENGTH, 1); | ||
negVector.setValueCount(1); | ||
|
||
byte[] content = new byte[2]; | ||
|
||
// prepare data in sorted order | ||
for (int i = 0; i < VECTOR_LENGTH; i++) { | ||
if (i == 0) { | ||
rawVector.setNull(i); | ||
} else { | ||
int q = i / 10; | ||
int r = i % 10; | ||
|
||
content[0] = (byte) ('a' + q); | ||
content[1] = (byte) r; | ||
rawVector.set(i, content); | ||
} | ||
} | ||
negVector.set(0, "abcd".getBytes()); | ||
|
||
// do search | ||
VectorValueComparator<VarCharVector> comparator = new DefaultVectorComparators.VarCharComparator(); | ||
for (int i = 0; i < VECTOR_LENGTH; i++) { | ||
int result = VectorSearcher.linearSearch(rawVector, comparator, rawVector, i); | ||
assertEquals(i, result); | ||
} | ||
|
||
// negative case | ||
assertEquals(-1, VectorSearcher.linearSearch(rawVector, comparator, negVector, 0)); | ||
} | ||
} | ||
} |