From ea2c60f0302469cec4801e5d1cf99b81594e043a Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Wed, 18 Dec 2024 16:35:37 -0500 Subject: [PATCH] Update test --- ...KnnVectorQueryRewriteInterceptorTests.java | 49 +++++++++++-------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java index 454f88121a20f..46ca771b11d96 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java @@ -17,12 +17,14 @@ import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; +import org.elasticsearch.search.vectors.QueryVectorBuilder; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.client.NoOpClient; import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; +import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; -import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor; +import org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor; import org.junit.After; import org.junit.Before; @@ -51,13 +53,14 @@ public void cleanup() { threadPool.close(); } - public void testSparseVectorQueryOnInferenceFieldIsInterceptedAndRewritten() throws IOException { + public void testKnnQueryWithVectorBuilderIsInterceptedAndRewritten() throws IOException { Map inferenceFields = Map.of( FIELD_NAME, new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }) ); QueryRewriteContext context = createQueryRewriteContext(inferenceFields); - QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY); + QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(INFERENCE_ID, QUERY); + QueryBuilder original = new KnnVectorQueryBuilder(FIELD_NAME, queryVectorBuilder, 10, 100, null); QueryBuilder rewritten = original.rewrite(context); assertTrue( "Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]", @@ -68,20 +71,20 @@ public void testSparseVectorQueryOnInferenceFieldIsInterceptedAndRewritten() thr NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder; assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path()); QueryBuilder innerQuery = nestedQueryBuilder.query(); - assertTrue(innerQuery instanceof SparseVectorQueryBuilder); - SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery; - assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName()); - assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId()); - assertEquals(QUERY, sparseVectorQueryBuilder.getQuery()); + assertTrue(innerQuery instanceof KnnVectorQueryBuilder); + KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) innerQuery; + assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), knnVectorQueryBuilder.getFieldName()); + assertEquals(queryVectorBuilder, knnVectorQueryBuilder.queryVectorBuilder()); } - public void testSparseVectorQueryOnInferenceFieldWithoutInferenceIdIsInterceptedAndRewritten() throws IOException { + public void testKnnWithQueryBuilderWithoutInferenceIdIsInterceptedAndRewritten() throws IOException { Map inferenceFields = Map.of( FIELD_NAME, new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }) ); QueryRewriteContext context = createQueryRewriteContext(inferenceFields); - QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, null, QUERY); + QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(null, QUERY); + QueryBuilder original = new KnnVectorQueryBuilder(FIELD_NAME, queryVectorBuilder, 10, 100, null); QueryBuilder rewritten = original.rewrite(context); assertTrue( "Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]", @@ -92,20 +95,24 @@ public void testSparseVectorQueryOnInferenceFieldWithoutInferenceIdIsIntercepted NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder; assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path()); QueryBuilder innerQuery = nestedQueryBuilder.query(); - assertTrue(innerQuery instanceof SparseVectorQueryBuilder); - SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery; - assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName()); - assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId()); - assertEquals(QUERY, sparseVectorQueryBuilder.getQuery()); + assertTrue(innerQuery instanceof KnnVectorQueryBuilder); + KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) innerQuery; + assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), knnVectorQueryBuilder.getFieldName()); + assertTrue(knnVectorQueryBuilder.queryVectorBuilder() instanceof TextEmbeddingQueryVectorBuilder); + TextEmbeddingQueryVectorBuilder textEmbeddingQueryVectorBuilder = (TextEmbeddingQueryVectorBuilder) knnVectorQueryBuilder + .queryVectorBuilder(); + assertEquals(QUERY, textEmbeddingQueryVectorBuilder.getModelText()); + assertEquals(INFERENCE_ID, textEmbeddingQueryVectorBuilder.getModelId()); } - public void testSparseVectorQueryOnNonInferenceFieldRemainsUnchanged() throws IOException { + public void testKnnVectorQueryOnNonInferenceFieldRemainsUnchanged() throws IOException { QueryRewriteContext context = createQueryRewriteContext(Map.of()); // No inference fields - QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY); + QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(null, QUERY); + QueryBuilder original = new KnnVectorQueryBuilder(FIELD_NAME, queryVectorBuilder, 10, 100, null); QueryBuilder rewritten = original.rewrite(context); assertTrue( - "Expected query to remain sparse_vector but was [" + rewritten.getClass().getName() + "]", - rewritten instanceof SparseVectorQueryBuilder + "Expected query to remain knn but was [" + rewritten.getClass().getName() + "]", + rewritten instanceof KnnVectorQueryBuilder ); assertEquals(original, rewritten); } @@ -132,6 +139,6 @@ private QueryRewriteContext createQueryRewriteContext(Map