From 770ae4151d3e46a9fcd8246690aa20b266f69d57 Mon Sep 17 00:00:00 2001 From: Manasvini B S Date: Wed, 18 Sep 2024 09:57:23 -0700 Subject: [PATCH] Adding test coverage Signed-off-by: Manasvini B S --- .../query/planner/physical/node/Paginate.java | 7 +- .../node/pointInTime/PointInTime.java | 4 +- .../client/OpenSearchNodeClient.java | 3 +- .../client/OpenSearchRestClient.java | 3 +- .../request/OpenSearchQueryRequest.java | 93 +++-- .../client/OpenSearchNodeClientTest.java | 69 +++- .../client/OpenSearchRestClientTest.java | 62 ++++ .../request/OpenSearchQueryRequestTest.java | 323 +++++++++++++++++- .../request/OpenSearchRequestBuilderTest.java | 292 +++++++++++++++- .../storage/scan/OpenSearchIndexScanTest.java | 50 ++- 10 files changed, 813 insertions(+), 93 deletions(-) diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/Paginate.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/Paginate.java index d9bf12d916..5bf31bb691 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/Paginate.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/Paginate.java @@ -22,22 +22,17 @@ public abstract class Paginate extends BatchPhysicalOperator { - /** Request to submit to OpenSearch to scroll over */ + /** Request to submit to OpenSearch to scan over */ protected final TableInJoinRequestBuilder request; - /** Page size to scroll over index */ protected final int pageSize; - /** Client connection to ElasticSearch */ protected Client client; - /** Currently undergoing scan */ protected SearchResponse searchResponse; - /** Time out */ protected Integer timeout; - /** Resource monitor manager */ protected ResourceManager resourceMgr; public Paginate(TableInJoinRequestBuilder request, int pageSize) { diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/pointInTime/PointInTime.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/pointInTime/PointInTime.java index d91df4b273..9ddbde2d29 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/pointInTime/PointInTime.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/pointInTime/PointInTime.java @@ -36,6 +36,8 @@ protected void loadFirstBatch() { pit = new PointInTimeHandlerImpl(client, request.getOriginalSelect().getIndexArr()); pit.create(); pitId = pit.getPitId(); + + LOG.info("Loading first batch of response using Point In Time"); searchResponse = request .getRequestBuilder() @@ -44,7 +46,6 @@ protected void loadFirstBatch() { .setTimeout(TimeValue.timeValueSeconds(timeout)) .setPointInTime(new PointInTimeBuilder(pitId)) .get(); - LOG.info("Loading first batch of response using Point In Time"); } @Override @@ -57,6 +58,7 @@ protected void loadNextBatch() { .getHits() .getHits()[searchResponse.getHits().getHits().length - 1] .getSortValues(); + LOG.info("Loading next batch of response using Point In Time. - " + pitId); searchResponse = request diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java index 071ee2fd45..7a9487ef6a 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java @@ -30,7 +30,6 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.IndexSettings; import org.opensearch.sql.opensearch.mapping.IndexMapping; -import org.opensearch.sql.opensearch.request.OpenSearchQueryRequest; import org.opensearch.sql.opensearch.request.OpenSearchRequest; import org.opensearch.sql.opensearch.request.OpenSearchScrollRequest; import org.opensearch.sql.opensearch.response.OpenSearchResponse; @@ -179,7 +178,7 @@ public void cleanup(OpenSearchRequest request) { "Failed to clean up resources for search request " + request, e); } }); - } else if (request instanceof OpenSearchQueryRequest) { + } else { request.clean( pitId -> { DeletePitRequest deletePitRequest = new DeletePitRequest(pitId); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchRestClient.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchRestClient.java index e6ccf283a5..5cb6a69918 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchRestClient.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchRestClient.java @@ -31,7 +31,6 @@ import org.opensearch.cluster.metadata.AliasMetadata; import org.opensearch.common.settings.Settings; import org.opensearch.sql.opensearch.mapping.IndexMapping; -import org.opensearch.sql.opensearch.request.OpenSearchQueryRequest; import org.opensearch.sql.opensearch.request.OpenSearchRequest; import org.opensearch.sql.opensearch.request.OpenSearchScrollRequest; import org.opensearch.sql.opensearch.response.OpenSearchResponse; @@ -190,7 +189,7 @@ public void cleanup(OpenSearchRequest request) { "Failed to clean up resources for search request " + request, e); } }); - } else if (request instanceof OpenSearchQueryRequest) { + } else { request.clean( pitId -> { DeletePitRequest deletePitRequest = new DeletePitRequest(pitId); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequest.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequest.java index 0bc71e6d46..fff252f3b4 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequest.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequest.java @@ -25,6 +25,7 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.SearchModule; import org.opensearch.search.builder.PointInTimeBuilder; @@ -117,6 +118,41 @@ public OpenSearchQueryRequest( this.pitId = pitId; } + /** + * Constructs OpenSearchQueryRequest from serialized representation. + * + * @param in stream to read data from. + * @param engine OpenSearchSqlEngine to get node-specific context. + * @throws IOException thrown if reading from input {@code in} fails. + */ + public OpenSearchQueryRequest(StreamInput in, OpenSearchStorageEngine engine) throws IOException { + // Deserialize the SearchSourceBuilder from the string representation + String sourceBuilderString = in.readString(); + + NamedXContentRegistry xContentRegistry = + new NamedXContentRegistry( + new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + XContentParser parser = + XContentType.JSON + .xContent() + .createParser(xContentRegistry, IGNORE_DEPRECATIONS, sourceBuilderString); + this.sourceBuilder = SearchSourceBuilder.fromXContent(parser); + + cursorKeepAlive = in.readTimeValue(); + pitId = in.readString(); + includes = in.readStringList(); + indexName = new IndexName(in); + + int length = in.readVInt(); + this.searchAfter = new Object[length]; + for (int i = 0; i < length; i++) { + this.searchAfter[i] = in.readGenericValue(); + } + + OpenSearchIndex index = (OpenSearchIndex) engine.getTable(null, indexName.toString()); + exprValueFactory = new OpenSearchExprValueFactory(index.getFieldOpenSearchTypes()); + } + @Override public OpenSearchResponse search( Function searchAction, @@ -161,14 +197,9 @@ public OpenSearchResponse searchWithPIT(Function needClean = openSearchResponse.isEmpty(); searchDone = openSearchResponse.isEmpty(); - if (!needClean - && this.searchResponse.getHits().getHits() != null - && this.searchResponse.getHits().getHits().length > 0) { - searchAfter = - this.searchResponse - .getHits() - .getHits()[this.searchResponse.getHits().getHits().length - 1] - .getSortValues(); + SearchHit[] searchHits = this.searchResponse.getHits().getHits(); + if (searchHits != null && searchHits.length > 0) { + searchAfter = searchHits[searchHits.length - 1].getSortValues(); this.sourceBuilder.searchAfter(searchAfter); } } @@ -179,10 +210,9 @@ public OpenSearchResponse searchWithPIT(Function public void clean(Consumer cleanAction) { try { // clean on the last page only, to prevent deleting the PitId in the middle of paging. - if (needClean && this.pitId != null) { + if (this.pitId != null && needClean) { cleanAction.accept(this.pitId); searchDone = true; - this.pitId = null; } } finally { this.pitId = null; @@ -209,9 +239,11 @@ public void writeTo(StreamOutput out) throws IOException { indexName.writeTo(out); // Serialize the searchAfter array - out.writeVInt(searchAfter.length); - for (Object obj : searchAfter) { - out.writeGenericValue(obj); + if (searchAfter != null) { + out.writeVInt(searchAfter.length); + for (Object obj : searchAfter) { + out.writeGenericValue(obj); + } } } else { // OpenSearch Query request without PIT for single page requests @@ -219,39 +251,4 @@ public void writeTo(StreamOutput out) throws IOException { "OpenSearchQueryRequest serialization is not implemented."); } } - - /** - * Constructs OpenSearchQueryRequest from serialized representation. - * - * @param in stream to read data from. - * @param engine OpenSearchSqlEngine to get node-specific context. - * @throws IOException thrown if reading from input {@code in} fails. - */ - public OpenSearchQueryRequest(StreamInput in, OpenSearchStorageEngine engine) throws IOException { - // Deserialize the SearchSourceBuilder from the string representation - String sourceBuilderString = in.readString(); - - NamedXContentRegistry xContentRegistry = - new NamedXContentRegistry( - new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); - XContentParser parser = - XContentType.JSON - .xContent() - .createParser(xContentRegistry, IGNORE_DEPRECATIONS, sourceBuilderString); - this.sourceBuilder = SearchSourceBuilder.fromXContent(parser); - - cursorKeepAlive = in.readTimeValue(); - pitId = in.readString(); - includes = in.readStringList(); - indexName = new IndexName(in); - - int length = in.readVInt(); - this.searchAfter = new Object[length]; - for (int i = 0; i < length; i++) { - this.searchAfter[i] = in.readGenericValue(); - } - - OpenSearchIndex index = (OpenSearchIndex) engine.getTable(null, indexName.toString()); - exprValueFactory = new OpenSearchExprValueFactory(index.getFieldOpenSearchTypes()); - } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java index 9da6e05e92..ba0fb85422 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java @@ -13,8 +13,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Answers.RETURNS_DEEP_STUBS; -import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -31,6 +30,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicBoolean; import lombok.SneakyThrows; import org.apache.commons.lang3.reflect.FieldUtils; @@ -51,15 +51,16 @@ import org.opensearch.action.admin.indices.get.GetIndexResponse; import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; import org.opensearch.action.admin.indices.settings.get.GetSettingsResponse; -import org.opensearch.action.search.ClearScrollRequestBuilder; -import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.*; import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.metadata.AliasMetadata; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.common.action.ActionFuture; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.Strings; import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; @@ -74,6 +75,7 @@ import org.opensearch.sql.opensearch.data.type.OpenSearchTextType; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.mapping.IndexMapping; +import org.opensearch.sql.opensearch.request.OpenSearchQueryRequest; import org.opensearch.sql.opensearch.request.OpenSearchRequest; import org.opensearch.sql.opensearch.request.OpenSearchScrollRequest; import org.opensearch.sql.opensearch.response.OpenSearchResponse; @@ -393,6 +395,65 @@ void cleanup_rethrows_exception() { assertThrows(IllegalStateException.class, () -> client.cleanup(request)); } + @Test + @SneakyThrows + void cleanup_pit_request() { + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + new SearchSourceBuilder(), + factory, + List.of(), + TimeValue.timeValueMinutes(1L), + "samplePitId"); + // Enforce cleaning by setting a private field. + FieldUtils.writeField(request, "needClean", true, true); + client.cleanup(request); + verify(nodeClient).execute(any(), any()); + } + + @Test + @SneakyThrows + void cleanup_pit_request_throw_exception() { + DeletePitRequest deletePitRequest = new DeletePitRequest("samplePitId"); + ActionFuture actionFuture = mock(ActionFuture.class); + when(actionFuture.get()).thenThrow(new ExecutionException("Execution failed", new Throwable())); + when(nodeClient.execute(eq(DeletePitAction.INSTANCE), any(DeletePitRequest.class))) + .thenReturn(actionFuture); + assertThrows(RuntimeException.class, () -> client.deletePit(deletePitRequest)); + } + + @Test + @SneakyThrows + void create_pit() { + CreatePitRequest createPitRequest = + new CreatePitRequest(TimeValue.timeValueMinutes(5), false, Strings.EMPTY_ARRAY); + ActionFuture actionFuture = mock(ActionFuture.class); + CreatePitResponse createPitResponse = mock(CreatePitResponse.class); + when(createPitResponse.getId()).thenReturn("samplePitId"); + when(actionFuture.get()).thenReturn(createPitResponse); + when(nodeClient.execute(eq(CreatePitAction.INSTANCE), any(CreatePitRequest.class))) + .thenReturn(actionFuture); + + String pitId = client.createPit(createPitRequest); + assertEquals("samplePitId", pitId); + + verify(nodeClient).execute(CreatePitAction.INSTANCE, createPitRequest); + verify(actionFuture).get(); + } + + @Test + @SneakyThrows + void create_pit_request_throw_exception() { + CreatePitRequest createPitRequest = + new CreatePitRequest(TimeValue.timeValueMinutes(5), false, Strings.EMPTY_ARRAY); + ActionFuture actionFuture = mock(ActionFuture.class); + when(actionFuture.get()).thenThrow(new ExecutionException("Execution failed", new Throwable())); + when(nodeClient.execute(eq(CreatePitAction.INSTANCE), any(CreatePitRequest.class))) + .thenReturn(actionFuture); + assertThrows(RuntimeException.class, () -> client.createPit(createPitRequest)); + } + @Test void get_indices() { AliasMetadata aliasMetadata = mock(AliasMetadata.class); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java index b83313de07..82d9e74422 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java @@ -43,6 +43,8 @@ import org.opensearch.action.admin.cluster.settings.ClusterGetSettingsResponse; import org.opensearch.action.admin.indices.settings.get.GetSettingsRequest; import org.opensearch.action.admin.indices.settings.get.GetSettingsResponse; +import org.opensearch.action.search.CreatePitRequest; +import org.opensearch.action.search.CreatePitResponse; import org.opensearch.action.search.SearchResponse; import org.opensearch.client.RequestOptions; import org.opensearch.client.RestHighLevelClient; @@ -56,6 +58,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.Strings; import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; @@ -69,6 +72,7 @@ import org.opensearch.sql.opensearch.data.type.OpenSearchTextType; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.mapping.IndexMapping; +import org.opensearch.sql.opensearch.request.OpenSearchQueryRequest; import org.opensearch.sql.opensearch.request.OpenSearchRequest; import org.opensearch.sql.opensearch.request.OpenSearchScrollRequest; import org.opensearch.sql.opensearch.response.OpenSearchResponse; @@ -411,6 +415,64 @@ void cleanup_with_IOException() { assertThrows(IllegalStateException.class, () -> client.cleanup(request)); } + @Test + @SneakyThrows + void cleanup_pit_request() { + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + new SearchSourceBuilder(), + factory, + List.of(), + TimeValue.timeValueMinutes(1L), + "samplePitId"); + // Enforce cleaning by setting a private field. + FieldUtils.writeField(request, "needClean", true, true); + client.cleanup(request); + verify(restClient).deletePit(any(), any()); + } + + @Test + @SneakyThrows + void cleanup_pit_request_throw_exception() { + when(restClient.deletePit(any(), any())).thenThrow(new IOException()); + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + new SearchSourceBuilder(), + factory, + List.of(), + TimeValue.timeValueMinutes(1L), + "samplePitId"); + // Enforce cleaning by setting a private field. + FieldUtils.writeField(request, "needClean", true, true); + assertThrows(RuntimeException.class, () -> client.cleanup(request)); + } + + @Test + @SneakyThrows + void create_pit() { + CreatePitRequest createPitRequest = + new CreatePitRequest(TimeValue.timeValueMinutes(5), false, Strings.EMPTY_ARRAY); + CreatePitResponse createPitResponse = mock(CreatePitResponse.class); + when(createPitResponse.getId()).thenReturn("samplePitId"); + when(restClient.createPit(any(CreatePitRequest.class), any())).thenReturn(createPitResponse); + + String pitId = client.createPit(createPitRequest); + assertEquals("samplePitId", pitId); + + verify(restClient).createPit(createPitRequest, RequestOptions.DEFAULT); + } + + @Test + @SneakyThrows + void create_pit_request_throw_exception() { + CreatePitRequest createPitRequest = + new CreatePitRequest(TimeValue.timeValueMinutes(5), false, Strings.EMPTY_ARRAY); + when(restClient.createPit(any(), any())).thenThrow(new IOException()); + assertThrows(RuntimeException.class, () -> client.createPit(createPitRequest)); + } + @Test void get_indices() throws IOException { when(restClient.indices().get(any(GetIndexRequest.class), any(RequestOptions.class))) diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequestTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequestTest.java index d2bc5b0641..89b51207b5 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequestTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequestTest.java @@ -5,37 +5,40 @@ package org.opensearch.sql.opensearch.request; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; import static org.opensearch.sql.opensearch.request.OpenSearchRequest.DEFAULT_QUERY_TIMEOUT; +import java.io.IOException; +import java.lang.reflect.Field; import java.util.List; import java.util.function.Consumer; import java.util.function.Function; +import lombok.SneakyThrows; import org.apache.lucene.search.TotalHits; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchScrollRequest; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; +import org.opensearch.search.builder.PointInTimeBuilder; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.response.OpenSearchResponse; +import org.opensearch.sql.opensearch.storage.OpenSearchIndex; +import org.opensearch.sql.opensearch.storage.OpenSearchStorageEngine; @ExtendWith(MockitoExtension.class) public class OpenSearchQueryRequestTest { @@ -64,6 +67,93 @@ public class OpenSearchQueryRequestTest { private final OpenSearchQueryRequest remoteRequest = new OpenSearchQueryRequest("ccs:test", 200, factory, List.of()); + @Mock private StreamOutput streamOutput; + @Mock private StreamInput streamInput; + @Mock private OpenSearchStorageEngine engine; + @Mock private PointInTimeBuilder pointInTimeBuilder; + + @InjectMocks private OpenSearchQueryRequest serializationRequest; + + private SearchSourceBuilder sourceBuilderForSerializer; + + @BeforeEach + void setup() { + sourceBuilderForSerializer = new SearchSourceBuilder(); + sourceBuilderForSerializer.pointInTimeBuilder(pointInTimeBuilder); + sourceBuilderForSerializer.timeout(TimeValue.timeValueSeconds(30)); + } + + @SneakyThrows + @Test + void testWriteTo() throws IOException { + when(pointInTimeBuilder.getId()).thenReturn("samplePITId"); + sourceBuilderForSerializer.searchAfter(new Object[] {"value1", 123}); + List includes = List.of("field1", "field2"); + serializationRequest = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilderForSerializer, + factory, + includes, + new TimeValue(1000), + "samplePITId"); + + Field searchAfterField = OpenSearchQueryRequest.class.getDeclaredField("searchAfter"); + searchAfterField.setAccessible(true); + searchAfterField.set(serializationRequest, new Object[] {"value1", 123}); + + serializationRequest.writeTo(streamOutput); + + String expectedJson = "{\"timeout\":\"30s\",\"search_after\":[\"value1\",123]}"; + verify(streamOutput).writeString(expectedJson); + verify(streamOutput).writeTimeValue(TimeValue.timeValueSeconds(30)); + verify(streamOutput).writeString("samplePITId"); + verify(streamOutput).writeStringCollection(includes); + + verify(streamOutput).writeVInt(2); + verify(streamOutput).writeGenericValue("value1"); + verify(streamOutput).writeGenericValue(123); + } + + @Test + void testWriteToWithoutSearchAfter() + throws IOException, NoSuchFieldException, IllegalAccessException { + when(pointInTimeBuilder.getId()).thenReturn("samplePITId"); + + List includes = List.of("field1", "field2"); + serializationRequest = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilderForSerializer, + factory, + includes, + new TimeValue(1000), + "samplePITId"); + + serializationRequest.writeTo(streamOutput); + verify(streamOutput).writeString("{\"timeout\":\"30s\"}"); + verify(streamOutput).writeTimeValue(TimeValue.timeValueSeconds(30)); + verify(streamOutput).writeString("samplePITId"); + verify(streamOutput).writeStringCollection(includes); + verify(streamOutput, never()).writeVInt(anyInt()); + verify(streamOutput, never()).writeGenericValue(any()); + } + + @Test + void testWriteToWithoutPIT() { + serializationRequest = new OpenSearchQueryRequest("test", 200, factory, List.of()); + + UnsupportedOperationException exception = + assertThrows( + UnsupportedOperationException.class, + () -> { + request.writeTo(streamOutput); + }); + + assertEquals( + "OpenSearchQueryRequest serialization is not implemented.", exception.getMessage()); + } + @Test void search() { OpenSearchQueryRequest request = @@ -81,6 +171,145 @@ void search() { verify(searchAction, times(1)).apply(any()); } + @Test + void search_with_pit() { + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory, + List.of(), + new TimeValue(1000), + "samplePid"); + + when(searchAction.apply(any())).thenReturn(searchResponse); + when(searchResponse.getHits()).thenReturn(searchHits); + when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); + when(searchHit.getSortValues()).thenReturn(new String[] {"sortedValue"}); + when(sourceBuilder.sorts()).thenReturn(null); + + OpenSearchResponse openSearchResponse = request.searchWithPIT(searchAction); + assertFalse(openSearchResponse.isEmpty()); + verify(searchAction, times(1)).apply(any()); + + when(searchResponse.getHits()).thenReturn(searchHits); + when(searchResponse.getAggregations()).thenReturn(null); + when(searchHits.getHits()).thenReturn(null); + openSearchResponse = request.searchWithPIT(searchAction); + assertTrue(openSearchResponse.isEmpty()); + verify(searchAction, times(2)).apply(any()); + + openSearchResponse = request.searchWithPIT(searchAction); + assertTrue(openSearchResponse.isEmpty()); + } + + @Test + void search_with_pit_hits_null() { + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory, + List.of(), + new TimeValue(1000), + "samplePid"); + + when(searchAction.apply(any())).thenReturn(searchResponse); + when(searchResponse.getHits()).thenReturn(searchHits); + when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); + when(sourceBuilder.sorts()).thenReturn(null); + + OpenSearchResponse openSearchResponse = request.searchWithPIT(searchAction); + assertFalse(openSearchResponse.isEmpty()); + } + + @Test + void search_with_pit_hits_empty() { + SearchResponse searchResponse = mock(SearchResponse.class); + SearchHits searchHits = mock(SearchHits.class); + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory, + List.of(), + new TimeValue(1000), + "samplePid"); + + when(searchAction.apply(any())).thenReturn(searchResponse); + when(searchResponse.getHits()).thenReturn(searchHits); + when(searchHits.getHits()).thenReturn(new SearchHit[] {}); + when(sourceBuilder.sorts()).thenReturn(null); + + OpenSearchResponse openSearchResponse = request.searchWithPIT(searchAction); + assertTrue(openSearchResponse.isEmpty()); + } + + @Test + void search_with_pit_null() { + SearchResponse searchResponse = mock(SearchResponse.class); + SearchHits searchHits = mock(SearchHits.class); + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory, + List.of(), + new TimeValue(1000), + "sample"); + + when(searchAction.apply(any())).thenReturn(searchResponse); + when(searchResponse.getHits()).thenReturn(searchHits); + when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); + + OpenSearchResponse openSearchResponse = request.search(searchAction, scrollAction); + assertFalse(openSearchResponse.isEmpty()); + } + + @Test + void has_another_batch() { + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory, + List.of(), + new TimeValue(1000), + "sample"); + assertFalse(request.hasAnotherBatch()); + } + + @Test + void has_another_batch_pid_null() { + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory, + List.of(), + new TimeValue(1000), + null); + assertFalse(request.hasAnotherBatch()); + } + + @Test + void has_another_batch_need_clean() { + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory, + List.of(), + new TimeValue(1000), + "samplePid"); + + when(searchAction.apply(any())).thenReturn(searchResponse); + when(searchResponse.getHits()).thenReturn(searchHits); + when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); + OpenSearchResponse openSearchResponse = request.searchWithPIT(searchAction); + assertTrue(request.hasAnotherBatch()); + } + @Test void search_withoutContext() { OpenSearchQueryRequest request = @@ -121,6 +350,68 @@ void clean() { verify(cleanAction, never()).accept(any()); } + @Test + void testCleanConditionTrue() { + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory, + List.of(), + new TimeValue(1000), + "samplePid"); + + when(searchAction.apply(any())).thenReturn(searchResponse); + when(searchResponse.getHits()).thenReturn(searchHits); + when(searchHits.getHits()).thenReturn(null); + OpenSearchResponse openSearchResponse = request.searchWithPIT(searchAction); + + request.clean(cleanAction); + + verify(cleanAction, times(1)).accept("samplePid"); + assertTrue(request.isSearchDone()); + assertNull(request.getPitId()); + } + + @Test + void testCleanConditionFalse_needCleanFalse() { + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory, + List.of(), + new TimeValue(1000), + "samplePid"); + + when(searchAction.apply(any())).thenReturn(searchResponse); + when(searchResponse.getHits()).thenReturn(searchHits); + when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); + OpenSearchResponse openSearchResponse = request.searchWithPIT(searchAction); + + request.clean(cleanAction); + verify(cleanAction, never()).accept(anyString()); + assertFalse(request.isSearchDone()); + assertNull(request.getPitId()); + } + + @Test + void testCleanConditionFalse_pidNull() { + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory, + List.of(), + new TimeValue(1000), + null); + + request.clean(cleanAction); + verify(cleanAction, never()).accept(anyString()); + assertFalse(request.isSearchDone()); + assertNull(request.getPitId()); + } + @Test void searchRequest() { request.getSourceBuilder().query(QueryBuilders.termQuery("name", "John")); @@ -159,6 +450,20 @@ void writeTo_unsupported() { UnsupportedOperationException.class, () -> request.writeTo(mock(StreamOutput.class))); } + @Test + void constructor_serialized() throws IOException { + StreamInput stream = mock(StreamInput.class); + OpenSearchStorageEngine engine = mock(OpenSearchStorageEngine.class); + when(stream.readString()).thenReturn("{}"); + when(stream.readStringArray()).thenReturn(new String[] {"sample"}); + OpenSearchIndex index = mock(OpenSearchIndex.class); + when(engine.getTable(null, "sample")).thenReturn(index); + when(stream.readVInt()).thenReturn(2); + when(stream.readGenericValue()).thenReturn("sampleSearchAfter"); + OpenSearchQueryRequest request = new OpenSearchQueryRequest(stream, engine); + assertNotNull(request); + } + private void assertSearchRequest(SearchRequest expected, OpenSearchQueryRequest request) { Function querySearch = searchRequest -> { diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java index fa0248782d..bf87840b60 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java @@ -5,11 +5,10 @@ package org.opensearch.sql.opensearch.request; -import static org.junit.Assert.assertThrows; +import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.*; -import static org.opensearch.index.query.QueryBuilders.matchAllQuery; -import static org.opensearch.index.query.QueryBuilders.nestedQuery; +import static org.opensearch.index.query.QueryBuilders.*; import static org.opensearch.search.sort.FieldSortBuilder.DOC_FIELD_NAME; import static org.opensearch.search.sort.SortOrder.ASC; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; @@ -23,21 +22,16 @@ import org.apache.commons.lang3.tuple.Pair; import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.join.ScoreMode; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.DisplayNameGeneration; -import org.junit.jupiter.api.DisplayNameGenerator; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.*; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.search.CreatePitRequest; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchScrollRequest; import org.opensearch.common.unit.TimeValue; -import org.opensearch.index.query.InnerHitBuilder; -import org.opensearch.index.query.NestedQueryBuilder; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.*; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.AggregationBuilder; @@ -45,9 +39,13 @@ import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.search.fetch.subphase.highlight.HighlightBuilder; import org.opensearch.search.sort.FieldSortBuilder; import org.opensearch.search.sort.ScoreSortBuilder; import org.opensearch.search.sort.SortBuilders; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; @@ -110,12 +108,188 @@ void build_query_request() { requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client)); } + @Test + void build_query_request_push_down_size() { + Integer limit = 200; + Integer offset = 0; + requestBuilder.pushDownLimit(limit, offset); + requestBuilder.pushDownTrackedScore(true); + + assertNotNull( + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client)); + } + + @Test + void build_PIT_request_with_correct_size() { + when(settings.getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)).thenReturn(true); + when(client.createPit(any(CreatePitRequest.class))).thenReturn("samplePITId"); + Integer limit = 0; + Integer offset = 0; + requestBuilder.pushDownLimit(limit, offset); + requestBuilder.pushDownPageSize(2); + + assertEquals( + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + new SearchSourceBuilder().from(offset).size(2).timeout(DEFAULT_QUERY_TIMEOUT), + exprValueFactory, + List.of(), + TimeValue.timeValueMinutes(1), + "samplePITId"), + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client)); + } + + @Test + void buildRequestWithPit_pageSizeNull_sizeGreaterThanMaxResultWindow() { + when(settings.getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)).thenReturn(true); + when(client.createPit(any(CreatePitRequest.class))).thenReturn("samplePITId"); + Integer limit = 600; + Integer offset = 0; + int requestedTotalSize = 600; + requestBuilder = new OpenSearchRequestBuilder(requestedTotalSize, exprValueFactory, settings); + requestBuilder.pushDownLimit(limit, offset); + + assertEquals( + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + new SearchSourceBuilder() + .from(offset) + .size(MAX_RESULT_WINDOW - offset) + .timeout(DEFAULT_QUERY_TIMEOUT), + exprValueFactory, + List.of(), + TimeValue.timeValueMinutes(1), + "samplePITId"), + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client)); + } + + @Test + void buildRequestWithPit_pageSizeNull_sizeLessThanMaxResultWindow() { + when(settings.getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)).thenReturn(true); + Integer limit = 400; + Integer offset = 0; + int requestedTotalSize = 400; + requestBuilder = new OpenSearchRequestBuilder(requestedTotalSize, exprValueFactory, settings); + requestBuilder.pushDownLimit(limit, offset); + + assertEquals( + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + new SearchSourceBuilder() + .from(offset) + .size(requestedTotalSize) + .timeout(DEFAULT_QUERY_TIMEOUT), + exprValueFactory, + List.of()), + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client)); + } + + @Test + void buildRequestWithPit_pageSizeNotNull_startFromZero() { + int pageSize = 200; + int offset = 0; + int limit = 400; + requestBuilder.pushDownPageSize(pageSize); + requestBuilder.pushDownLimit(limit, offset); + when(client.createPit(any(CreatePitRequest.class))).thenReturn("samplePITId"); + + assertEquals( + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + new SearchSourceBuilder().from(offset).size(pageSize).timeout(DEFAULT_QUERY_TIMEOUT), + exprValueFactory, + List.of(), + TimeValue.timeValueMinutes(1), + "samplePITId"), + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client)); + } + + @Test + void buildRequestWithPit_pageSizeNotNull_startFromNonZero() { + int pageSize = 200; + int offset = 100; + int limit = 400; + requestBuilder.pushDownPageSize(pageSize); + requestBuilder.pushDownLimit(limit, offset); + assertThrows( + UnsupportedOperationException.class, + () -> { + requestBuilder.build(indexName, 500, TimeValue.timeValueMinutes(1), client); + }); + } + @Test void build_scroll_request_with_correct_size() { when(settings.getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)).thenReturn(false); Integer limit = 800; Integer offset = 10; requestBuilder.pushDownLimit(limit, offset); + requestBuilder.getSourceBuilder().fetchSource("a", "b"); + + assertEquals( + new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), + TimeValue.timeValueMinutes(1), + new SearchSourceBuilder() + .from(offset) + .size(MAX_RESULT_WINDOW - offset) + .timeout(DEFAULT_QUERY_TIMEOUT), + exprValueFactory, + List.of()), + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client)); + } + + @Test + void buildRequestWithScroll_pageSizeNull_sizeGreaterThanMaxResultWindow() { + when(settings.getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)).thenReturn(false); + Integer limit = 600; + Integer offset = 0; + int requestedTotalSize = 600; + requestBuilder = new OpenSearchRequestBuilder(requestedTotalSize, exprValueFactory, settings); + requestBuilder.pushDownLimit(limit, offset); + + assertEquals( + new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), + TimeValue.timeValueMinutes(1), + new SearchSourceBuilder() + .from(offset) + .size(MAX_RESULT_WINDOW - offset) + .timeout(DEFAULT_QUERY_TIMEOUT), + exprValueFactory, + List.of()), + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client)); + } + + @Test + void buildRequestWithScroll_pageSizeNull_sizeLessThanMaxResultWindow() { + when(settings.getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)).thenReturn(false); + Integer limit = 400; + Integer offset = 0; + int requestedTotalSize = 400; + requestBuilder = new OpenSearchRequestBuilder(requestedTotalSize, exprValueFactory, settings); + requestBuilder.pushDownLimit(limit, offset); + + assertEquals( + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + new SearchSourceBuilder() + .from(offset) + .size(requestedTotalSize) + .timeout(DEFAULT_QUERY_TIMEOUT), + exprValueFactory, + List.of()), + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client)); + } + + @Test + void buildRequestWithScroll_pageSizeNotNull_startFromZero() { + when(settings.getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)).thenReturn(false); + int pageSize = 200; + int offset = 0; + int limit = 400; + requestBuilder.pushDownPageSize(pageSize); + requestBuilder.pushDownLimit(limit, offset); assertEquals( new OpenSearchScrollRequest( @@ -130,6 +304,21 @@ void build_scroll_request_with_correct_size() { requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client)); } + @Test + void buildRequestWithScroll_pageSizeNotNull_startFromNonZero() { + when(settings.getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)).thenReturn(false); + int pageSize = 200; + int offset = 100; + int limit = 400; + requestBuilder.pushDownPageSize(pageSize); + requestBuilder.pushDownLimit(limit, offset); + assertThrows( + UnsupportedOperationException.class, + () -> { + requestBuilder.build(indexName, 500, TimeValue.timeValueMinutes(1), client); + }); + } + @Test void test_push_down_query() { QueryBuilder query = QueryBuilders.termQuery("intA", 1); @@ -211,6 +400,51 @@ void test_push_down_query_and_sort() { requestBuilder); } + @Test + void test_push_down_query_not_null() { + SearchSourceBuilder sourceBuilder = requestBuilder.getSourceBuilder(); + sourceBuilder.query(QueryBuilders.termQuery("name", "John")); + sourceBuilder.sort(DOC_FIELD_NAME, SortOrder.ASC); + + QueryBuilder query = QueryBuilders.termQuery("intA", 1); + requestBuilder.pushDownFilter(query); + + BoolQueryBuilder expectedQuery = + QueryBuilders.boolQuery().filter(QueryBuilders.termQuery("name", "John")).filter(query); + + SearchSourceBuilder expectedSourceBuilder = + new SearchSourceBuilder() + .from(DEFAULT_OFFSET) + .size(DEFAULT_LIMIT) + .timeout(DEFAULT_QUERY_TIMEOUT) + .query(expectedQuery) + .sort(DOC_FIELD_NAME, SortOrder.ASC); + + assertSearchSourceBuilder(expectedSourceBuilder, requestBuilder); + } + + @Test + void test_push_down_query_with_bool_filter() { + BoolQueryBuilder initialBoolQuery = + QueryBuilders.boolQuery().filter(QueryBuilders.termQuery("name", "John")); + + SearchSourceBuilder sourceBuilder = requestBuilder.getSourceBuilder(); + sourceBuilder.query(initialBoolQuery); + + QueryBuilder newQuery = QueryBuilders.termQuery("intA", 1); + requestBuilder.pushDownFilter(newQuery); + initialBoolQuery.filter(newQuery); + SearchSourceBuilder expectedSourceBuilder = + new SearchSourceBuilder() + .from(DEFAULT_OFFSET) + .size(DEFAULT_LIMIT) + .timeout(DEFAULT_QUERY_TIMEOUT) + .query(initialBoolQuery) + .sort(DOC_FIELD_NAME, SortOrder.ASC); + + assertSearchSourceBuilder(expectedSourceBuilder, requestBuilder); + } + void assertSearchSourceBuilder( SearchSourceBuilder expected, OpenSearchRequestBuilder requestBuilder) throws UnsupportedOperationException { @@ -385,7 +619,7 @@ void test_push_down_nested() { assertSearchSourceBuilder( new SearchSourceBuilder() - .query(QueryBuilders.boolQuery().filter(QueryBuilders.boolQuery().must(nestedQuery))) + .query(boolQuery().filter(boolQuery().must(nestedQuery))) .from(DEFAULT_OFFSET) .size(DEFAULT_LIMIT) .timeout(DEFAULT_QUERY_TIMEOUT), @@ -419,7 +653,7 @@ void test_push_down_multiple_nested_with_same_path() { true, new String[] {"message.info", "message.from"}, null))); assertSearchSourceBuilder( new SearchSourceBuilder() - .query(QueryBuilders.boolQuery().filter(QueryBuilders.boolQuery().must(nestedQuery))) + .query(boolQuery().filter(boolQuery().must(nestedQuery))) .from(DEFAULT_OFFSET) .size(DEFAULT_LIMIT) .timeout(DEFAULT_QUERY_TIMEOUT), @@ -452,9 +686,9 @@ void test_push_down_nested_with_filter() { assertSearchSourceBuilder( new SearchSourceBuilder() .query( - QueryBuilders.boolQuery() + boolQuery() .filter( - QueryBuilders.boolQuery() + boolQuery() .must(QueryBuilders.rangeQuery("myNum").gt(3)) .must(nestedQuery))) .from(DEFAULT_OFFSET) @@ -491,7 +725,7 @@ void testPushDownNestedWithNestedFilter() { assertSearchSourceBuilder( new SearchSourceBuilder() - .query(QueryBuilders.boolQuery().filter(QueryBuilders.boolQuery().must(filterQuery))) + .query(boolQuery().filter(boolQuery().must(filterQuery))) .from(DEFAULT_OFFSET) .size(DEFAULT_LIMIT) .timeout(DEFAULT_QUERY_TIMEOUT), @@ -515,6 +749,32 @@ void push_down_highlight_with_repeating_fields() { assertEquals("Duplicate field name in highlight", exception.getMessage()); } + @Test + void test_push_down_highlight_with_pre_tags() { + requestBuilder.pushDownHighlight( + "name", Map.of("pre_tags", new Literal("pre1", DataType.STRING))); + + SearchSourceBuilder sourceBuilder = requestBuilder.getSourceBuilder(); + assertNotNull(sourceBuilder.highlighter()); + assertEquals(1, sourceBuilder.highlighter().fields().size()); + HighlightBuilder.Field field = sourceBuilder.highlighter().fields().get(0); + assertEquals("name", field.name()); + assertEquals("pre1", field.preTags()[0]); + } + + @Test + void test_push_down_highlight_with_post_tags() { + requestBuilder.pushDownHighlight( + "name", Map.of("post_tags", new Literal("post1", DataType.STRING))); + + SearchSourceBuilder sourceBuilder = requestBuilder.getSourceBuilder(); + assertNotNull(sourceBuilder.highlighter()); + assertEquals(1, sourceBuilder.highlighter().fields().size()); + HighlightBuilder.Field field = sourceBuilder.highlighter().fields().get(0); + assertEquals("name", field.name()); + assertEquals("post1", field.postTags()[0]); + } + @Test void push_down_page_size() { requestBuilder.pushDownPageSize(3); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanTest.java index bd22e083ad..e680c6b3a6 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanTest.java @@ -5,6 +5,7 @@ package org.opensearch.sql.opensearch.storage.scan; +import static org.junit.Assert.assertNotNull; import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -16,11 +17,7 @@ import static org.opensearch.search.sort.SortOrder.ASC; import static org.opensearch.sql.data.type.ExprCoreType.STRING; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; +import java.io.*; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -148,6 +145,49 @@ void serialize(Integer numberOfIncludes) { } } + @SneakyThrows + @ParameterizedTest + @ValueSource(ints = {0, 150}) + void serialize_PIT(Integer numberOfIncludes) { + var searchSourceBuilder = new SearchSourceBuilder().size(4); + + var factory = mock(OpenSearchExprValueFactory.class); + var engine = mock(OpenSearchStorageEngine.class); + var index = mock(OpenSearchIndex.class); + when(engine.getClient()).thenReturn(client); + when(engine.getTable(any(), any())).thenReturn(index); + Map map = mock(Map.class); + when(map.get(any(String.class))).thenReturn("true"); + when(client.meta()).thenReturn(map); + var includes = + Stream.iterate(1, i -> i + 1) + .limit(numberOfIncludes) + .map(i -> "column" + i) + .collect(Collectors.toList()); + var request = + new OpenSearchQueryRequest( + INDEX_NAME, searchSourceBuilder, factory, includes, CURSOR_KEEP_ALIVE, "samplePitId"); + // make a response, so OpenSearchResponse::isEmpty would return true and unset needClean + var response = mock(SearchResponse.class); + when(response.getAggregations()).thenReturn(mock()); + var hits = mock(SearchHits.class); + when(response.getHits()).thenReturn(hits); + SearchHit hit = mock(SearchHit.class); + when(hit.getSortValues()).thenReturn(new String[] {"sample1"}); + when(hits.getHits()).thenReturn(new SearchHit[] {hit}); + request.search((req) -> response, null); + + try (var indexScan = new OpenSearchIndexScan(client, QUERY_SIZE, request)) { + var planSerializer = new PlanSerializer(engine); + var cursor = planSerializer.convertToCursor(indexScan); + var newPlan = planSerializer.convertToPlan(cursor.toString()); + assertNotNull(newPlan); + + verify(client).meta(); + verify(map).get(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER.getKeyValue()); + } + } + @SneakyThrows @Test void throws_io_exception_if_too_short() {