Skip to content

Commit

Permalink
Use a separate threadSafeFind method and change structure to allow ov…
Browse files Browse the repository at this point in the history
…erriding match method
  • Loading branch information
carlosdelest committed Apr 16, 2024
1 parent 4401d95 commit 04f9055
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,18 @@ public BytesRef get(long id, BytesRef dest) {
return bytesRefs.get(id, dest);
}

public long find(BytesRef key, int code) {
return find(key, code, spare);
}

/**
* Get the id associated with <code>key</code>
*/
public long find(BytesRef key, int code) {
private long find(BytesRef key, int code, BytesRef intermediate) {
final long slot = slot(rehash(code), mask);
for (long index = slot;; index = nextSlot(index, mask)) {
final long id = id(index);
if (id == -1L || key.bytesEquals(get(id, spare))) {
if (id == -1L || key.bytesEquals(get(id, intermediate))) {
return id;
}
}
Expand All @@ -147,6 +151,15 @@ public long find(BytesRef key) {
return find(key, key.hashCode());
}

/**
* Allows finding a key in the hash in a thread safe manner, by providing an intermediate
* BytesRef reference to storing intermediate results. As long as each thread provides
* its own intermediate instance, this method is thread safe.
*/
public long threadSafeFind(BytesRef key, BytesRef intermediate) {
return find(key, key.hashCode(), intermediate);
}

private long set(BytesRef key, int code, long id) {
assert rehash(key.hashCode()) == code;
assert size < maxSize;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ protected final boolean matches(IpFieldScript scriptContext, int docId) {
/**
* Does the value match this query?
*/
protected abstract boolean matches(BytesRef[] values, int conut);
protected abstract boolean matches(BytesRef[] values, int count);

protected static InetAddress decode(BytesRef ref) {
return InetAddressPoint.decode(BytesReference.toBytes(new BytesArray(ref)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

package org.elasticsearch.search.runtime;

import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.network.InetAddresses;
import org.elasticsearch.common.util.BytesRefHash;
Expand All @@ -16,26 +18,51 @@

import java.util.Objects;

public class IpScriptFieldTermsQuery extends AbstractIpScriptFieldQuery {
import static org.elasticsearch.search.runtime.AbstractIpScriptFieldQuery.decode;

public class IpScriptFieldTermsQuery extends AbstractScriptFieldQuery<IpFieldScript> {
private final BytesRefHash terms;

public IpScriptFieldTermsQuery(Script script, IpFieldScript.LeafFactory leafFactory, String fieldName, BytesRefHash terms) {
super(script, leafFactory, fieldName);
super(script, fieldName, leafFactory::newInstance);
this.terms = terms;
}

@Override
protected boolean matches(BytesRef[] values, int count) {
synchronized (terms) {
for (int i = 0; i < count; i++) {
if (terms.find(values[i]) >= 0) {
return true;
}
boolean matches(BytesRef[] values, int count, BytesRef intermediate) {
for (int i = 0; i < count; i++) {
if (terms.threadSafeFind(values[i], intermediate) >= 0) {
return true;
}
}
return false;
}

@Override
protected boolean matches(IpFieldScript scriptContext, int docId) {
throw new UnsupportedOperationException("Need to use matches(IpFieldScript, int, BytesRef) instead");
}

boolean matches(IpFieldScript scriptContext, int docId, BytesRef intermediate) {
scriptContext.runForDoc(docId);
return matches(scriptContext.values(), scriptContext.count(), intermediate);
}

protected TwoPhaseIterator createTwoPhaseIterator(IpFieldScript scriptContext, DocIdSetIterator approximation) {
return new TwoPhaseIterator(approximation) {
private final BytesRef bytesRef = new BytesRef();

@Override
public boolean matches() {
return IpScriptFieldTermsQuery.this.matches(scriptContext, approximation.docID(), bytesRef);
}

@Override
public float matchCost() {
return MATCH_COST;
}
};
}

@Override
public final String toString(String field) {
StringBuilder b = new StringBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,35 @@
import org.elasticsearch.common.network.InetAddresses;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BytesRefHash;
import org.elasticsearch.script.IpFieldScript;
import org.elasticsearch.script.Script;

import java.net.InetAddress;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.endsWith;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.startsWith;
import static org.mockito.Mockito.mock;

public class IpScriptFieldTermsQueryTests extends AbstractIpScriptFieldQueryTestCase<IpScriptFieldTermsQuery> {
public class IpScriptFieldTermsQueryTests extends AbstractScriptFieldQueryTestCase<IpScriptFieldTermsQuery> {
@Override
protected IpScriptFieldTermsQuery createTestInstance() {
return createTestInstance(between(1, 100));
}

protected final IpFieldScript.LeafFactory leafFactory = mock(IpFieldScript.LeafFactory.class);

@Override
public final void testVisit() {
assertEmptyVisit();
}

protected static BytesRef encode(InetAddress addr) {
return new BytesRef(InetAddressPoint.encode(addr));
}


private IpScriptFieldTermsQuery createTestInstance(int size) {
BytesRefHash terms = new BytesRefHash(size, BigArrays.NON_RECYCLING_INSTANCE);
while (terms.size() < size) {
Expand Down Expand Up @@ -80,12 +96,13 @@ public void testMatches() {
terms.add(ip1);
terms.add(ip2);
IpScriptFieldTermsQuery query = new IpScriptFieldTermsQuery(randomScript(), leafFactory, "test", terms);
assertTrue(query.matches(new BytesRef[] { ip1 }, 1));
assertTrue(query.matches(new BytesRef[] { ip2 }, 1));
assertTrue(query.matches(new BytesRef[] { ip1, notIp }, 2));
assertTrue(query.matches(new BytesRef[] { notIp, ip1 }, 2));
assertFalse(query.matches(new BytesRef[] { notIp }, 1));
assertFalse(query.matches(new BytesRef[] { notIp, ip1 }, 1));
BytesRef intermediate = new BytesRef();
assertTrue(query.matches(new BytesRef[] { ip1 }, 1, intermediate));
assertTrue(query.matches(new BytesRef[] { ip2 }, 1, intermediate));
assertTrue(query.matches(new BytesRef[] { ip1, notIp }, 2, intermediate));
assertTrue(query.matches(new BytesRef[] { notIp, ip1 }, 2, intermediate));
assertFalse(query.matches(new BytesRef[] { notIp }, 1, intermediate));
assertFalse(query.matches(new BytesRef[] { notIp, ip1 }, 1, intermediate));
}

@Override
Expand Down

0 comments on commit 04f9055

Please sign in to comment.