From 769eb99b0330400660f32db127ef64c5ae7f1e4a Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Wed, 23 Oct 2024 12:37:38 -0700 Subject: [PATCH] [Backport 2.18] [Manual Backport] Backport #2981 to 2.x (#3111) * Adding PIT for pagination queries in new SQL engine code paths * Fix legacy code using scroll API instead of PIT for batch physical operator * Fix local debugger issue * Refactor integ-tests data for PIT and fix unit tests * Address feedback comments * Adding test coverage --------- (cherry picked from commit f6ca54cce5d8a4e8284c9822798077e72c84a66f) Signed-off-by: Manasvini B S Signed-off-by: Simeon Widdis Signed-off-by: github-actions[bot] Co-authored-by: github-actions[bot] Co-authored-by: Manasvini B Suryanarayana --- .../org/opensearch/sql/legacy/CursorIT.java | 2 +- .../org/opensearch/sql/ppl/StandaloneIT.java | 7 +- .../sql/sql/PaginationFilterIT.java | 35 +- .../sql/sql/StandalonePaginationIT.java | 1 + .../ppl/explain_filter_agg_push.json | 2 +- .../ppl/explain_filter_push.json | 2 +- .../ppl/explain_limit_push.json | 2 +- .../expectedOutput/ppl/explain_output.json | 2 +- .../expectedOutput/ppl/explain_sort_push.json | 2 +- .../executor/format/SelectResultSet.java | 2 +- .../query/planner/logical/node/TableScan.java | 7 + .../query/planner/physical/node/Paginate.java | 144 ++++++++ .../node/{scroll => }/SearchHitRow.java | 6 +- .../node/pointInTime/PointInTime.java | 73 ++++ .../planner/physical/node/scroll/Scroll.java | 157 +-------- .../node/scroll/SearchHitRowTest.java | 1 + .../unittest/planner/QueryPlannerTest.java | 2 + .../opensearch/client/OpenSearchClient.java | 17 + .../client/OpenSearchNodeClient.java | 59 +++- .../client/OpenSearchRestClient.java | 56 ++- .../request/OpenSearchQueryRequest.java | 170 ++++++++- .../request/OpenSearchRequestBuilder.java | 71 +++- .../opensearch/storage/OpenSearchIndex.java | 4 +- .../storage/scan/OpenSearchIndexScan.java | 16 +- .../client/OpenSearchNodeClientTest.java | 69 +++- .../client/OpenSearchRestClientTest.java | 62 ++++ .../OpenSearchExecutionEngineTest.java | 6 +- .../OpenSearchExecutionProtectorTest.java | 11 +- .../request/OpenSearchQueryRequestTest.java | 323 ++++++++++++++++- .../request/OpenSearchRequestBuilderTest.java | 324 ++++++++++++++++-- .../storage/OpenSearchIndexTest.java | 18 +- .../OpenSearchIndexScanPaginationTest.java | 11 +- .../storage/scan/OpenSearchIndexScanTest.java | 108 ++++-- 33 files changed, 1470 insertions(+), 302 deletions(-) create mode 100644 legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/Paginate.java rename legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/{scroll => }/SearchHitRow.java (97%) create mode 100644 legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/pointInTime/PointInTime.java diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/CursorIT.java b/integ-test/src/test/java/org/opensearch/sql/legacy/CursorIT.java index 4ee99198df..90f72cec40 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/CursorIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/CursorIT.java @@ -182,7 +182,7 @@ public void validTotalResultWithAndWithoutPaginationOrderBy() throws IOException String selectQuery = StringUtils.format( "SELECT firstname, state FROM %s ORDER BY balance DESC ", TEST_INDEX_ACCOUNT); - verifyWithAndWithoutPaginationResponse(selectQuery + " LIMIT 2000", selectQuery, 26, false); + verifyWithAndWithoutPaginationResponse(selectQuery + " LIMIT 2000", selectQuery, 25, false); } @Test diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java index f81e1b6615..66f85b0754 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java @@ -149,8 +149,11 @@ public void onFailure(Exception e) { private Settings defaultSettings() { return new Settings() { - private final Map defaultSettings = - new ImmutableMap.Builder().put(Key.QUERY_SIZE_LIMIT, 200).build(); + private final Map defaultSettings = + new ImmutableMap.Builder() + .put(Key.QUERY_SIZE_LIMIT, 200) + .put(Key.SQL_PAGINATION_API_SEARCH_AFTER, true) + .build(); @Override public T getSettingValue(Key key) { diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/PaginationFilterIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/PaginationFilterIT.java index 038596cf57..9a945ec86f 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/PaginationFilterIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/PaginationFilterIT.java @@ -34,25 +34,30 @@ public class PaginationFilterIT extends SQLIntegTestCase { */ private static final Map STATEMENT_TO_NUM_OF_PAGES = Map.of( - "SELECT * FROM " + TestsConstants.TEST_INDEX_ACCOUNT, 1000, + "SELECT * FROM " + TestsConstants.TEST_INDEX_ACCOUNT, + 1000, "SELECT * FROM " + TestsConstants.TEST_INDEX_ACCOUNT + " WHERE match(address, 'street')", - 385, + 385, "SELECT * FROM " - + TestsConstants.TEST_INDEX_ACCOUNT - + " WHERE match(address, 'street') AND match(city, 'Ola')", - 1, + + TestsConstants.TEST_INDEX_ACCOUNT + + " WHERE match(address, 'street') AND match(city, 'Ola')", + 1, "SELECT firstname, lastname, highlight(address) FROM " - + TestsConstants.TEST_INDEX_ACCOUNT - + " WHERE match(address, 'street') AND match(state, 'OH')", - 5, + + TestsConstants.TEST_INDEX_ACCOUNT + + " WHERE match(address, 'street') AND match(state, 'OH')", + 5, "SELECT firstname, lastname, highlight('*') FROM " - + TestsConstants.TEST_INDEX_ACCOUNT - + " WHERE match(address, 'street') AND match(state, 'OH')", - 5, - "SELECT * FROM " + TestsConstants.TEST_INDEX_BEER + " WHERE true", 60, - "SELECT * FROM " + TestsConstants.TEST_INDEX_BEER + " WHERE Id=10", 1, - "SELECT * FROM " + TestsConstants.TEST_INDEX_BEER + " WHERE Id + 5=15", 1, - "SELECT * FROM " + TestsConstants.TEST_INDEX_BANK, 7); + + TestsConstants.TEST_INDEX_ACCOUNT + + " WHERE match(address, 'street') AND match(state, 'OH')", + 5, + "SELECT * FROM " + TestsConstants.TEST_INDEX_BEER + " WHERE true", + 60, + "SELECT * FROM " + TestsConstants.TEST_INDEX_BEER + " WHERE Id=10", + 1, + "SELECT * FROM " + TestsConstants.TEST_INDEX_BEER + " WHERE Id + 5=15", + 1, + "SELECT * FROM " + TestsConstants.TEST_INDEX_BANK, + 7); private final String sqlStatement; diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/StandalonePaginationIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/StandalonePaginationIT.java index e884734c96..698e185abb 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/StandalonePaginationIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/StandalonePaginationIT.java @@ -166,6 +166,7 @@ private Settings defaultSettings() { new ImmutableMap.Builder() .put(Key.QUERY_SIZE_LIMIT, 200) .put(Key.SQL_CURSOR_KEEP_ALIVE, TimeValue.timeValueMinutes(1)) + .put(Key.SQL_PAGINATION_API_SEARCH_AFTER, true) .build(); @Override diff --git a/integ-test/src/test/resources/expectedOutput/ppl/explain_filter_agg_push.json b/integ-test/src/test/resources/expectedOutput/ppl/explain_filter_agg_push.json index 568b397f07..8035822357 100644 --- a/integ-test/src/test/resources/expectedOutput/ppl/explain_filter_agg_push.json +++ b/integ-test/src/test/resources/expectedOutput/ppl/explain_filter_agg_push.json @@ -8,7 +8,7 @@ { "name": "OpenSearchIndexScan", "description": { - "request": "OpenSearchQueryRequest(indexName\u003dopensearch-sql_test_index_account, sourceBuilder\u003d{\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":30,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"state\":{\"terms\":{\"field\":\"state.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}},{\"city\":{\"terms\":{\"field\":\"city.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg_age\":{\"avg\":{\"field\":\"age\"}}}}}}, searchDone\u003dfalse)" + "request": "OpenSearchQueryRequest(indexName\u003dopensearch-sql_test_index_account, sourceBuilder\u003d{\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":30,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"state\":{\"terms\":{\"field\":\"state.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}},{\"city\":{\"terms\":{\"field\":\"city.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg_age\":{\"avg\":{\"field\":\"age\"}}}}}}, needClean\u003dtrue, searchDone\u003dfalse, pitId\u003dnull, cursorKeepAlive\u003dnull, searchAfter\u003dnull, searchResponse\u003dnull)" }, "children": [] } diff --git a/integ-test/src/test/resources/expectedOutput/ppl/explain_filter_push.json b/integ-test/src/test/resources/expectedOutput/ppl/explain_filter_push.json index 0e7087aa1f..3e92a17b97 100644 --- a/integ-test/src/test/resources/expectedOutput/ppl/explain_filter_push.json +++ b/integ-test/src/test/resources/expectedOutput/ppl/explain_filter_push.json @@ -8,7 +8,7 @@ { "name": "OpenSearchIndexScan", "description": { - "request": "OpenSearchQueryRequest(indexName\u003dopensearch-sql_test_index_account, sourceBuilder\u003d{\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"query\":{\"bool\":{\"filter\":[{\"bool\":{\"filter\":[{\"range\":{\"balance\":{\"from\":10000,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},{\"range\":{\"age\":{\"from\":null,\"to\":40,\"include_lower\":true,\"include_upper\":false,\"boost\":1.0}}}],\"adjust_pure_negative\":true,\"boost\":1.0}},{\"range\":{\"age\":{\"from\":30,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}}],\"adjust_pure_negative\":true,\"boost\":1.0}},\"_source\":{\"includes\":[\"age\"],\"excludes\":[]},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}]}, searchDone\u003dfalse)" + "request": "OpenSearchQueryRequest(indexName\u003dopensearch-sql_test_index_account, sourceBuilder\u003d{\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"query\":{\"bool\":{\"filter\":[{\"bool\":{\"filter\":[{\"range\":{\"balance\":{\"from\":10000,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},{\"range\":{\"age\":{\"from\":null,\"to\":40,\"include_lower\":true,\"include_upper\":false,\"boost\":1.0}}}],\"adjust_pure_negative\":true,\"boost\":1.0}},{\"range\":{\"age\":{\"from\":30,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}}],\"adjust_pure_negative\":true,\"boost\":1.0}},\"_source\":{\"includes\":[\"age\"],\"excludes\":[]},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}]}, needClean\u003dtrue, searchDone\u003dfalse, pitId\u003dnull, cursorKeepAlive\u003dnull, searchAfter\u003dnull, searchResponse\u003dnull)" }, "children": [] } diff --git a/integ-test/src/test/resources/expectedOutput/ppl/explain_limit_push.json b/integ-test/src/test/resources/expectedOutput/ppl/explain_limit_push.json index 51a627ea4d..0a0b58f17d 100644 --- a/integ-test/src/test/resources/expectedOutput/ppl/explain_limit_push.json +++ b/integ-test/src/test/resources/expectedOutput/ppl/explain_limit_push.json @@ -16,7 +16,7 @@ { "name": "OpenSearchIndexScan", "description": { - "request": "OpenSearchQueryRequest(indexName=opensearch-sql_test_index_account, sourceBuilder={\"from\":0,\"size\":5,\"timeout\":\"1m\"}, searchDone=false)" + "request": "OpenSearchQueryRequest(indexName=opensearch-sql_test_index_account, sourceBuilder={\"from\":0,\"size\":5,\"timeout\":\"1m\"}, needClean\u003dtrue, searchDone\u003dfalse, pitId\u003dnull, cursorKeepAlive\u003dnull, searchAfter\u003dnull, searchResponse\u003dnull)" }, "children": [] } diff --git a/integ-test/src/test/resources/expectedOutput/ppl/explain_output.json b/integ-test/src/test/resources/expectedOutput/ppl/explain_output.json index 8d45714283..bd7310810e 100644 --- a/integ-test/src/test/resources/expectedOutput/ppl/explain_output.json +++ b/integ-test/src/test/resources/expectedOutput/ppl/explain_output.json @@ -31,7 +31,7 @@ { "name": "OpenSearchIndexScan", "description": { - "request": "OpenSearchQueryRequest(indexName\u003dopensearch-sql_test_index_account, sourceBuilder\u003d{\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":30,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"state\":{\"terms\":{\"field\":\"state.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}},{\"city\":{\"terms\":{\"field\":\"city.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg_age\":{\"avg\":{\"field\":\"age\"}}}}}}, searchDone\u003dfalse)" + "request": "OpenSearchQueryRequest(indexName\u003dopensearch-sql_test_index_account, sourceBuilder\u003d{\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":30,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"state\":{\"terms\":{\"field\":\"state.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}},{\"city\":{\"terms\":{\"field\":\"city.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg_age\":{\"avg\":{\"field\":\"age\"}}}}}}, needClean\u003dtrue, searchDone\u003dfalse, pitId\u003dnull, cursorKeepAlive\u003dnull, searchAfter\u003dnull, searchResponse\u003dnull)" }, "children": [] } diff --git a/integ-test/src/test/resources/expectedOutput/ppl/explain_sort_push.json b/integ-test/src/test/resources/expectedOutput/ppl/explain_sort_push.json index af2a57e536..e2630e24f9 100644 --- a/integ-test/src/test/resources/expectedOutput/ppl/explain_sort_push.json +++ b/integ-test/src/test/resources/expectedOutput/ppl/explain_sort_push.json @@ -8,7 +8,7 @@ { "name": "OpenSearchIndexScan", "description": { - "request": "OpenSearchQueryRequest(indexName\u003dopensearch-sql_test_index_account, sourceBuilder\u003d{\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":30,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"_source\":{\"includes\":[\"age\"],\"excludes\":[]},\"sort\":[{\"age\":{\"order\":\"asc\",\"missing\":\"_first\"}}]}, searchDone\u003dfalse)" + "request": "OpenSearchQueryRequest(indexName\u003dopensearch-sql_test_index_account, sourceBuilder\u003d{\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":30,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"_source\":{\"includes\":[\"age\"],\"excludes\":[]},\"sort\":[{\"age\":{\"order\":\"asc\",\"missing\":\"_first\"}}]}, needClean\u003dtrue, searchDone\u003dfalse, pitId\u003dnull, cursorKeepAlive\u003dnull, searchAfter\u003dnull, searchResponse\u003dnull)" }, "children": [] } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/SelectResultSet.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/SelectResultSet.java index 6fd9c73375..2b9bdc349a 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/SelectResultSet.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/SelectResultSet.java @@ -573,7 +573,7 @@ private void populateDefaultCursor(DefaultCursor cursor) { Integer limit = cursor.getLimit(); long rowsLeft = rowsLeft(cursor.getFetchSize(), cursor.getLimit()); if (rowsLeft <= 0) { - // close the cursor + // Delete Point In Time ID if (LocalClusterState.state().getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)) { String pitId = cursor.getPitId(); PointInTimeHandler pit = new PointInTimeHandlerImpl(client, pitId); diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/logical/node/TableScan.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/logical/node/TableScan.java index 16af199ed7..59e6f27216 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/logical/node/TableScan.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/logical/node/TableScan.java @@ -5,11 +5,15 @@ package org.opensearch.sql.legacy.query.planner.logical.node; +import static org.opensearch.sql.common.setting.Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER; + import java.util.Map; +import org.opensearch.sql.legacy.esdomain.LocalClusterState; import org.opensearch.sql.legacy.query.join.TableInJoinRequestBuilder; import org.opensearch.sql.legacy.query.planner.core.PlanNode; import org.opensearch.sql.legacy.query.planner.logical.LogicalOperator; import org.opensearch.sql.legacy.query.planner.physical.PhysicalOperator; +import org.opensearch.sql.legacy.query.planner.physical.node.pointInTime.PointInTime; import org.opensearch.sql.legacy.query.planner.physical.node.scroll.Scroll; /** Table scan */ @@ -33,6 +37,9 @@ public PlanNode[] children() { @Override public PhysicalOperator[] toPhysical(Map> optimalOps) { + if (LocalClusterState.state().getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER)) { + return new PhysicalOperator[] {new PointInTime(request, pageSize)}; + } return new PhysicalOperator[] {new Scroll(request, pageSize)}; } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/Paginate.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/Paginate.java new file mode 100644 index 0000000000..5bf31bb691 --- /dev/null +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/Paginate.java @@ -0,0 +1,144 @@ +package org.opensearch.sql.legacy.query.planner.physical.node; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Objects; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.Strings; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.search.SearchHit; +import org.opensearch.sql.legacy.domain.Where; +import org.opensearch.sql.legacy.exception.SqlParseException; +import org.opensearch.sql.legacy.query.join.TableInJoinRequestBuilder; +import org.opensearch.sql.legacy.query.maker.QueryMaker; +import org.opensearch.sql.legacy.query.planner.core.ExecuteParams; +import org.opensearch.sql.legacy.query.planner.core.PlanNode; +import org.opensearch.sql.legacy.query.planner.physical.Row; +import org.opensearch.sql.legacy.query.planner.physical.estimation.Cost; +import org.opensearch.sql.legacy.query.planner.resource.ResourceManager; + +public abstract class Paginate extends BatchPhysicalOperator { + + /** Request to submit to OpenSearch to scan over */ + protected final TableInJoinRequestBuilder request; + + protected final int pageSize; + + protected Client client; + + protected SearchResponse searchResponse; + + protected Integer timeout; + + protected ResourceManager resourceMgr; + + public Paginate(TableInJoinRequestBuilder request, int pageSize) { + this.request = request; + this.pageSize = pageSize; + } + + @Override + public PlanNode[] children() { + return new PlanNode[0]; + } + + @Override + public Cost estimate() { + return new Cost(); + } + + @Override + public void open(ExecuteParams params) throws Exception { + super.open(params); + client = params.get(ExecuteParams.ExecuteParamType.CLIENT); + timeout = params.get(ExecuteParams.ExecuteParamType.TIMEOUT); + resourceMgr = params.get(ExecuteParams.ExecuteParamType.RESOURCE_MANAGER); + + Object filter = params.get(ExecuteParams.ExecuteParamType.EXTRA_QUERY_FILTER); + if (filter instanceof BoolQueryBuilder) { + request + .getRequestBuilder() + .setQuery(generateNewQueryWithExtraFilter((BoolQueryBuilder) filter)); + + if (LOG.isDebugEnabled()) { + LOG.debug( + "Received extra query filter, re-build query: {}", + Strings.toString( + XContentType.JSON, request.getRequestBuilder().request().source(), true, true)); + } + } + } + + @Override + protected Collection> prefetch() { + Objects.requireNonNull(client, "Client connection is not ready"); + Objects.requireNonNull(resourceMgr, "ResourceManager is not set"); + Objects.requireNonNull(timeout, "Time out is not set"); + + if (searchResponse == null) { + loadFirstBatch(); + updateMetaResult(); + } else { + loadNextBatch(); + } + return wrapRowForCurrentBatch(); + } + + protected abstract void loadFirstBatch(); + + protected abstract void loadNextBatch(); + + /** + * Extra filter pushed down from upstream. Re-parse WHERE clause with extra filter because + * OpenSearch RequestBuilder doesn't allow QueryBuilder inside be changed after added. + */ + protected QueryBuilder generateNewQueryWithExtraFilter(BoolQueryBuilder filter) + throws SqlParseException { + Where where = request.getOriginalSelect().getWhere(); + BoolQueryBuilder newQuery; + if (where != null) { + newQuery = QueryMaker.explain(where, false); + newQuery.must(filter); + } else { + newQuery = filter; + } + return newQuery; + } + + protected void updateMetaResult() { + resourceMgr.getMetaResult().addTotalNumOfShards(searchResponse.getTotalShards()); + resourceMgr.getMetaResult().addSuccessfulShards(searchResponse.getSuccessfulShards()); + resourceMgr.getMetaResult().addFailedShards(searchResponse.getFailedShards()); + resourceMgr.getMetaResult().updateTimeOut(searchResponse.isTimedOut()); + } + + @SuppressWarnings("unchecked") + protected Collection> wrapRowForCurrentBatch() { + SearchHit[] hits = searchResponse.getHits().getHits(); + Row[] rows = new Row[hits.length]; + for (int i = 0; i < hits.length; i++) { + rows[i] = new SearchHitRow(hits[i], request.getAlias()); + } + return Arrays.asList(rows); + } + + @Override + public String toString() { + return getClass().getSimpleName() + " [ " + describeTable() + ", pageSize=" + pageSize + " ]"; + } + + protected String describeTable() { + return request.getOriginalSelect().getFrom().get(0).getIndex() + " as " + request.getAlias(); + } + + /********************************************* + * Getters for Explain + *********************************************/ + + public String getRequest() { + return Strings.toString(XContentType.JSON, request.getRequestBuilder().request().source()); + } +} diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/SearchHitRow.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/SearchHitRow.java similarity index 97% rename from legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/SearchHitRow.java rename to legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/SearchHitRow.java index a859a7c88c..a69c912da5 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/SearchHitRow.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/SearchHitRow.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.sql.legacy.query.planner.physical.node.scroll; +package org.opensearch.sql.legacy.query.planner.physical.node; import com.google.common.base.Strings; import java.util.HashMap; @@ -32,7 +32,7 @@ * retain() in Project | {"firstName": "Allen", "age": 30 } | "" | retain("e.name.first", "e.age") * ---------------------------------------------------------------------------------------------------------------------- */ -class SearchHitRow implements Row { +public class SearchHitRow implements Row { /** Native OpenSearch data object for each row */ private final SearchHit hit; @@ -43,7 +43,7 @@ class SearchHitRow implements Row { /** Table alias owned the row. Empty if this row comes from combination of two other rows */ private final String tableAlias; - SearchHitRow(SearchHit hit, String tableAlias) { + public SearchHitRow(SearchHit hit, String tableAlias) { this.hit = hit; this.source = hit.getSourceAsMap(); this.tableAlias = tableAlias; diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/pointInTime/PointInTime.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/pointInTime/PointInTime.java new file mode 100644 index 0000000000..9ddbde2d29 --- /dev/null +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/pointInTime/PointInTime.java @@ -0,0 +1,73 @@ +package org.opensearch.sql.legacy.query.planner.physical.node.pointInTime; + +import org.opensearch.common.unit.TimeValue; +import org.opensearch.search.builder.PointInTimeBuilder; +import org.opensearch.search.sort.FieldSortBuilder; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.sql.legacy.pit.PointInTimeHandlerImpl; +import org.opensearch.sql.legacy.query.join.TableInJoinRequestBuilder; +import org.opensearch.sql.legacy.query.planner.physical.node.Paginate; + +/** OpenSearch Search API with Point in time as physical implementation of TableScan */ +public class PointInTime extends Paginate { + + private String pitId; + private PointInTimeHandlerImpl pit; + + public PointInTime(TableInJoinRequestBuilder request, int pageSize) { + super(request, pageSize); + } + + @Override + public void close() { + if (searchResponse != null) { + LOG.debug("Closing Point In Time (PIT) context"); + // Delete the Point In Time context + pit.delete(); + searchResponse = null; + } else { + LOG.debug("PIT context is already closed or was never opened"); + } + } + + @Override + protected void loadFirstBatch() { + // Create PIT and set to request object + pit = new PointInTimeHandlerImpl(client, request.getOriginalSelect().getIndexArr()); + pit.create(); + pitId = pit.getPitId(); + + LOG.info("Loading first batch of response using Point In Time"); + searchResponse = + request + .getRequestBuilder() + .addSort(FieldSortBuilder.DOC_FIELD_NAME, SortOrder.ASC) + .setSize(pageSize) + .setTimeout(TimeValue.timeValueSeconds(timeout)) + .setPointInTime(new PointInTimeBuilder(pitId)) + .get(); + } + + @Override + protected void loadNextBatch() { + // Add PIT with search after to fetch next batch of data + if (searchResponse.getHits().getHits() != null + && searchResponse.getHits().getHits().length > 0) { + Object[] sortValues = + searchResponse + .getHits() + .getHits()[searchResponse.getHits().getHits().length - 1] + .getSortValues(); + + LOG.info("Loading next batch of response using Point In Time. - " + pitId); + searchResponse = + request + .getRequestBuilder() + .setSize(pageSize) + .setTimeout(TimeValue.timeValueSeconds(timeout)) + .setPointInTime(new PointInTimeBuilder(pitId)) + .searchAfter(sortValues) + .get(); + } + } +} diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/Scroll.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/Scroll.java index 40e9860886..5019e9cde8 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/Scroll.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/Scroll.java @@ -5,138 +5,38 @@ package org.opensearch.sql.legacy.query.planner.physical.node.scroll; -import java.util.Arrays; -import java.util.Collection; -import java.util.Objects; import org.opensearch.action.search.ClearScrollResponse; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.client.Client; import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.common.Strings; -import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.search.SearchHit; import org.opensearch.search.sort.FieldSortBuilder; import org.opensearch.search.sort.SortOrder; -import org.opensearch.sql.legacy.domain.Where; -import org.opensearch.sql.legacy.exception.SqlParseException; import org.opensearch.sql.legacy.query.join.TableInJoinRequestBuilder; -import org.opensearch.sql.legacy.query.maker.QueryMaker; -import org.opensearch.sql.legacy.query.planner.core.ExecuteParams; -import org.opensearch.sql.legacy.query.planner.core.PlanNode; -import org.opensearch.sql.legacy.query.planner.physical.Row; -import org.opensearch.sql.legacy.query.planner.physical.estimation.Cost; -import org.opensearch.sql.legacy.query.planner.physical.node.BatchPhysicalOperator; -import org.opensearch.sql.legacy.query.planner.resource.ResourceManager; +import org.opensearch.sql.legacy.query.planner.physical.node.Paginate; /** OpenSearch Scroll API as physical implementation of TableScan */ -public class Scroll extends BatchPhysicalOperator { - - /** Request to submit to OpenSearch to scroll over */ - private final TableInJoinRequestBuilder request; - - /** Page size to scroll over index */ - private final int pageSize; - - /** Client connection to ElasticSearch */ - private Client client; - - /** Currently undergoing Scroll */ - private SearchResponse scrollResponse; - - /** Time out */ - private Integer timeout; - - /** Resource monitor manager */ - private ResourceManager resourceMgr; +public class Scroll extends Paginate { public Scroll(TableInJoinRequestBuilder request, int pageSize) { - this.request = request; - this.pageSize = pageSize; - } - - @Override - public PlanNode[] children() { - return new PlanNode[0]; - } - - @Override - public Cost estimate() { - return new Cost(); - } - - @Override - public void open(ExecuteParams params) throws Exception { - super.open(params); - client = params.get(ExecuteParams.ExecuteParamType.CLIENT); - timeout = params.get(ExecuteParams.ExecuteParamType.TIMEOUT); - resourceMgr = params.get(ExecuteParams.ExecuteParamType.RESOURCE_MANAGER); - - Object filter = params.get(ExecuteParams.ExecuteParamType.EXTRA_QUERY_FILTER); - if (filter instanceof BoolQueryBuilder) { - request - .getRequestBuilder() - .setQuery(generateNewQueryWithExtraFilter((BoolQueryBuilder) filter)); - - if (LOG.isDebugEnabled()) { - LOG.debug( - "Received extra query filter, re-build query: {}", - Strings.toString( - XContentType.JSON, request.getRequestBuilder().request().source(), true, true)); - } - } + super(request, pageSize); } @Override public void close() { - if (scrollResponse != null) { + if (searchResponse != null) { LOG.debug("Closing all scroll resources"); ClearScrollResponse clearScrollResponse = - client.prepareClearScroll().addScrollId(scrollResponse.getScrollId()).get(); + client.prepareClearScroll().addScrollId(searchResponse.getScrollId()).get(); if (!clearScrollResponse.isSucceeded()) { LOG.warn("Failed to close scroll: {}", clearScrollResponse.status()); } - scrollResponse = null; + searchResponse = null; } else { LOG.debug("Scroll already be closed"); } } @Override - protected Collection> prefetch() { - Objects.requireNonNull(client, "Client connection is not ready"); - Objects.requireNonNull(resourceMgr, "ResourceManager is not set"); - Objects.requireNonNull(timeout, "Time out is not set"); - - if (scrollResponse == null) { - loadFirstBatch(); - updateMetaResult(); - } else { - loadNextBatchByScrollId(); - } - return wrapRowForCurrentBatch(); - } - - /** - * Extra filter pushed down from upstream. Re-parse WHERE clause with extra filter because - * OpenSearch RequestBuilder doesn't allow QueryBuilder inside be changed after added. - */ - private QueryBuilder generateNewQueryWithExtraFilter(BoolQueryBuilder filter) - throws SqlParseException { - Where where = request.getOriginalSelect().getWhere(); - BoolQueryBuilder newQuery; - if (where != null) { - newQuery = QueryMaker.explain(where, false); - newQuery.must(filter); - } else { - newQuery = filter; - } - return newQuery; - } - - private void loadFirstBatch() { - scrollResponse = + protected void loadFirstBatch() { + searchResponse = request .getRequestBuilder() .addSort(FieldSortBuilder.DOC_FIELD_NAME, SortOrder.ASC) @@ -145,45 +45,12 @@ private void loadFirstBatch() { .get(); } - private void updateMetaResult() { - resourceMgr.getMetaResult().addTotalNumOfShards(scrollResponse.getTotalShards()); - resourceMgr.getMetaResult().addSuccessfulShards(scrollResponse.getSuccessfulShards()); - resourceMgr.getMetaResult().addFailedShards(scrollResponse.getFailedShards()); - resourceMgr.getMetaResult().updateTimeOut(scrollResponse.isTimedOut()); - } - - private void loadNextBatchByScrollId() { - scrollResponse = + @Override + protected void loadNextBatch() { + searchResponse = client - .prepareSearchScroll(scrollResponse.getScrollId()) + .prepareSearchScroll(searchResponse.getScrollId()) .setScroll(TimeValue.timeValueSeconds(timeout)) .get(); } - - @SuppressWarnings("unchecked") - private Collection> wrapRowForCurrentBatch() { - SearchHit[] hits = scrollResponse.getHits().getHits(); - Row[] rows = new Row[hits.length]; - for (int i = 0; i < hits.length; i++) { - rows[i] = new SearchHitRow(hits[i], request.getAlias()); - } - return Arrays.asList(rows); - } - - @Override - public String toString() { - return "Scroll [ " + describeTable() + ", pageSize=" + pageSize + " ]"; - } - - private String describeTable() { - return request.getOriginalSelect().getFrom().get(0).getIndex() + " as " + request.getAlias(); - } - - /********************************************* - * Getters for Explain - *********************************************/ - - public String getRequest() { - return Strings.toString(XContentType.JSON, request.getRequestBuilder().request().source()); - } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/SearchHitRowTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/SearchHitRowTest.java index dd0fc626c0..f7d2030b0c 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/SearchHitRowTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/SearchHitRowTest.java @@ -12,6 +12,7 @@ import org.junit.Test; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.search.SearchHit; +import org.opensearch.sql.legacy.query.planner.physical.node.SearchHitRow; public class SearchHitRowTest { diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerTest.java index 521b225893..6ff907ba30 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerTest.java @@ -42,6 +42,7 @@ import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; +import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.legacy.domain.JoinSelect; import org.opensearch.sql.legacy.esdomain.LocalClusterState; import org.opensearch.sql.legacy.exception.SqlParseException; @@ -104,6 +105,7 @@ public void init() { // to mock. // In this case, default value in Setting will be returned all the time. doReturn(emptyList()).when(settings).getSettings(); + doReturn(false).when(settings).getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER); LocalClusterState.state().setPluginSettings(settings); ActionFuture mockFuture = mock(ActionFuture.class); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchClient.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchClient.java index 0a9cc67993..cdc3d4462f 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchClient.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchClient.java @@ -7,6 +7,8 @@ import java.util.List; import java.util.Map; +import org.opensearch.action.search.CreatePitRequest; +import org.opensearch.action.search.DeletePitRequest; import org.opensearch.client.node.NodeClient; import org.opensearch.sql.opensearch.mapping.IndexMapping; import org.opensearch.sql.opensearch.request.OpenSearchRequest; @@ -89,4 +91,19 @@ public interface OpenSearchClient { void schedule(Runnable task); NodeClient getNodeClient(); + + /** + * Create PIT for given indices + * + * @param createPitRequest Create Point In Time request + * @return PitId + */ + String createPit(CreatePitRequest createPitRequest); + + /** + * Delete PIT + * + * @param deletePitRequest Delete Point In Time request + */ + void deletePit(DeletePitRequest deletePitRequest); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java index 993e092534..7a9487ef6a 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java @@ -11,6 +11,7 @@ import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.concurrent.ExecutionException; import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; @@ -21,13 +22,16 @@ import org.opensearch.action.admin.indices.get.GetIndexResponse; import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; import org.opensearch.action.admin.indices.settings.get.GetSettingsResponse; +import org.opensearch.action.search.*; import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.metadata.AliasMetadata; +import org.opensearch.common.action.ActionFuture; import org.opensearch.common.settings.Settings; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.IndexSettings; import org.opensearch.sql.opensearch.mapping.IndexMapping; import org.opensearch.sql.opensearch.request.OpenSearchRequest; +import org.opensearch.sql.opensearch.request.OpenSearchScrollRequest; import org.opensearch.sql.opensearch.response.OpenSearchResponse; /** OpenSearch connection by node client. */ @@ -155,20 +159,32 @@ public List indices() { */ @Override public Map meta() { - return ImmutableMap.of(META_CLUSTER_NAME, client.settings().get("cluster.name", "opensearch")); + return ImmutableMap.of( + META_CLUSTER_NAME, + client.settings().get("cluster.name", "opensearch"), + "plugins.sql.pagination.api", + client.settings().get("plugins.sql.pagination.api", "true")); } @Override public void cleanup(OpenSearchRequest request) { - request.clean( - scrollId -> { - try { - client.prepareClearScroll().addScrollId(scrollId).get(); - } catch (Exception e) { - throw new IllegalStateException( - "Failed to clean up resources for search request " + request, e); - } - }); + if (request instanceof OpenSearchScrollRequest) { + request.clean( + scrollId -> { + try { + client.prepareClearScroll().addScrollId(scrollId).get(); + } catch (Exception e) { + throw new IllegalStateException( + "Failed to clean up resources for search request " + request, e); + } + }); + } else { + request.clean( + pitId -> { + DeletePitRequest deletePitRequest = new DeletePitRequest(pitId); + deletePit(deletePitRequest); + }); + } } @Override @@ -181,4 +197,27 @@ public void schedule(Runnable task) { public NodeClient getNodeClient() { return client; } + + @Override + public String createPit(CreatePitRequest createPitRequest) { + ActionFuture execute = + this.client.execute(CreatePitAction.INSTANCE, createPitRequest); + try { + CreatePitResponse pitResponse = execute.get(); + return pitResponse.getId(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException("Error occurred while creating PIT for new engine SQL query", e); + } + } + + @Override + public void deletePit(DeletePitRequest deletePitRequest) { + ActionFuture execute = + this.client.execute(DeletePitAction.INSTANCE, deletePitRequest); + try { + DeletePitResponse deletePitResponse = execute.get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException("Error occurred while deleting PIT.", e); + } + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchRestClient.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchRestClient.java index b6106982a7..5cb6a69918 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchRestClient.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchRestClient.java @@ -19,7 +19,7 @@ import org.opensearch.action.admin.cluster.settings.ClusterGetSettingsRequest; import org.opensearch.action.admin.indices.settings.get.GetSettingsRequest; import org.opensearch.action.admin.indices.settings.get.GetSettingsResponse; -import org.opensearch.action.search.ClearScrollRequest; +import org.opensearch.action.search.*; import org.opensearch.client.RequestOptions; import org.opensearch.client.RestHighLevelClient; import org.opensearch.client.indices.CreateIndexRequest; @@ -32,6 +32,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.sql.opensearch.mapping.IndexMapping; import org.opensearch.sql.opensearch.request.OpenSearchRequest; +import org.opensearch.sql.opensearch.request.OpenSearchScrollRequest; import org.opensearch.sql.opensearch.response.OpenSearchResponse; /** @@ -166,6 +167,8 @@ public Map meta() { final Settings defaultSettings = client.cluster().getSettings(request, RequestOptions.DEFAULT).getDefaultSettings(); builder.put(META_CLUSTER_NAME, defaultSettings.get("cluster.name", "opensearch")); + builder.put( + "plugins.sql.pagination.api", defaultSettings.get("plugins.sql.pagination.api", "true")); return builder.build(); } catch (IOException e) { throw new IllegalStateException("Failed to get cluster meta info", e); @@ -174,17 +177,25 @@ public Map meta() { @Override public void cleanup(OpenSearchRequest request) { - request.clean( - scrollId -> { - try { - ClearScrollRequest clearRequest = new ClearScrollRequest(); - clearRequest.addScrollId(scrollId); - client.clearScroll(clearRequest, RequestOptions.DEFAULT); - } catch (IOException e) { - throw new IllegalStateException( - "Failed to clean up resources for search request " + request, e); - } - }); + if (request instanceof OpenSearchScrollRequest) { + request.clean( + scrollId -> { + try { + ClearScrollRequest clearRequest = new ClearScrollRequest(); + clearRequest.addScrollId(scrollId); + client.clearScroll(clearRequest, RequestOptions.DEFAULT); + } catch (IOException e) { + throw new IllegalStateException( + "Failed to clean up resources for search request " + request, e); + } + }); + } else { + request.clean( + pitId -> { + DeletePitRequest deletePitRequest = new DeletePitRequest(pitId); + deletePit(deletePitRequest); + }); + } } @Override @@ -196,4 +207,25 @@ public void schedule(Runnable task) { public NodeClient getNodeClient() { throw new UnsupportedOperationException("Unsupported method."); } + + @Override + public String createPit(CreatePitRequest createPitRequest) { + try { + CreatePitResponse createPitResponse = + client.createPit(createPitRequest, RequestOptions.DEFAULT); + return createPitResponse.getId(); + } catch (IOException e) { + throw new RuntimeException("Error occurred while creating PIT for new engine SQL query", e); + } + } + + @Override + public void deletePit(DeletePitRequest deletePitRequest) { + try { + DeletePitResponse deletePitResponse = + client.deletePit(deletePitRequest, RequestOptions.DEFAULT); + } catch (IOException e) { + throw new RuntimeException("Error occurred while creating PIT for new engine SQL query", e); + } + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequest.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequest.java index 6cf7fe49c2..fff252f3b4 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequest.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequest.java @@ -5,21 +5,35 @@ package org.opensearch.sql.opensearch.request; +import static org.opensearch.core.xcontent.DeprecationHandler.IGNORE_DEPRECATIONS; +import static org.opensearch.search.sort.FieldSortBuilder.DOC_FIELD_NAME; +import static org.opensearch.search.sort.SortOrder.ASC; + import java.io.IOException; +import java.util.Collections; import java.util.List; import java.util.function.Consumer; import java.util.function.Function; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.search.SearchScrollRequest; +import org.opensearch.action.search.*; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; +import org.opensearch.search.SearchModule; +import org.opensearch.search.builder.PointInTimeBuilder; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.response.OpenSearchResponse; +import org.opensearch.sql.opensearch.storage.OpenSearchIndex; +import org.opensearch.sql.opensearch.storage.OpenSearchStorageEngine; /** * OpenSearch search request. This has to be stateful because it needs to: @@ -36,7 +50,7 @@ public class OpenSearchQueryRequest implements OpenSearchRequest { private final IndexName indexName; /** Search request source builder. */ - private final SearchSourceBuilder sourceBuilder; + private SearchSourceBuilder sourceBuilder; /** OpenSearchExprValueFactory. */ @EqualsAndHashCode.Exclude @ToString.Exclude @@ -45,9 +59,19 @@ public class OpenSearchQueryRequest implements OpenSearchRequest { /** List of includes expected in the response. */ @EqualsAndHashCode.Exclude @ToString.Exclude private final List includes; + @EqualsAndHashCode.Exclude private boolean needClean = true; + /** Indicate the search already done. */ private boolean searchDone = false; + private String pitId; + + private TimeValue cursorKeepAlive; + + private Object[] searchAfter; + + private SearchResponse searchResponse = null; + /** Constructor of OpenSearchQueryRequest. */ public OpenSearchQueryRequest( String indexName, int size, OpenSearchExprValueFactory factory, List includes) { @@ -78,35 +102,153 @@ public OpenSearchQueryRequest( this.includes = includes; } + /** Constructor of OpenSearchQueryRequest with PIT support. */ + public OpenSearchQueryRequest( + IndexName indexName, + SearchSourceBuilder sourceBuilder, + OpenSearchExprValueFactory factory, + List includes, + TimeValue cursorKeepAlive, + String pitId) { + this.indexName = indexName; + this.sourceBuilder = sourceBuilder; + this.exprValueFactory = factory; + this.includes = includes; + this.cursorKeepAlive = cursorKeepAlive; + this.pitId = pitId; + } + + /** + * Constructs OpenSearchQueryRequest from serialized representation. + * + * @param in stream to read data from. + * @param engine OpenSearchSqlEngine to get node-specific context. + * @throws IOException thrown if reading from input {@code in} fails. + */ + public OpenSearchQueryRequest(StreamInput in, OpenSearchStorageEngine engine) throws IOException { + // Deserialize the SearchSourceBuilder from the string representation + String sourceBuilderString = in.readString(); + + NamedXContentRegistry xContentRegistry = + new NamedXContentRegistry( + new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + XContentParser parser = + XContentType.JSON + .xContent() + .createParser(xContentRegistry, IGNORE_DEPRECATIONS, sourceBuilderString); + this.sourceBuilder = SearchSourceBuilder.fromXContent(parser); + + cursorKeepAlive = in.readTimeValue(); + pitId = in.readString(); + includes = in.readStringList(); + indexName = new IndexName(in); + + int length = in.readVInt(); + this.searchAfter = new Object[length]; + for (int i = 0; i < length; i++) { + this.searchAfter[i] = in.readGenericValue(); + } + + OpenSearchIndex index = (OpenSearchIndex) engine.getTable(null, indexName.toString()); + exprValueFactory = new OpenSearchExprValueFactory(index.getFieldOpenSearchTypes()); + } + @Override public OpenSearchResponse search( Function searchAction, Function scrollAction) { + if (this.pitId == null) { + // When SearchRequest doesn't contain PitId, fetch single page request + if (searchDone) { + return new OpenSearchResponse(SearchHits.empty(), exprValueFactory, includes); + } else { + searchDone = true; + return new OpenSearchResponse( + searchAction.apply( + new SearchRequest().indices(indexName.getIndexNames()).source(sourceBuilder)), + exprValueFactory, + includes); + } + } else { + // Search with PIT instead of scroll API + return searchWithPIT(searchAction); + } + } + + public OpenSearchResponse searchWithPIT(Function searchAction) { + OpenSearchResponse openSearchResponse; if (searchDone) { - return new OpenSearchResponse(SearchHits.empty(), exprValueFactory, includes); + openSearchResponse = new OpenSearchResponse(SearchHits.empty(), exprValueFactory, includes); } else { - searchDone = true; - return new OpenSearchResponse( - searchAction.apply( - new SearchRequest().indices(indexName.getIndexNames()).source(sourceBuilder)), - exprValueFactory, - includes); + this.sourceBuilder.pointInTimeBuilder(new PointInTimeBuilder(this.pitId)); + this.sourceBuilder.timeout(cursorKeepAlive); + // check for search after + if (searchAfter != null) { + this.sourceBuilder.searchAfter(searchAfter); + } + // Set sort field for search_after + if (this.sourceBuilder.sorts() == null) { + this.sourceBuilder.sort(DOC_FIELD_NAME, ASC); + } + SearchRequest searchRequest = new SearchRequest().source(this.sourceBuilder); + this.searchResponse = searchAction.apply(searchRequest); + + openSearchResponse = new OpenSearchResponse(this.searchResponse, exprValueFactory, includes); + + needClean = openSearchResponse.isEmpty(); + searchDone = openSearchResponse.isEmpty(); + SearchHit[] searchHits = this.searchResponse.getHits().getHits(); + if (searchHits != null && searchHits.length > 0) { + searchAfter = searchHits[searchHits.length - 1].getSortValues(); + this.sourceBuilder.searchAfter(searchAfter); + } } + return openSearchResponse; } @Override public void clean(Consumer cleanAction) { - // do nothing. + try { + // clean on the last page only, to prevent deleting the PitId in the middle of paging. + if (this.pitId != null && needClean) { + cleanAction.accept(this.pitId); + searchDone = true; + } + } finally { + this.pitId = null; + } } @Override public boolean hasAnotherBatch() { + if (this.pitId != null) { + return !needClean; + } return false; } @Override public void writeTo(StreamOutput out) throws IOException { - throw new UnsupportedOperationException( - "OpenSearchQueryRequest serialization " + "is not implemented."); + if (this.pitId != null) { + // Convert SearchSourceBuilder to XContent and write it as a string + out.writeString(sourceBuilder.toString()); + + out.writeTimeValue(sourceBuilder.timeout()); + out.writeString(sourceBuilder.pointInTimeBuilder().getId()); + out.writeStringCollection(includes); + indexName.writeTo(out); + + // Serialize the searchAfter array + if (searchAfter != null) { + out.writeVInt(searchAfter.length); + for (Object obj : searchAfter) { + out.writeGenericValue(obj); + } + } + } else { + // OpenSearch Query request without PIT for single page requests + throw new UnsupportedOperationException( + "OpenSearchQueryRequest serialization is not implemented."); + } } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java index 1df3dcb183..6fa9b17697 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java @@ -26,6 +26,7 @@ import lombok.ToString; import org.apache.commons.lang3.tuple.Pair; import org.apache.lucene.search.join.ScoreMode; +import org.opensearch.action.search.CreatePitRequest; import org.opensearch.common.unit.TimeValue; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.InnerHitBuilder; @@ -39,9 +40,11 @@ import org.opensearch.search.sort.SortBuilder; import org.opensearch.search.sort.SortBuilders; import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; @@ -67,10 +70,13 @@ public class OpenSearchRequestBuilder { private int startFrom = 0; + private final Settings settings; + /** Constructor. */ public OpenSearchRequestBuilder( - int requestedTotalSize, OpenSearchExprValueFactory exprValueFactory) { + int requestedTotalSize, OpenSearchExprValueFactory exprValueFactory, Settings settings) { this.requestedTotalSize = requestedTotalSize; + this.settings = settings; this.sourceBuilder = new SearchSourceBuilder() .from(startFrom) @@ -82,18 +88,65 @@ public OpenSearchRequestBuilder( /** * Build DSL request. * - * @return query request or scroll request + * @return query request with PIT or scroll request */ public OpenSearchRequest build( - OpenSearchRequest.IndexName indexName, int maxResultWindow, TimeValue scrollTimeout) { + OpenSearchRequest.IndexName indexName, + int maxResultWindow, + TimeValue cursorKeepAlive, + OpenSearchClient client) { + if (this.settings.getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)) { + return buildRequestWithPit(indexName, maxResultWindow, cursorKeepAlive, client); + } else { + return buildRequestWithScroll(indexName, maxResultWindow, cursorKeepAlive); + } + } + + private OpenSearchRequest buildRequestWithPit( + OpenSearchRequest.IndexName indexName, + int maxResultWindow, + TimeValue cursorKeepAlive, + OpenSearchClient client) { + int size = requestedTotalSize; + FetchSourceContext fetchSource = this.sourceBuilder.fetchSource(); + List includes = fetchSource != null ? Arrays.asList(fetchSource.includes()) : List.of(); + + if (pageSize == null) { + if (startFrom + size > maxResultWindow) { + sourceBuilder.size(maxResultWindow - startFrom); + // Search with PIT request + String pitId = createPit(indexName, cursorKeepAlive, client); + return new OpenSearchQueryRequest( + indexName, sourceBuilder, exprValueFactory, includes, cursorKeepAlive, pitId); + } else { + sourceBuilder.from(startFrom); + sourceBuilder.size(requestedTotalSize); + // Search with non-Pit request + return new OpenSearchQueryRequest(indexName, sourceBuilder, exprValueFactory, includes); + } + } else { + if (startFrom != 0) { + throw new UnsupportedOperationException("Non-zero offset is not supported with pagination"); + } + sourceBuilder.size(pageSize); + // Search with PIT request + String pitId = createPit(indexName, cursorKeepAlive, client); + return new OpenSearchQueryRequest( + indexName, sourceBuilder, exprValueFactory, includes, cursorKeepAlive, pitId); + } + } + + private OpenSearchRequest buildRequestWithScroll( + OpenSearchRequest.IndexName indexName, int maxResultWindow, TimeValue cursorKeepAlive) { int size = requestedTotalSize; FetchSourceContext fetchSource = this.sourceBuilder.fetchSource(); List includes = fetchSource != null ? Arrays.asList(fetchSource.includes()) : List.of(); + if (pageSize == null) { if (startFrom + size > maxResultWindow) { sourceBuilder.size(maxResultWindow - startFrom); return new OpenSearchScrollRequest( - indexName, scrollTimeout, sourceBuilder, exprValueFactory, includes); + indexName, cursorKeepAlive, sourceBuilder, exprValueFactory, includes); } else { sourceBuilder.from(startFrom); sourceBuilder.size(requestedTotalSize); @@ -105,10 +158,18 @@ public OpenSearchRequest build( } sourceBuilder.size(pageSize); return new OpenSearchScrollRequest( - indexName, scrollTimeout, sourceBuilder, exprValueFactory, includes); + indexName, cursorKeepAlive, sourceBuilder, exprValueFactory, includes); } } + private String createPit( + OpenSearchRequest.IndexName indexName, TimeValue cursorKeepAlive, OpenSearchClient client) { + // Create PIT ID for request + CreatePitRequest createPitRequest = + new CreatePitRequest(cursorKeepAlive, false, indexName.getIndexNames()); + return client.createPit(createPitRequest); + } + boolean isBoolFilterQuery(QueryBuilder current) { return (current instanceof BoolQueryBuilder); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java index c6afdb8511..a6fe83c8c4 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java @@ -163,13 +163,13 @@ public TableScanBuilder createScanBuilder() { final int querySizeLimit = settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT); final TimeValue cursorKeepAlive = settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE); - var builder = new OpenSearchRequestBuilder(querySizeLimit, createExprValueFactory()); + var builder = new OpenSearchRequestBuilder(querySizeLimit, createExprValueFactory(), settings); Function createScanOperator = requestBuilder -> new OpenSearchIndexScan( client, requestBuilder.getMaxResponseSize(), - requestBuilder.build(indexName, getMaxResultWindow(), cursorKeepAlive)); + requestBuilder.build(indexName, getMaxResultWindow(), cursorKeepAlive, client)); return new OpenSearchIndexScanBuilder(builder, createScanOperator); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScan.java index b1e4ccc463..b17773cb03 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScan.java @@ -14,10 +14,12 @@ import lombok.ToString; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.exception.NoCursorException; import org.opensearch.sql.executor.pagination.PlanSerializer; import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.request.OpenSearchQueryRequest; import org.opensearch.sql.opensearch.request.OpenSearchRequest; import org.opensearch.sql.opensearch.request.OpenSearchScrollRequest; import org.opensearch.sql.opensearch.response.OpenSearchResponse; @@ -45,6 +47,8 @@ public class OpenSearchIndexScan extends TableScanOperator implements Serializab /** Search response for current batch. */ private Iterator iterator; + private Settings pluginSettings; + /** Creates index scan based on a provided OpenSearchRequestBuilder. */ public OpenSearchIndexScan( OpenSearchClient client, int maxResponseSize, OpenSearchRequest request) { @@ -121,12 +125,18 @@ public void readExternal(ObjectInput in) throws IOException { (OpenSearchStorageEngine) ((PlanSerializer.CursorDeserializationStream) in).resolveObject("engine"); + client = engine.getClient(); + boolean pointInTimeEnabled = + Boolean.parseBoolean( + client.meta().get(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER.getKeyValue())); try (BytesStreamInput bsi = new BytesStreamInput(requestStream)) { - request = new OpenSearchScrollRequest(bsi, engine); + if (pointInTimeEnabled) { + request = new OpenSearchQueryRequest(bsi, engine); + } else { + request = new OpenSearchScrollRequest(bsi, engine); + } } maxResponseSize = in.readInt(); - - client = engine.getClient(); } @Override diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java index 9da6e05e92..ba0fb85422 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java @@ -13,8 +13,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Answers.RETURNS_DEEP_STUBS; -import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -31,6 +30,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicBoolean; import lombok.SneakyThrows; import org.apache.commons.lang3.reflect.FieldUtils; @@ -51,15 +51,16 @@ import org.opensearch.action.admin.indices.get.GetIndexResponse; import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; import org.opensearch.action.admin.indices.settings.get.GetSettingsResponse; -import org.opensearch.action.search.ClearScrollRequestBuilder; -import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.*; import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.metadata.AliasMetadata; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.common.action.ActionFuture; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.Strings; import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; @@ -74,6 +75,7 @@ import org.opensearch.sql.opensearch.data.type.OpenSearchTextType; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.mapping.IndexMapping; +import org.opensearch.sql.opensearch.request.OpenSearchQueryRequest; import org.opensearch.sql.opensearch.request.OpenSearchRequest; import org.opensearch.sql.opensearch.request.OpenSearchScrollRequest; import org.opensearch.sql.opensearch.response.OpenSearchResponse; @@ -393,6 +395,65 @@ void cleanup_rethrows_exception() { assertThrows(IllegalStateException.class, () -> client.cleanup(request)); } + @Test + @SneakyThrows + void cleanup_pit_request() { + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + new SearchSourceBuilder(), + factory, + List.of(), + TimeValue.timeValueMinutes(1L), + "samplePitId"); + // Enforce cleaning by setting a private field. + FieldUtils.writeField(request, "needClean", true, true); + client.cleanup(request); + verify(nodeClient).execute(any(), any()); + } + + @Test + @SneakyThrows + void cleanup_pit_request_throw_exception() { + DeletePitRequest deletePitRequest = new DeletePitRequest("samplePitId"); + ActionFuture actionFuture = mock(ActionFuture.class); + when(actionFuture.get()).thenThrow(new ExecutionException("Execution failed", new Throwable())); + when(nodeClient.execute(eq(DeletePitAction.INSTANCE), any(DeletePitRequest.class))) + .thenReturn(actionFuture); + assertThrows(RuntimeException.class, () -> client.deletePit(deletePitRequest)); + } + + @Test + @SneakyThrows + void create_pit() { + CreatePitRequest createPitRequest = + new CreatePitRequest(TimeValue.timeValueMinutes(5), false, Strings.EMPTY_ARRAY); + ActionFuture actionFuture = mock(ActionFuture.class); + CreatePitResponse createPitResponse = mock(CreatePitResponse.class); + when(createPitResponse.getId()).thenReturn("samplePitId"); + when(actionFuture.get()).thenReturn(createPitResponse); + when(nodeClient.execute(eq(CreatePitAction.INSTANCE), any(CreatePitRequest.class))) + .thenReturn(actionFuture); + + String pitId = client.createPit(createPitRequest); + assertEquals("samplePitId", pitId); + + verify(nodeClient).execute(CreatePitAction.INSTANCE, createPitRequest); + verify(actionFuture).get(); + } + + @Test + @SneakyThrows + void create_pit_request_throw_exception() { + CreatePitRequest createPitRequest = + new CreatePitRequest(TimeValue.timeValueMinutes(5), false, Strings.EMPTY_ARRAY); + ActionFuture actionFuture = mock(ActionFuture.class); + when(actionFuture.get()).thenThrow(new ExecutionException("Execution failed", new Throwable())); + when(nodeClient.execute(eq(CreatePitAction.INSTANCE), any(CreatePitRequest.class))) + .thenReturn(actionFuture); + assertThrows(RuntimeException.class, () -> client.createPit(createPitRequest)); + } + @Test void get_indices() { AliasMetadata aliasMetadata = mock(AliasMetadata.class); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java index b83313de07..82d9e74422 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java @@ -43,6 +43,8 @@ import org.opensearch.action.admin.cluster.settings.ClusterGetSettingsResponse; import org.opensearch.action.admin.indices.settings.get.GetSettingsRequest; import org.opensearch.action.admin.indices.settings.get.GetSettingsResponse; +import org.opensearch.action.search.CreatePitRequest; +import org.opensearch.action.search.CreatePitResponse; import org.opensearch.action.search.SearchResponse; import org.opensearch.client.RequestOptions; import org.opensearch.client.RestHighLevelClient; @@ -56,6 +58,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.Strings; import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; @@ -69,6 +72,7 @@ import org.opensearch.sql.opensearch.data.type.OpenSearchTextType; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.mapping.IndexMapping; +import org.opensearch.sql.opensearch.request.OpenSearchQueryRequest; import org.opensearch.sql.opensearch.request.OpenSearchRequest; import org.opensearch.sql.opensearch.request.OpenSearchScrollRequest; import org.opensearch.sql.opensearch.response.OpenSearchResponse; @@ -411,6 +415,64 @@ void cleanup_with_IOException() { assertThrows(IllegalStateException.class, () -> client.cleanup(request)); } + @Test + @SneakyThrows + void cleanup_pit_request() { + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + new SearchSourceBuilder(), + factory, + List.of(), + TimeValue.timeValueMinutes(1L), + "samplePitId"); + // Enforce cleaning by setting a private field. + FieldUtils.writeField(request, "needClean", true, true); + client.cleanup(request); + verify(restClient).deletePit(any(), any()); + } + + @Test + @SneakyThrows + void cleanup_pit_request_throw_exception() { + when(restClient.deletePit(any(), any())).thenThrow(new IOException()); + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + new SearchSourceBuilder(), + factory, + List.of(), + TimeValue.timeValueMinutes(1L), + "samplePitId"); + // Enforce cleaning by setting a private field. + FieldUtils.writeField(request, "needClean", true, true); + assertThrows(RuntimeException.class, () -> client.cleanup(request)); + } + + @Test + @SneakyThrows + void create_pit() { + CreatePitRequest createPitRequest = + new CreatePitRequest(TimeValue.timeValueMinutes(5), false, Strings.EMPTY_ARRAY); + CreatePitResponse createPitResponse = mock(CreatePitResponse.class); + when(createPitResponse.getId()).thenReturn("samplePitId"); + when(restClient.createPit(any(CreatePitRequest.class), any())).thenReturn(createPitResponse); + + String pitId = client.createPit(createPitRequest); + assertEquals("samplePitId", pitId); + + verify(restClient).createPit(createPitRequest, RequestOptions.DEFAULT); + } + + @Test + @SneakyThrows + void create_pit_request_throw_exception() { + CreatePitRequest createPitRequest = + new CreatePitRequest(TimeValue.timeValueMinutes(5), false, Strings.EMPTY_ARRAY); + when(restClient.createPit(any(), any())).thenThrow(new IOException()); + assertThrows(RuntimeException.class, () -> client.createPit(createPitRequest)); + } + @Test void get_indices() throws IOException { when(restClient.indices().get(any(GetIndexRequest.class), any(RequestOptions.class))) diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngineTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngineTest.java index 739b70b1b8..e5cf94eb86 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngineTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngineTest.java @@ -174,18 +174,20 @@ void explain_successfully() { new OpenSearchExecutionEngine(client, protector, new PlanSerializer(null)); Settings settings = mock(Settings.class); when(settings.getSettingValue(SQL_CURSOR_KEEP_ALIVE)).thenReturn(TimeValue.timeValueMinutes(1)); + when(settings.getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)).thenReturn(true); OpenSearchExprValueFactory exprValueFactory = mock(OpenSearchExprValueFactory.class); final var name = new OpenSearchRequest.IndexName("test"); final int defaultQuerySize = 100; final int maxResultWindow = 10000; - final var requestBuilder = new OpenSearchRequestBuilder(defaultQuerySize, exprValueFactory); + final var requestBuilder = + new OpenSearchRequestBuilder(defaultQuerySize, exprValueFactory, settings); PhysicalPlan plan = new OpenSearchIndexScan( mock(OpenSearchClient.class), maxResultWindow, requestBuilder.build( - name, maxResultWindow, settings.getSettingValue(SQL_CURSOR_KEEP_ALIVE))); + name, maxResultWindow, settings.getSettingValue(SQL_CURSOR_KEEP_ALIVE), client)); AtomicReference result = new AtomicReference<>(); executor.explain( diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java index 5cd11c6cd4..da06c1eb66 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java @@ -8,9 +8,7 @@ import static java.util.Collections.emptyList; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertSame; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.*; import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC; import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; @@ -91,6 +89,8 @@ public void setup() { @Test void test_protect_indexScan() { + when(settings.getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)).thenReturn(true); + String indexName = "test"; final int maxResultWindow = 10000; final int querySizeLimit = 200; @@ -114,11 +114,12 @@ void test_protect_indexScan() { final var name = new OpenSearchRequest.IndexName(indexName); final var request = - new OpenSearchRequestBuilder(querySizeLimit, exprValueFactory) + new OpenSearchRequestBuilder(querySizeLimit, exprValueFactory, settings) .build( name, maxResultWindow, - settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE)); + settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE), + client); assertEquals( PhysicalPlanDSL.project( PhysicalPlanDSL.limit( diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequestTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequestTest.java index d2bc5b0641..89b51207b5 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequestTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequestTest.java @@ -5,37 +5,40 @@ package org.opensearch.sql.opensearch.request; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; import static org.opensearch.sql.opensearch.request.OpenSearchRequest.DEFAULT_QUERY_TIMEOUT; +import java.io.IOException; +import java.lang.reflect.Field; import java.util.List; import java.util.function.Consumer; import java.util.function.Function; +import lombok.SneakyThrows; import org.apache.lucene.search.TotalHits; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchScrollRequest; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; +import org.opensearch.search.builder.PointInTimeBuilder; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.response.OpenSearchResponse; +import org.opensearch.sql.opensearch.storage.OpenSearchIndex; +import org.opensearch.sql.opensearch.storage.OpenSearchStorageEngine; @ExtendWith(MockitoExtension.class) public class OpenSearchQueryRequestTest { @@ -64,6 +67,93 @@ public class OpenSearchQueryRequestTest { private final OpenSearchQueryRequest remoteRequest = new OpenSearchQueryRequest("ccs:test", 200, factory, List.of()); + @Mock private StreamOutput streamOutput; + @Mock private StreamInput streamInput; + @Mock private OpenSearchStorageEngine engine; + @Mock private PointInTimeBuilder pointInTimeBuilder; + + @InjectMocks private OpenSearchQueryRequest serializationRequest; + + private SearchSourceBuilder sourceBuilderForSerializer; + + @BeforeEach + void setup() { + sourceBuilderForSerializer = new SearchSourceBuilder(); + sourceBuilderForSerializer.pointInTimeBuilder(pointInTimeBuilder); + sourceBuilderForSerializer.timeout(TimeValue.timeValueSeconds(30)); + } + + @SneakyThrows + @Test + void testWriteTo() throws IOException { + when(pointInTimeBuilder.getId()).thenReturn("samplePITId"); + sourceBuilderForSerializer.searchAfter(new Object[] {"value1", 123}); + List includes = List.of("field1", "field2"); + serializationRequest = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilderForSerializer, + factory, + includes, + new TimeValue(1000), + "samplePITId"); + + Field searchAfterField = OpenSearchQueryRequest.class.getDeclaredField("searchAfter"); + searchAfterField.setAccessible(true); + searchAfterField.set(serializationRequest, new Object[] {"value1", 123}); + + serializationRequest.writeTo(streamOutput); + + String expectedJson = "{\"timeout\":\"30s\",\"search_after\":[\"value1\",123]}"; + verify(streamOutput).writeString(expectedJson); + verify(streamOutput).writeTimeValue(TimeValue.timeValueSeconds(30)); + verify(streamOutput).writeString("samplePITId"); + verify(streamOutput).writeStringCollection(includes); + + verify(streamOutput).writeVInt(2); + verify(streamOutput).writeGenericValue("value1"); + verify(streamOutput).writeGenericValue(123); + } + + @Test + void testWriteToWithoutSearchAfter() + throws IOException, NoSuchFieldException, IllegalAccessException { + when(pointInTimeBuilder.getId()).thenReturn("samplePITId"); + + List includes = List.of("field1", "field2"); + serializationRequest = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilderForSerializer, + factory, + includes, + new TimeValue(1000), + "samplePITId"); + + serializationRequest.writeTo(streamOutput); + verify(streamOutput).writeString("{\"timeout\":\"30s\"}"); + verify(streamOutput).writeTimeValue(TimeValue.timeValueSeconds(30)); + verify(streamOutput).writeString("samplePITId"); + verify(streamOutput).writeStringCollection(includes); + verify(streamOutput, never()).writeVInt(anyInt()); + verify(streamOutput, never()).writeGenericValue(any()); + } + + @Test + void testWriteToWithoutPIT() { + serializationRequest = new OpenSearchQueryRequest("test", 200, factory, List.of()); + + UnsupportedOperationException exception = + assertThrows( + UnsupportedOperationException.class, + () -> { + request.writeTo(streamOutput); + }); + + assertEquals( + "OpenSearchQueryRequest serialization is not implemented.", exception.getMessage()); + } + @Test void search() { OpenSearchQueryRequest request = @@ -81,6 +171,145 @@ void search() { verify(searchAction, times(1)).apply(any()); } + @Test + void search_with_pit() { + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory, + List.of(), + new TimeValue(1000), + "samplePid"); + + when(searchAction.apply(any())).thenReturn(searchResponse); + when(searchResponse.getHits()).thenReturn(searchHits); + when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); + when(searchHit.getSortValues()).thenReturn(new String[] {"sortedValue"}); + when(sourceBuilder.sorts()).thenReturn(null); + + OpenSearchResponse openSearchResponse = request.searchWithPIT(searchAction); + assertFalse(openSearchResponse.isEmpty()); + verify(searchAction, times(1)).apply(any()); + + when(searchResponse.getHits()).thenReturn(searchHits); + when(searchResponse.getAggregations()).thenReturn(null); + when(searchHits.getHits()).thenReturn(null); + openSearchResponse = request.searchWithPIT(searchAction); + assertTrue(openSearchResponse.isEmpty()); + verify(searchAction, times(2)).apply(any()); + + openSearchResponse = request.searchWithPIT(searchAction); + assertTrue(openSearchResponse.isEmpty()); + } + + @Test + void search_with_pit_hits_null() { + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory, + List.of(), + new TimeValue(1000), + "samplePid"); + + when(searchAction.apply(any())).thenReturn(searchResponse); + when(searchResponse.getHits()).thenReturn(searchHits); + when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); + when(sourceBuilder.sorts()).thenReturn(null); + + OpenSearchResponse openSearchResponse = request.searchWithPIT(searchAction); + assertFalse(openSearchResponse.isEmpty()); + } + + @Test + void search_with_pit_hits_empty() { + SearchResponse searchResponse = mock(SearchResponse.class); + SearchHits searchHits = mock(SearchHits.class); + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory, + List.of(), + new TimeValue(1000), + "samplePid"); + + when(searchAction.apply(any())).thenReturn(searchResponse); + when(searchResponse.getHits()).thenReturn(searchHits); + when(searchHits.getHits()).thenReturn(new SearchHit[] {}); + when(sourceBuilder.sorts()).thenReturn(null); + + OpenSearchResponse openSearchResponse = request.searchWithPIT(searchAction); + assertTrue(openSearchResponse.isEmpty()); + } + + @Test + void search_with_pit_null() { + SearchResponse searchResponse = mock(SearchResponse.class); + SearchHits searchHits = mock(SearchHits.class); + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory, + List.of(), + new TimeValue(1000), + "sample"); + + when(searchAction.apply(any())).thenReturn(searchResponse); + when(searchResponse.getHits()).thenReturn(searchHits); + when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); + + OpenSearchResponse openSearchResponse = request.search(searchAction, scrollAction); + assertFalse(openSearchResponse.isEmpty()); + } + + @Test + void has_another_batch() { + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory, + List.of(), + new TimeValue(1000), + "sample"); + assertFalse(request.hasAnotherBatch()); + } + + @Test + void has_another_batch_pid_null() { + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory, + List.of(), + new TimeValue(1000), + null); + assertFalse(request.hasAnotherBatch()); + } + + @Test + void has_another_batch_need_clean() { + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory, + List.of(), + new TimeValue(1000), + "samplePid"); + + when(searchAction.apply(any())).thenReturn(searchResponse); + when(searchResponse.getHits()).thenReturn(searchHits); + when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); + OpenSearchResponse openSearchResponse = request.searchWithPIT(searchAction); + assertTrue(request.hasAnotherBatch()); + } + @Test void search_withoutContext() { OpenSearchQueryRequest request = @@ -121,6 +350,68 @@ void clean() { verify(cleanAction, never()).accept(any()); } + @Test + void testCleanConditionTrue() { + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory, + List.of(), + new TimeValue(1000), + "samplePid"); + + when(searchAction.apply(any())).thenReturn(searchResponse); + when(searchResponse.getHits()).thenReturn(searchHits); + when(searchHits.getHits()).thenReturn(null); + OpenSearchResponse openSearchResponse = request.searchWithPIT(searchAction); + + request.clean(cleanAction); + + verify(cleanAction, times(1)).accept("samplePid"); + assertTrue(request.isSearchDone()); + assertNull(request.getPitId()); + } + + @Test + void testCleanConditionFalse_needCleanFalse() { + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory, + List.of(), + new TimeValue(1000), + "samplePid"); + + when(searchAction.apply(any())).thenReturn(searchResponse); + when(searchResponse.getHits()).thenReturn(searchHits); + when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); + OpenSearchResponse openSearchResponse = request.searchWithPIT(searchAction); + + request.clean(cleanAction); + verify(cleanAction, never()).accept(anyString()); + assertFalse(request.isSearchDone()); + assertNull(request.getPitId()); + } + + @Test + void testCleanConditionFalse_pidNull() { + OpenSearchQueryRequest request = + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + sourceBuilder, + factory, + List.of(), + new TimeValue(1000), + null); + + request.clean(cleanAction); + verify(cleanAction, never()).accept(anyString()); + assertFalse(request.isSearchDone()); + assertNull(request.getPitId()); + } + @Test void searchRequest() { request.getSourceBuilder().query(QueryBuilders.termQuery("name", "John")); @@ -159,6 +450,20 @@ void writeTo_unsupported() { UnsupportedOperationException.class, () -> request.writeTo(mock(StreamOutput.class))); } + @Test + void constructor_serialized() throws IOException { + StreamInput stream = mock(StreamInput.class); + OpenSearchStorageEngine engine = mock(OpenSearchStorageEngine.class); + when(stream.readString()).thenReturn("{}"); + when(stream.readStringArray()).thenReturn(new String[] {"sample"}); + OpenSearchIndex index = mock(OpenSearchIndex.class); + when(engine.getTable(null, "sample")).thenReturn(index); + when(stream.readVInt()).thenReturn(2); + when(stream.readGenericValue()).thenReturn("sampleSearchAfter"); + OpenSearchQueryRequest request = new OpenSearchQueryRequest(stream, engine); + assertNotNull(request); + } + private void assertSearchRequest(SearchRequest expected, OpenSearchQueryRequest request) { Function querySearch = searchRequest -> { diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java index 742e76cbd0..bf87840b60 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java @@ -5,13 +5,10 @@ package org.opensearch.sql.opensearch.request; -import static org.junit.Assert.assertThrows; +import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.opensearch.index.query.QueryBuilders.matchAllQuery; -import static org.opensearch.index.query.QueryBuilders.nestedQuery; +import static org.mockito.Mockito.*; +import static org.opensearch.index.query.QueryBuilders.*; import static org.opensearch.search.sort.FieldSortBuilder.DOC_FIELD_NAME; import static org.opensearch.search.sort.SortOrder.ASC; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; @@ -25,21 +22,16 @@ import org.apache.commons.lang3.tuple.Pair; import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.join.ScoreMode; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.DisplayNameGeneration; -import org.junit.jupiter.api.DisplayNameGenerator; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.*; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.search.CreatePitRequest; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchScrollRequest; import org.opensearch.common.unit.TimeValue; -import org.opensearch.index.query.InnerHitBuilder; -import org.opensearch.index.query.NestedQueryBuilder; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.*; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.AggregationBuilder; @@ -47,13 +39,19 @@ import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.search.fetch.subphase.highlight.HighlightBuilder; import org.opensearch.search.sort.FieldSortBuilder; import org.opensearch.search.sort.ScoreSortBuilder; import org.opensearch.search.sort.SortBuilders; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.response.agg.CompositeAggregationParser; @@ -76,11 +74,18 @@ class OpenSearchRequestBuilderTest { @Mock private OpenSearchExprValueFactory exprValueFactory; + @Mock private OpenSearchClient client; + + @Mock private Settings settings; + private OpenSearchRequestBuilder requestBuilder; @BeforeEach void setup() { - requestBuilder = new OpenSearchRequestBuilder(DEFAULT_LIMIT, exprValueFactory); + requestBuilder = new OpenSearchRequestBuilder(DEFAULT_LIMIT, exprValueFactory, settings); + lenient() + .when(settings.getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)) + .thenReturn(true); } @Test @@ -100,14 +105,148 @@ void build_query_request() { .trackScores(true), exprValueFactory, List.of()), - requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT)); + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client)); + } + + @Test + void build_query_request_push_down_size() { + Integer limit = 200; + Integer offset = 0; + requestBuilder.pushDownLimit(limit, offset); + requestBuilder.pushDownTrackedScore(true); + + assertNotNull( + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client)); + } + + @Test + void build_PIT_request_with_correct_size() { + when(settings.getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)).thenReturn(true); + when(client.createPit(any(CreatePitRequest.class))).thenReturn("samplePITId"); + Integer limit = 0; + Integer offset = 0; + requestBuilder.pushDownLimit(limit, offset); + requestBuilder.pushDownPageSize(2); + + assertEquals( + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + new SearchSourceBuilder().from(offset).size(2).timeout(DEFAULT_QUERY_TIMEOUT), + exprValueFactory, + List.of(), + TimeValue.timeValueMinutes(1), + "samplePITId"), + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client)); + } + + @Test + void buildRequestWithPit_pageSizeNull_sizeGreaterThanMaxResultWindow() { + when(settings.getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)).thenReturn(true); + when(client.createPit(any(CreatePitRequest.class))).thenReturn("samplePITId"); + Integer limit = 600; + Integer offset = 0; + int requestedTotalSize = 600; + requestBuilder = new OpenSearchRequestBuilder(requestedTotalSize, exprValueFactory, settings); + requestBuilder.pushDownLimit(limit, offset); + + assertEquals( + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + new SearchSourceBuilder() + .from(offset) + .size(MAX_RESULT_WINDOW - offset) + .timeout(DEFAULT_QUERY_TIMEOUT), + exprValueFactory, + List.of(), + TimeValue.timeValueMinutes(1), + "samplePITId"), + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client)); + } + + @Test + void buildRequestWithPit_pageSizeNull_sizeLessThanMaxResultWindow() { + when(settings.getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)).thenReturn(true); + Integer limit = 400; + Integer offset = 0; + int requestedTotalSize = 400; + requestBuilder = new OpenSearchRequestBuilder(requestedTotalSize, exprValueFactory, settings); + requestBuilder.pushDownLimit(limit, offset); + + assertEquals( + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + new SearchSourceBuilder() + .from(offset) + .size(requestedTotalSize) + .timeout(DEFAULT_QUERY_TIMEOUT), + exprValueFactory, + List.of()), + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client)); + } + + @Test + void buildRequestWithPit_pageSizeNotNull_startFromZero() { + int pageSize = 200; + int offset = 0; + int limit = 400; + requestBuilder.pushDownPageSize(pageSize); + requestBuilder.pushDownLimit(limit, offset); + when(client.createPit(any(CreatePitRequest.class))).thenReturn("samplePITId"); + + assertEquals( + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + new SearchSourceBuilder().from(offset).size(pageSize).timeout(DEFAULT_QUERY_TIMEOUT), + exprValueFactory, + List.of(), + TimeValue.timeValueMinutes(1), + "samplePITId"), + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client)); + } + + @Test + void buildRequestWithPit_pageSizeNotNull_startFromNonZero() { + int pageSize = 200; + int offset = 100; + int limit = 400; + requestBuilder.pushDownPageSize(pageSize); + requestBuilder.pushDownLimit(limit, offset); + assertThrows( + UnsupportedOperationException.class, + () -> { + requestBuilder.build(indexName, 500, TimeValue.timeValueMinutes(1), client); + }); } @Test void build_scroll_request_with_correct_size() { + when(settings.getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)).thenReturn(false); Integer limit = 800; Integer offset = 10; requestBuilder.pushDownLimit(limit, offset); + requestBuilder.getSourceBuilder().fetchSource("a", "b"); + + assertEquals( + new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), + TimeValue.timeValueMinutes(1), + new SearchSourceBuilder() + .from(offset) + .size(MAX_RESULT_WINDOW - offset) + .timeout(DEFAULT_QUERY_TIMEOUT), + exprValueFactory, + List.of()), + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client)); + } + + @Test + void buildRequestWithScroll_pageSizeNull_sizeGreaterThanMaxResultWindow() { + when(settings.getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)).thenReturn(false); + Integer limit = 600; + Integer offset = 0; + int requestedTotalSize = 600; + requestBuilder = new OpenSearchRequestBuilder(requestedTotalSize, exprValueFactory, settings); + requestBuilder.pushDownLimit(limit, offset); assertEquals( new OpenSearchScrollRequest( @@ -119,7 +258,65 @@ void build_scroll_request_with_correct_size() { .timeout(DEFAULT_QUERY_TIMEOUT), exprValueFactory, List.of()), - requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT)); + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client)); + } + + @Test + void buildRequestWithScroll_pageSizeNull_sizeLessThanMaxResultWindow() { + when(settings.getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)).thenReturn(false); + Integer limit = 400; + Integer offset = 0; + int requestedTotalSize = 400; + requestBuilder = new OpenSearchRequestBuilder(requestedTotalSize, exprValueFactory, settings); + requestBuilder.pushDownLimit(limit, offset); + + assertEquals( + new OpenSearchQueryRequest( + new OpenSearchRequest.IndexName("test"), + new SearchSourceBuilder() + .from(offset) + .size(requestedTotalSize) + .timeout(DEFAULT_QUERY_TIMEOUT), + exprValueFactory, + List.of()), + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client)); + } + + @Test + void buildRequestWithScroll_pageSizeNotNull_startFromZero() { + when(settings.getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)).thenReturn(false); + int pageSize = 200; + int offset = 0; + int limit = 400; + requestBuilder.pushDownPageSize(pageSize); + requestBuilder.pushDownLimit(limit, offset); + + assertEquals( + new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), + TimeValue.timeValueMinutes(1), + new SearchSourceBuilder() + .from(offset) + .size(MAX_RESULT_WINDOW - offset) + .timeout(DEFAULT_QUERY_TIMEOUT), + exprValueFactory, + List.of()), + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client)); + } + + @Test + void buildRequestWithScroll_pageSizeNotNull_startFromNonZero() { + when(settings.getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)).thenReturn(false); + int pageSize = 200; + int offset = 100; + int limit = 400; + requestBuilder.pushDownPageSize(pageSize); + requestBuilder.pushDownLimit(limit, offset); + assertThrows( + UnsupportedOperationException.class, + () -> { + requestBuilder.build(indexName, 500, TimeValue.timeValueMinutes(1), client); + }); } @Test @@ -127,7 +324,7 @@ void test_push_down_query() { QueryBuilder query = QueryBuilders.termQuery("intA", 1); requestBuilder.pushDownFilter(query); - var r = requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT); + var r = requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client); Function querySearch = searchRequest -> { assertEquals( @@ -203,6 +400,51 @@ void test_push_down_query_and_sort() { requestBuilder); } + @Test + void test_push_down_query_not_null() { + SearchSourceBuilder sourceBuilder = requestBuilder.getSourceBuilder(); + sourceBuilder.query(QueryBuilders.termQuery("name", "John")); + sourceBuilder.sort(DOC_FIELD_NAME, SortOrder.ASC); + + QueryBuilder query = QueryBuilders.termQuery("intA", 1); + requestBuilder.pushDownFilter(query); + + BoolQueryBuilder expectedQuery = + QueryBuilders.boolQuery().filter(QueryBuilders.termQuery("name", "John")).filter(query); + + SearchSourceBuilder expectedSourceBuilder = + new SearchSourceBuilder() + .from(DEFAULT_OFFSET) + .size(DEFAULT_LIMIT) + .timeout(DEFAULT_QUERY_TIMEOUT) + .query(expectedQuery) + .sort(DOC_FIELD_NAME, SortOrder.ASC); + + assertSearchSourceBuilder(expectedSourceBuilder, requestBuilder); + } + + @Test + void test_push_down_query_with_bool_filter() { + BoolQueryBuilder initialBoolQuery = + QueryBuilders.boolQuery().filter(QueryBuilders.termQuery("name", "John")); + + SearchSourceBuilder sourceBuilder = requestBuilder.getSourceBuilder(); + sourceBuilder.query(initialBoolQuery); + + QueryBuilder newQuery = QueryBuilders.termQuery("intA", 1); + requestBuilder.pushDownFilter(newQuery); + initialBoolQuery.filter(newQuery); + SearchSourceBuilder expectedSourceBuilder = + new SearchSourceBuilder() + .from(DEFAULT_OFFSET) + .size(DEFAULT_LIMIT) + .timeout(DEFAULT_QUERY_TIMEOUT) + .query(initialBoolQuery) + .sort(DOC_FIELD_NAME, SortOrder.ASC); + + assertSearchSourceBuilder(expectedSourceBuilder, requestBuilder); + } + void assertSearchSourceBuilder( SearchSourceBuilder expected, OpenSearchRequestBuilder requestBuilder) throws UnsupportedOperationException { @@ -220,7 +462,7 @@ void assertSearchSourceBuilder( throw new UnsupportedOperationException(); }; requestBuilder - .build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT) + .build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client) .search(querySearch, scrollSearch); } @@ -290,7 +532,7 @@ void test_push_down_project() { .fetchSource("intA", null), exprValueFactory, List.of("intA")), - requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT)); + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client)); } @Test @@ -320,7 +562,7 @@ void test_push_down_project_limit() { .fetchSource("intA", null), exprValueFactory, List.of("intA")), - requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT)); + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client)); } @Test @@ -350,7 +592,7 @@ void test_push_down_project_limit_and_offset() { .fetchSource("intA", null), exprValueFactory, List.of("intA")), - requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT)); + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client)); } @Test @@ -377,7 +619,7 @@ void test_push_down_nested() { assertSearchSourceBuilder( new SearchSourceBuilder() - .query(QueryBuilders.boolQuery().filter(QueryBuilders.boolQuery().must(nestedQuery))) + .query(boolQuery().filter(boolQuery().must(nestedQuery))) .from(DEFAULT_OFFSET) .size(DEFAULT_LIMIT) .timeout(DEFAULT_QUERY_TIMEOUT), @@ -411,7 +653,7 @@ void test_push_down_multiple_nested_with_same_path() { true, new String[] {"message.info", "message.from"}, null))); assertSearchSourceBuilder( new SearchSourceBuilder() - .query(QueryBuilders.boolQuery().filter(QueryBuilders.boolQuery().must(nestedQuery))) + .query(boolQuery().filter(boolQuery().must(nestedQuery))) .from(DEFAULT_OFFSET) .size(DEFAULT_LIMIT) .timeout(DEFAULT_QUERY_TIMEOUT), @@ -444,9 +686,9 @@ void test_push_down_nested_with_filter() { assertSearchSourceBuilder( new SearchSourceBuilder() .query( - QueryBuilders.boolQuery() + boolQuery() .filter( - QueryBuilders.boolQuery() + boolQuery() .must(QueryBuilders.rangeQuery("myNum").gt(3)) .must(nestedQuery))) .from(DEFAULT_OFFSET) @@ -483,7 +725,7 @@ void testPushDownNestedWithNestedFilter() { assertSearchSourceBuilder( new SearchSourceBuilder() - .query(QueryBuilders.boolQuery().filter(QueryBuilders.boolQuery().must(filterQuery))) + .query(boolQuery().filter(boolQuery().must(filterQuery))) .from(DEFAULT_OFFSET) .size(DEFAULT_LIMIT) .timeout(DEFAULT_QUERY_TIMEOUT), @@ -507,6 +749,32 @@ void push_down_highlight_with_repeating_fields() { assertEquals("Duplicate field name in highlight", exception.getMessage()); } + @Test + void test_push_down_highlight_with_pre_tags() { + requestBuilder.pushDownHighlight( + "name", Map.of("pre_tags", new Literal("pre1", DataType.STRING))); + + SearchSourceBuilder sourceBuilder = requestBuilder.getSourceBuilder(); + assertNotNull(sourceBuilder.highlighter()); + assertEquals(1, sourceBuilder.highlighter().fields().size()); + HighlightBuilder.Field field = sourceBuilder.highlighter().fields().get(0); + assertEquals("name", field.name()); + assertEquals("pre1", field.preTags()[0]); + } + + @Test + void test_push_down_highlight_with_post_tags() { + requestBuilder.pushDownHighlight( + "name", Map.of("post_tags", new Literal("post1", DataType.STRING))); + + SearchSourceBuilder sourceBuilder = requestBuilder.getSourceBuilder(); + assertNotNull(sourceBuilder.highlighter()); + assertEquals(1, sourceBuilder.highlighter().fields().size()); + HighlightBuilder.Field field = sourceBuilder.highlighter().fields().get(0); + assertEquals("name", field.name()); + assertEquals("post1", field.postTags()[0]); + } + @Test void push_down_page_size() { requestBuilder.pushDownPageSize(3); @@ -521,7 +789,7 @@ void exception_when_non_zero_offset_and_page_size() { requestBuilder.pushDownLimit(300, 2); assertThrows( UnsupportedOperationException.class, - () -> requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT)); + () -> requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT, client)); } @Test diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java index 3ca566fac6..ef6b86c42a 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java @@ -79,6 +79,9 @@ class OpenSearchIndexTest { @BeforeEach void setUp() { this.index = new OpenSearchIndex(client, settings, "test"); + lenient() + .when(settings.getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)) + .thenReturn(true); } @Test @@ -198,10 +201,11 @@ void implementRelationOperatorOnly() { when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); LogicalPlan plan = index.createScanBuilder(); Integer maxResultWindow = index.getMaxResultWindow(); - final var requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE_LIMIT, exprValueFactory); + final var requestBuilder = + new OpenSearchRequestBuilder(QUERY_SIZE_LIMIT, exprValueFactory, settings); assertEquals( new OpenSearchIndexScan( - client, 200, requestBuilder.build(INDEX_NAME, maxResultWindow, SCROLL_TIMEOUT)), + client, 200, requestBuilder.build(INDEX_NAME, maxResultWindow, SCROLL_TIMEOUT, client)), index.implement(index.optimize(plan))); } @@ -211,10 +215,11 @@ void implementRelationOperatorWithOptimization() { when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); LogicalPlan plan = index.createScanBuilder(); Integer maxResultWindow = index.getMaxResultWindow(); - final var requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE_LIMIT, exprValueFactory); + final var requestBuilder = + new OpenSearchRequestBuilder(QUERY_SIZE_LIMIT, exprValueFactory, settings); assertEquals( new OpenSearchIndexScan( - client, 200, requestBuilder.build(INDEX_NAME, maxResultWindow, SCROLL_TIMEOUT)), + client, 200, requestBuilder.build(INDEX_NAME, maxResultWindow, SCROLL_TIMEOUT, client)), index.implement(plan)); } @@ -243,7 +248,8 @@ void implementOtherLogicalOperators() { include); Integer maxResultWindow = index.getMaxResultWindow(); - final var requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE_LIMIT, exprValueFactory); + final var requestBuilder = + new OpenSearchRequestBuilder(QUERY_SIZE_LIMIT, exprValueFactory, settings); assertEquals( PhysicalPlanDSL.project( PhysicalPlanDSL.dedupe( @@ -255,7 +261,7 @@ void implementOtherLogicalOperators() { client, QUERY_SIZE_LIMIT, requestBuilder.build( - INDEX_NAME, maxResultWindow, SCROLL_TIMEOUT)), + INDEX_NAME, maxResultWindow, SCROLL_TIMEOUT, client)), mappings), exclude), newEvalField), diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanPaginationTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanPaginationTest.java index 2085519b12..e6a17aceaf 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanPaginationTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanPaginationTest.java @@ -56,6 +56,9 @@ void setup() { lenient() .when(settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE)) .thenReturn(TimeValue.timeValueMinutes(1)); + lenient() + .when(settings.getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)) + .thenReturn(true); } @Mock private OpenSearchClient client; @@ -69,12 +72,12 @@ void setup() { @Test void query_empty_result() { mockResponse(client); - var builder = new OpenSearchRequestBuilder(QUERY_SIZE, exprValueFactory); + var builder = new OpenSearchRequestBuilder(QUERY_SIZE, exprValueFactory, settings); try (var indexScan = new OpenSearchIndexScan( client, MAX_RESULT_WINDOW, - builder.build(INDEX_NAME, MAX_RESULT_WINDOW, SCROLL_TIMEOUT))) { + builder.build(INDEX_NAME, MAX_RESULT_WINDOW, SCROLL_TIMEOUT, client))) { indexScan.open(); assertFalse(indexScan.hasNext()); } @@ -96,13 +99,13 @@ void dont_serialize_if_no_cursor() { OpenSearchRequestBuilder builder = mock(); OpenSearchRequest request = mock(); OpenSearchResponse response = mock(); - when(builder.build(any(), anyInt(), any())).thenReturn(request); + when(builder.build(any(), anyInt(), any(), any())).thenReturn(request); when(client.search(any())).thenReturn(response); try (var indexScan = new OpenSearchIndexScan( client, MAX_RESULT_WINDOW, - builder.build(INDEX_NAME, MAX_RESULT_WINDOW, SCROLL_TIMEOUT))) { + builder.build(INDEX_NAME, MAX_RESULT_WINDOW, SCROLL_TIMEOUT, client))) { indexScan.open(); when(request.hasAnotherBatch()).thenReturn(false); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanTest.java index f813d8f551..e680c6b3a6 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanTest.java @@ -5,24 +5,19 @@ package org.opensearch.sql.opensearch.storage.scan; +import static org.junit.Assert.assertNotNull; import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; import static org.opensearch.search.sort.FieldSortBuilder.DOC_FIELD_NAME; import static org.opensearch.search.sort.SortOrder.ASC; import static org.opensearch.sql.data.type.ExprCoreType.STRING; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; +import java.io.*; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -52,6 +47,7 @@ import org.opensearch.search.fetch.subphase.highlight.HighlightBuilder; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.exception.NoCursorException; @@ -77,6 +73,7 @@ class OpenSearchIndexScanTest { public static final int MAX_RESULT_WINDOW = 10000; public static final TimeValue CURSOR_KEEP_ALIVE = TimeValue.timeValueMinutes(1); @Mock private OpenSearchClient client; + @Mock private Settings settings; private final OpenSearchExprValueFactory exprValueFactory = new OpenSearchExprValueFactory( @@ -84,7 +81,11 @@ class OpenSearchIndexScanTest { "name", OpenSearchDataType.of(STRING), "department", OpenSearchDataType.of(STRING))); @BeforeEach - void setup() {} + void setup() { + lenient() + .when(settings.getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)) + .thenReturn(true); + } @Test void explain() { @@ -144,6 +145,49 @@ void serialize(Integer numberOfIncludes) { } } + @SneakyThrows + @ParameterizedTest + @ValueSource(ints = {0, 150}) + void serialize_PIT(Integer numberOfIncludes) { + var searchSourceBuilder = new SearchSourceBuilder().size(4); + + var factory = mock(OpenSearchExprValueFactory.class); + var engine = mock(OpenSearchStorageEngine.class); + var index = mock(OpenSearchIndex.class); + when(engine.getClient()).thenReturn(client); + when(engine.getTable(any(), any())).thenReturn(index); + Map map = mock(Map.class); + when(map.get(any(String.class))).thenReturn("true"); + when(client.meta()).thenReturn(map); + var includes = + Stream.iterate(1, i -> i + 1) + .limit(numberOfIncludes) + .map(i -> "column" + i) + .collect(Collectors.toList()); + var request = + new OpenSearchQueryRequest( + INDEX_NAME, searchSourceBuilder, factory, includes, CURSOR_KEEP_ALIVE, "samplePitId"); + // make a response, so OpenSearchResponse::isEmpty would return true and unset needClean + var response = mock(SearchResponse.class); + when(response.getAggregations()).thenReturn(mock()); + var hits = mock(SearchHits.class); + when(response.getHits()).thenReturn(hits); + SearchHit hit = mock(SearchHit.class); + when(hit.getSortValues()).thenReturn(new String[] {"sample1"}); + when(hits.getHits()).thenReturn(new SearchHit[] {hit}); + request.search((req) -> response, null); + + try (var indexScan = new OpenSearchIndexScan(client, QUERY_SIZE, request)) { + var planSerializer = new PlanSerializer(engine); + var cursor = planSerializer.convertToCursor(indexScan); + var newPlan = planSerializer.convertToPlan(cursor.toString()); + assertNotNull(newPlan); + + verify(client).meta(); + verify(map).get(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER.getKeyValue()); + } + } + @SneakyThrows @Test void throws_io_exception_if_too_short() { @@ -172,10 +216,12 @@ void plan_for_serialization() { void query_empty_result() { mockResponse(client); final var name = new OpenSearchRequest.IndexName("test"); - final var requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE, exprValueFactory); + final var requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE, exprValueFactory, settings); try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan( - client, QUERY_SIZE, requestBuilder.build(name, MAX_RESULT_WINDOW, CURSOR_KEEP_ALIVE))) { + client, + QUERY_SIZE, + requestBuilder.build(name, MAX_RESULT_WINDOW, CURSOR_KEEP_ALIVE, client))) { indexScan.open(); assertFalse(indexScan.hasNext()); } @@ -190,10 +236,10 @@ void query_all_results_with_query() { employee(1, "John", "IT"), employee(2, "Smith", "HR"), employee(3, "Allen", "IT") }); - final var requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE, exprValueFactory); + final var requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE, exprValueFactory, settings); try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan( - client, 10, requestBuilder.build(INDEX_NAME, 10000, CURSOR_KEEP_ALIVE))) { + client, 10, requestBuilder.build(INDEX_NAME, 10000, CURSOR_KEEP_ALIVE, client))) { indexScan.open(); assertAll( @@ -218,10 +264,10 @@ void query_all_results_with_scroll() { new ExprValue[] {employee(1, "John", "IT"), employee(2, "Smith", "HR")}, new ExprValue[] {employee(3, "Allen", "IT")}); - final var requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE, exprValueFactory); + final var requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE, exprValueFactory, settings); try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan( - client, 10, requestBuilder.build(INDEX_NAME, 10000, CURSOR_KEEP_ALIVE))) { + client, 10, requestBuilder.build(INDEX_NAME, 10000, CURSOR_KEEP_ALIVE, client))) { indexScan.open(); assertAll( @@ -248,10 +294,12 @@ void query_some_results_with_query() { }); final int limit = 3; - OpenSearchRequestBuilder builder = new OpenSearchRequestBuilder(0, exprValueFactory); + OpenSearchRequestBuilder builder = new OpenSearchRequestBuilder(0, exprValueFactory, settings); try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan( - client, limit, builder.build(INDEX_NAME, MAX_RESULT_WINDOW, CURSOR_KEEP_ALIVE))) { + client, + limit, + builder.build(INDEX_NAME, MAX_RESULT_WINDOW, CURSOR_KEEP_ALIVE, client))) { indexScan.open(); assertAll( @@ -269,10 +317,12 @@ void query_some_results_with_query() { @Test void query_some_results_with_scroll() { mockTwoPageResponse(client); - final var requestuilder = new OpenSearchRequestBuilder(10, exprValueFactory); + final var requestuilder = new OpenSearchRequestBuilder(10, exprValueFactory, settings); try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan( - client, 3, requestuilder.build(INDEX_NAME, MAX_RESULT_WINDOW, CURSOR_KEEP_ALIVE))) { + client, + 3, + requestuilder.build(INDEX_NAME, MAX_RESULT_WINDOW, CURSOR_KEEP_ALIVE, client))) { indexScan.open(); assertAll( @@ -306,12 +356,13 @@ void query_results_limited_by_query_size() { }); final int defaultQuerySize = 2; - final var requestBuilder = new OpenSearchRequestBuilder(defaultQuerySize, exprValueFactory); + final var requestBuilder = + new OpenSearchRequestBuilder(defaultQuerySize, exprValueFactory, settings); try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan( client, defaultQuerySize, - requestBuilder.build(INDEX_NAME, QUERY_SIZE, CURSOR_KEEP_ALIVE))) { + requestBuilder.build(INDEX_NAME, QUERY_SIZE, CURSOR_KEEP_ALIVE, client))) { indexScan.open(); assertAll( @@ -368,7 +419,7 @@ void push_down_highlight_with_arguments() { } private PushDownAssertion assertThat() { - return new PushDownAssertion(client, exprValueFactory); + return new PushDownAssertion(client, exprValueFactory, settings); } private static class PushDownAssertion { @@ -377,9 +428,10 @@ private static class PushDownAssertion { private final OpenSearchResponse response; private final OpenSearchExprValueFactory factory; - public PushDownAssertion(OpenSearchClient client, OpenSearchExprValueFactory valueFactory) { + public PushDownAssertion( + OpenSearchClient client, OpenSearchExprValueFactory valueFactory, Settings settings) { this.client = client; - this.requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE, valueFactory); + this.requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE, valueFactory, settings); this.response = mock(OpenSearchResponse.class); this.factory = valueFactory; @@ -411,7 +463,9 @@ PushDownAssertion shouldQueryHighlight(QueryBuilder query, HighlightBuilder high when(client.search(request)).thenReturn(response); var indexScan = new OpenSearchIndexScan( - client, QUERY_SIZE, requestBuilder.build(EMPLOYEES_INDEX, 10000, CURSOR_KEEP_ALIVE)); + client, + QUERY_SIZE, + requestBuilder.build(EMPLOYEES_INDEX, 10000, CURSOR_KEEP_ALIVE, client)); indexScan.open(); return this; } @@ -429,7 +483,9 @@ PushDownAssertion shouldQuery(QueryBuilder expected) { when(client.search(request)).thenReturn(response); var indexScan = new OpenSearchIndexScan( - client, 10000, requestBuilder.build(EMPLOYEES_INDEX, 10000, CURSOR_KEEP_ALIVE)); + client, + 10000, + requestBuilder.build(EMPLOYEES_INDEX, 10000, CURSOR_KEEP_ALIVE, client)); indexScan.open(); return this; }