Skip to content

Commit

Permalink
Update test
Browse files Browse the repository at this point in the history
  • Loading branch information
kderusso committed Dec 18, 2024
1 parent ed2ab65 commit ea2c60f
Showing 1 changed file with 28 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
import org.elasticsearch.index.Index;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.search.vectors.QueryVectorBuilder;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.client.NoOpClient;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor;
import org.junit.After;
import org.junit.Before;

Expand Down Expand Up @@ -51,13 +53,14 @@ public void cleanup() {
threadPool.close();
}

public void testSparseVectorQueryOnInferenceFieldIsInterceptedAndRewritten() throws IOException {
public void testKnnQueryWithVectorBuilderIsInterceptedAndRewritten() throws IOException {
Map<String, InferenceFieldMetadata> inferenceFields = Map.of(
FIELD_NAME,
new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME })
);
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY);
QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(INFERENCE_ID, QUERY);
QueryBuilder original = new KnnVectorQueryBuilder(FIELD_NAME, queryVectorBuilder, 10, 100, null);
QueryBuilder rewritten = original.rewrite(context);
assertTrue(
"Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
Expand All @@ -68,20 +71,20 @@ public void testSparseVectorQueryOnInferenceFieldIsInterceptedAndRewritten() thr
NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());
QueryBuilder innerQuery = nestedQueryBuilder.query();
assertTrue(innerQuery instanceof SparseVectorQueryBuilder);
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery;
assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName());
assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId());
assertEquals(QUERY, sparseVectorQueryBuilder.getQuery());
assertTrue(innerQuery instanceof KnnVectorQueryBuilder);
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) innerQuery;
assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), knnVectorQueryBuilder.getFieldName());
assertEquals(queryVectorBuilder, knnVectorQueryBuilder.queryVectorBuilder());
}

public void testSparseVectorQueryOnInferenceFieldWithoutInferenceIdIsInterceptedAndRewritten() throws IOException {
public void testKnnWithQueryBuilderWithoutInferenceIdIsInterceptedAndRewritten() throws IOException {
Map<String, InferenceFieldMetadata> inferenceFields = Map.of(
FIELD_NAME,
new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME })
);
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, null, QUERY);
QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(null, QUERY);
QueryBuilder original = new KnnVectorQueryBuilder(FIELD_NAME, queryVectorBuilder, 10, 100, null);
QueryBuilder rewritten = original.rewrite(context);
assertTrue(
"Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
Expand All @@ -92,20 +95,24 @@ public void testSparseVectorQueryOnInferenceFieldWithoutInferenceIdIsIntercepted
NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());
QueryBuilder innerQuery = nestedQueryBuilder.query();
assertTrue(innerQuery instanceof SparseVectorQueryBuilder);
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery;
assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName());
assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId());
assertEquals(QUERY, sparseVectorQueryBuilder.getQuery());
assertTrue(innerQuery instanceof KnnVectorQueryBuilder);
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) innerQuery;
assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), knnVectorQueryBuilder.getFieldName());
assertTrue(knnVectorQueryBuilder.queryVectorBuilder() instanceof TextEmbeddingQueryVectorBuilder);
TextEmbeddingQueryVectorBuilder textEmbeddingQueryVectorBuilder = (TextEmbeddingQueryVectorBuilder) knnVectorQueryBuilder
.queryVectorBuilder();
assertEquals(QUERY, textEmbeddingQueryVectorBuilder.getModelText());
assertEquals(INFERENCE_ID, textEmbeddingQueryVectorBuilder.getModelId());
}

public void testSparseVectorQueryOnNonInferenceFieldRemainsUnchanged() throws IOException {
public void testKnnVectorQueryOnNonInferenceFieldRemainsUnchanged() throws IOException {
QueryRewriteContext context = createQueryRewriteContext(Map.of()); // No inference fields
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY);
QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(null, QUERY);
QueryBuilder original = new KnnVectorQueryBuilder(FIELD_NAME, queryVectorBuilder, 10, 100, null);
QueryBuilder rewritten = original.rewrite(context);
assertTrue(
"Expected query to remain sparse_vector but was [" + rewritten.getClass().getName() + "]",
rewritten instanceof SparseVectorQueryBuilder
"Expected query to remain knn but was [" + rewritten.getClass().getName() + "]",
rewritten instanceof KnnVectorQueryBuilder
);
assertEquals(original, rewritten);
}
Expand All @@ -132,6 +139,6 @@ private QueryRewriteContext createQueryRewriteContext(Map<String, InferenceField
}

private QueryRewriteInterceptor createRewriteInterceptor() {
return new SemanticSparseVectorQueryRewriteInterceptor();
return new SemanticKnnVectorQueryRewriteInterceptor();
}
}

0 comments on commit ea2c60f

Please sign in to comment.