diff --git a/core/build.gradle b/core/build.gradle index a338b8f368..624c10fd6b 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -57,6 +57,7 @@ dependencies { testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') testImplementation group: 'org.hamcrest', name: 'hamcrest-library', version: '2.1' testImplementation group: 'org.mockito', name: 'mockito-core', version: '3.12.4' + testImplementation group: 'org.mockito', name: 'mockito-inline', version: '3.12.4' testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: '3.12.4' } diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index de7d72fd0d..6418f92686 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -6,6 +6,7 @@ package org.opensearch.sql.analysis; +import static org.opensearch.sql.analysis.DataSourceSchemaIdentifierNameResolver.DEFAULT_DATASOURCE_NAME; import static org.opensearch.sql.ast.tree.Sort.NullOrder.NULL_FIRST; import static org.opensearch.sql.ast.tree.Sort.NullOrder.NULL_LAST; import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC; @@ -42,13 +43,16 @@ import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.tree.AD; import org.opensearch.sql.ast.tree.Aggregation; +import org.opensearch.sql.ast.tree.CloseCursor; import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.Eval; +import org.opensearch.sql.ast.tree.FetchCursor; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.Limit; import org.opensearch.sql.ast.tree.ML; +import org.opensearch.sql.ast.tree.Paginate; import org.opensearch.sql.ast.tree.Parse; import org.opensearch.sql.ast.tree.Project; import org.opensearch.sql.ast.tree.RareTopN; @@ -80,12 +84,15 @@ import org.opensearch.sql.expression.parse.ParseExpression; import org.opensearch.sql.planner.logical.LogicalAD; import org.opensearch.sql.planner.logical.LogicalAggregation; +import org.opensearch.sql.planner.logical.LogicalCloseCursor; import org.opensearch.sql.planner.logical.LogicalDedupe; import org.opensearch.sql.planner.logical.LogicalEval; +import org.opensearch.sql.planner.logical.LogicalFetchCursor; import org.opensearch.sql.planner.logical.LogicalFilter; import org.opensearch.sql.planner.logical.LogicalLimit; import org.opensearch.sql.planner.logical.LogicalML; import org.opensearch.sql.planner.logical.LogicalMLCommons; +import org.opensearch.sql.planner.logical.LogicalPaginate; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalProject; import org.opensearch.sql.planner.logical.LogicalRareTopN; @@ -208,7 +215,6 @@ public LogicalPlan visitTableFunction(TableFunction node, AnalysisContext contex tableFunctionImplementation.applyArguments()); } - @Override public LogicalPlan visitLimit(Limit node, AnalysisContext context) { LogicalPlan child = node.getChild().get(0).accept(this, context); @@ -561,6 +567,23 @@ public LogicalPlan visitML(ML node, AnalysisContext context) { return new LogicalML(child, node.getArguments()); } + @Override + public LogicalPlan visitPaginate(Paginate paginate, AnalysisContext context) { + LogicalPlan child = paginate.getChild().get(0).accept(this, context); + return new LogicalPaginate(paginate.getPageSize(), List.of(child)); + } + + @Override + public LogicalPlan visitFetchCursor(FetchCursor cursor, AnalysisContext context) { + return new LogicalFetchCursor(cursor.getCursor(), + dataSourceService.getDataSource(DEFAULT_DATASOURCE_NAME).getStorageEngine()); + } + + @Override + public LogicalPlan visitCloseCursor(CloseCursor closeCursor, AnalysisContext context) { + return new LogicalCloseCursor(closeCursor.getChild().get(0).accept(this, context)); + } + /** * The first argument is always "asc", others are optional. * Given nullFirst argument, use its value. Otherwise just use DEFAULT_ASC/DESC. @@ -576,5 +599,4 @@ private SortOption analyzeSortOption(List fieldArgs) { } return asc ? SortOption.DEFAULT_ASC : SortOption.DEFAULT_DESC; } - } diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index d2ebb9eb99..3e81509fae 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -41,13 +41,16 @@ import org.opensearch.sql.ast.statement.Statement; import org.opensearch.sql.ast.tree.AD; import org.opensearch.sql.ast.tree.Aggregation; +import org.opensearch.sql.ast.tree.CloseCursor; import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.Eval; +import org.opensearch.sql.ast.tree.FetchCursor; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.Limit; import org.opensearch.sql.ast.tree.ML; +import org.opensearch.sql.ast.tree.Paginate; import org.opensearch.sql.ast.tree.Parse; import org.opensearch.sql.ast.tree.Project; import org.opensearch.sql.ast.tree.RareTopN; @@ -294,4 +297,16 @@ public T visitQuery(Query node, C context) { public T visitExplain(Explain node, C context) { return visitStatement(node, context); } + + public T visitPaginate(Paginate paginate, C context) { + return visitChildren(paginate, context); + } + + public T visitFetchCursor(FetchCursor cursor, C context) { + return visitChildren(cursor, context); + } + + public T visitCloseCursor(CloseCursor closeCursor, C context) { + return visitChildren(closeCursor, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/ast/statement/Query.java b/core/src/main/java/org/opensearch/sql/ast/statement/Query.java index 17682cd47b..82efdde4dd 100644 --- a/core/src/main/java/org/opensearch/sql/ast/statement/Query.java +++ b/core/src/main/java/org/opensearch/sql/ast/statement/Query.java @@ -27,6 +27,7 @@ public class Query extends Statement { protected final UnresolvedPlan plan; + protected final int fetchSize; @Override public R accept(AbstractNodeVisitor visitor, C context) { diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/CloseCursor.java b/core/src/main/java/org/opensearch/sql/ast/tree/CloseCursor.java new file mode 100644 index 0000000000..cf82c2b070 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/tree/CloseCursor.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import java.util.List; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; + +/** + * AST node to represent close cursor operation. + * Actually a wrapper to the AST. + */ +public class CloseCursor extends UnresolvedPlan { + + /** + * An instance of {@link FetchCursor}. + */ + private UnresolvedPlan cursor; + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitCloseCursor(this, context); + } + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.cursor = child; + return this; + } + + @Override + public List getChild() { + return List.of(cursor); + } +} diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/FetchCursor.java b/core/src/main/java/org/opensearch/sql/ast/tree/FetchCursor.java new file mode 100644 index 0000000000..aa327c295b --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/tree/FetchCursor.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +/** + * An unresolved plan that represents fetching the next + * batch in paginationed plan. + */ +@RequiredArgsConstructor +@EqualsAndHashCode(callSuper = false) +public class FetchCursor extends UnresolvedPlan { + @Getter + final String cursor; + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitFetchCursor(this, context); + } + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + throw new UnsupportedOperationException("Cursor unresolved plan does not support children"); + } +} diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Paginate.java b/core/src/main/java/org/opensearch/sql/ast/tree/Paginate.java new file mode 100644 index 0000000000..55e0e8c7a6 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Paginate.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import java.util.List; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; + +/** + * AST node to represent pagination operation. + * Actually a wrapper to the AST. + */ +@RequiredArgsConstructor +@EqualsAndHashCode(callSuper = false) +@ToString +public class Paginate extends UnresolvedPlan { + @Getter + private final int pageSize; + private UnresolvedPlan child; + + public Paginate(int pageSize, UnresolvedPlan child) { + this.pageSize = pageSize; + this.child = child; + } + + @Override + public List getChild() { + return List.of(child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitPaginate(this, context); + } + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } +} diff --git a/core/src/main/java/org/opensearch/sql/exception/NoCursorException.java b/core/src/main/java/org/opensearch/sql/exception/NoCursorException.java new file mode 100644 index 0000000000..9383bece57 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/exception/NoCursorException.java @@ -0,0 +1,13 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.exception; + +/** + * This should be thrown on serialization of a PhysicalPlan tree if paging is finished. + * Processing of such exception should outcome of responding no cursor to the user. + */ +public class NoCursorException extends RuntimeException { +} diff --git a/core/src/main/java/org/opensearch/sql/exception/UnsupportedCursorRequestException.java b/core/src/main/java/org/opensearch/sql/exception/UnsupportedCursorRequestException.java new file mode 100644 index 0000000000..6ed8e02e5f --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/exception/UnsupportedCursorRequestException.java @@ -0,0 +1,12 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.exception; + +/** + * This should be thrown by V2 engine to support fallback scenario. + */ +public class UnsupportedCursorRequestException extends RuntimeException { +} diff --git a/core/src/main/java/org/opensearch/sql/executor/ExecutionEngine.java b/core/src/main/java/org/opensearch/sql/executor/ExecutionEngine.java index 1936a0f517..9465da22c9 100644 --- a/core/src/main/java/org/opensearch/sql/executor/ExecutionEngine.java +++ b/core/src/main/java/org/opensearch/sql/executor/ExecutionEngine.java @@ -14,6 +14,7 @@ import org.opensearch.sql.common.response.ResponseListener; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.executor.pagination.Cursor; import org.opensearch.sql.planner.physical.PhysicalPlan; /** @@ -53,6 +54,7 @@ void execute(PhysicalPlan plan, ExecutionContext context, class QueryResponse { private final Schema schema; private final List results; + private final Cursor cursor; } @Data diff --git a/core/src/main/java/org/opensearch/sql/executor/execution/CommandPlan.java b/core/src/main/java/org/opensearch/sql/executor/execution/CommandPlan.java new file mode 100644 index 0000000000..0ea5266084 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/executor/execution/CommandPlan.java @@ -0,0 +1,53 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.sql.executor.execution; + +import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.common.response.ResponseListener; +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.executor.QueryId; +import org.opensearch.sql.executor.QueryService; + +/** + * Query plan which does not reflect a search query being executed. + * It contains a command or an action, for example, a DDL query. + */ +public class CommandPlan extends AbstractPlan { + + /** + * The query plan ast. + */ + protected final UnresolvedPlan plan; + + /** + * Query service. + */ + protected final QueryService queryService; + + protected final ResponseListener listener; + + /** Constructor. */ + public CommandPlan(QueryId queryId, UnresolvedPlan plan, QueryService queryService, + ResponseListener listener) { + super(queryId); + this.plan = plan; + this.queryService = queryService; + this.listener = listener; + } + + @Override + public void execute() { + queryService.execute(plan, listener); + } + + @Override + public void explain(ResponseListener listener) { + throw new UnsupportedOperationException("CommandPlan does not support explain"); + } +} diff --git a/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlan.java b/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlan.java index af5c032d49..aeecf3e76f 100644 --- a/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlan.java +++ b/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlan.java @@ -8,6 +8,9 @@ package org.opensearch.sql.executor.execution; +import java.util.Optional; +import org.apache.commons.lang3.NotImplementedException; +import org.opensearch.sql.ast.tree.Paginate; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.response.ResponseListener; import org.opensearch.sql.executor.ExecutionEngine; @@ -15,9 +18,7 @@ import org.opensearch.sql.executor.QueryService; /** - * Query plan. Which includes. - * - *

select query. + * Query plan which includes a select query. */ public class QueryPlan extends AbstractPlan { @@ -33,25 +34,51 @@ public class QueryPlan extends AbstractPlan { protected final ResponseListener listener; - /** constructor. */ + protected final Optional pageSize; + + /** Constructor. */ + public QueryPlan( + QueryId queryId, + UnresolvedPlan plan, + QueryService queryService, + ResponseListener listener) { + super(queryId); + this.plan = plan; + this.queryService = queryService; + this.listener = listener; + this.pageSize = Optional.empty(); + } + + /** Constructor with page size. */ public QueryPlan( QueryId queryId, UnresolvedPlan plan, + int pageSize, QueryService queryService, ResponseListener listener) { super(queryId); this.plan = plan; this.queryService = queryService; this.listener = listener; + this.pageSize = Optional.of(pageSize); } @Override public void execute() { - queryService.execute(plan, listener); + if (pageSize.isPresent()) { + queryService.execute(new Paginate(pageSize.get(), plan), listener); + } else { + queryService.execute(plan, listener); + } } @Override public void explain(ResponseListener listener) { - queryService.explain(plan, listener); + if (pageSize.isPresent()) { + listener.onFailure(new NotImplementedException( + "`explain` feature for paginated requests is not implemented yet.")); + } else { + queryService.explain(plan, listener); + } } } diff --git a/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlanFactory.java b/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlanFactory.java index 851381cc7a..3273eb3c18 100644 --- a/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlanFactory.java +++ b/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlanFactory.java @@ -17,10 +17,15 @@ import org.opensearch.sql.ast.statement.Explain; import org.opensearch.sql.ast.statement.Query; import org.opensearch.sql.ast.statement.Statement; +import org.opensearch.sql.ast.tree.CloseCursor; +import org.opensearch.sql.ast.tree.FetchCursor; +import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.response.ResponseListener; +import org.opensearch.sql.exception.UnsupportedCursorRequestException; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.executor.QueryId; import org.opensearch.sql.executor.QueryService; +import org.opensearch.sql.executor.pagination.CanPaginateVisitor; /** * QueryExecution Factory. @@ -39,7 +44,7 @@ public class QueryPlanFactory private final QueryService queryService; /** - * NO_CONSUMER_RESPONSE_LISTENER should never been called. It is only used as constructor + * NO_CONSUMER_RESPONSE_LISTENER should never be called. It is only used as constructor * parameter of {@link QueryPlan}. */ @VisibleForTesting @@ -69,32 +74,67 @@ public AbstractPlan create( return statement.accept(this, Pair.of(queryListener, explainListener)); } + /** + * Creates a QueryPlan from a cursor. + */ + public AbstractPlan create(String cursor, boolean isExplain, + ResponseListener queryResponseListener, + ResponseListener explainListener) { + QueryId queryId = QueryId.queryId(); + var plan = new QueryPlan(queryId, new FetchCursor(cursor), queryService, queryResponseListener); + return isExplain ? new ExplainPlan(queryId, plan, explainListener) : plan; + } + + boolean canConvertToCursor(UnresolvedPlan plan) { + return plan.accept(new CanPaginateVisitor(), null); + } + + /** + * Creates a {@link CloseCursor} command on a cursor. + */ + public AbstractPlan createCloseCursor(String cursor, + ResponseListener queryResponseListener) { + return new CommandPlan(QueryId.queryId(), new CloseCursor().attach(new FetchCursor(cursor)), + queryService, queryResponseListener); + } + @Override public AbstractPlan visitQuery( Query node, - Pair< - Optional>, - Optional>> + Pair>, + Optional>> context) { Preconditions.checkArgument( context.getLeft().isPresent(), "[BUG] query listener must be not null"); - return new QueryPlan(QueryId.queryId(), node.getPlan(), queryService, context.getLeft().get()); + if (node.getFetchSize() > 0) { + if (canConvertToCursor(node.getPlan())) { + return new QueryPlan(QueryId.queryId(), node.getPlan(), node.getFetchSize(), + queryService, + context.getLeft().get()); + } else { + // This should be picked up by the legacy engine. + throw new UnsupportedCursorRequestException(); + } + } else { + return new QueryPlan(QueryId.queryId(), node.getPlan(), queryService, + context.getLeft().get()); + } } @Override public AbstractPlan visitExplain( Explain node, - Pair< - Optional>, - Optional>> + Pair>, + Optional>> context) { Preconditions.checkArgument( context.getRight().isPresent(), "[BUG] explain listener must be not null"); return new ExplainPlan( QueryId.queryId(), - create(node.getStatement(), Optional.of(NO_CONSUMER_RESPONSE_LISTENER), Optional.empty()), + create(node.getStatement(), + Optional.of(NO_CONSUMER_RESPONSE_LISTENER), Optional.empty()), context.getRight().get()); } } diff --git a/core/src/main/java/org/opensearch/sql/executor/pagination/CanPaginateVisitor.java b/core/src/main/java/org/opensearch/sql/executor/pagination/CanPaginateVisitor.java new file mode 100644 index 0000000000..3164794abb --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/executor/pagination/CanPaginateVisitor.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.executor.pagination; + +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.AllFields; +import org.opensearch.sql.ast.tree.Project; +import org.opensearch.sql.ast.tree.Relation; + +/** + * Use this unresolved plan visitor to check if a plan can be serialized by PaginatedPlanCache. + * If plan.accept(new CanPaginateVisitor(...)) returns true, + * then PaginatedPlanCache.convertToCursor will succeed. Otherwise, it will fail. + * The purpose of this visitor is to activate legacy engine fallback mechanism. + * Currently, the conditions are: + * - only projection of a relation is supported. + * - projection only has * (a.k.a. allFields). + * - Relation only scans one table + * - The table is an open search index. + * So it accepts only queries like `select * from $index` + * See PaginatedPlanCache.canConvertToCursor for usage. + */ +public class CanPaginateVisitor extends AbstractNodeVisitor { + + @Override + public Boolean visitRelation(Relation node, Object context) { + if (!node.getChild().isEmpty()) { + // Relation instance should never have a child, but check just in case. + return Boolean.FALSE; + } + + return Boolean.TRUE; + } + + @Override + public Boolean visitChildren(Node node, Object context) { + return Boolean.FALSE; + } + + @Override + public Boolean visitProject(Project node, Object context) { + // Allow queries with 'SELECT *' only. Those restriction could be removed, but consider + // in-memory aggregation performed by window function (see WindowOperator). + // SELECT max(age) OVER (PARTITION BY city) ... + var projections = node.getProjectList(); + if (projections.size() != 1) { + return Boolean.FALSE; + } + + if (!(projections.get(0) instanceof AllFields)) { + return Boolean.FALSE; + } + + var children = node.getChild(); + if (children.size() != 1) { + return Boolean.FALSE; + } + + return children.get(0).accept(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/executor/pagination/Cursor.java b/core/src/main/java/org/opensearch/sql/executor/pagination/Cursor.java new file mode 100644 index 0000000000..bb320f5c67 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/executor/pagination/Cursor.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.executor.pagination; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; + +@EqualsAndHashCode +@RequiredArgsConstructor +public class Cursor { + public static final Cursor None = new Cursor(null); + + @Getter + private final String data; + + public String toString() { + return data; + } +} diff --git a/core/src/main/java/org/opensearch/sql/executor/pagination/PlanSerializer.java b/core/src/main/java/org/opensearch/sql/executor/pagination/PlanSerializer.java new file mode 100644 index 0000000000..07cf174d73 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/executor/pagination/PlanSerializer.java @@ -0,0 +1,127 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.executor.pagination; + +import com.google.common.hash.HashCode; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.NotSerializableException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.util.zip.Deflater; +import java.util.zip.GZIPInputStream; +import java.util.zip.GZIPOutputStream; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.exception.NoCursorException; +import org.opensearch.sql.planner.SerializablePlan; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.storage.StorageEngine; + +/** + * This class is entry point to paged requests. It is responsible to cursor serialization + * and deserialization. + */ +@RequiredArgsConstructor +public class PlanSerializer { + public static final String CURSOR_PREFIX = "n:"; + + private final StorageEngine engine; + + + /** + * Converts a physical plan tree to a cursor. + */ + public Cursor convertToCursor(PhysicalPlan plan) { + try { + return new Cursor(CURSOR_PREFIX + + serialize(((SerializablePlan) plan).getPlanForSerialization())); + // ClassCastException thrown when a plan in the tree doesn't implement SerializablePlan + } catch (NotSerializableException | ClassCastException | NoCursorException e) { + return Cursor.None; + } + } + + /** + * Serializes and compresses the object. + * @param object The object. + * @return Encoded binary data. + */ + protected String serialize(Serializable object) throws NotSerializableException { + try { + ByteArrayOutputStream output = new ByteArrayOutputStream(); + ObjectOutputStream objectOutput = new ObjectOutputStream(output); + objectOutput.writeObject(object); + objectOutput.flush(); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + // GZIP provides 35-45%, lzma from apache commons-compress has few % better compression + GZIPOutputStream gzip = new GZIPOutputStream(out) { { + this.def.setLevel(Deflater.BEST_COMPRESSION); + } }; + gzip.write(output.toByteArray()); + gzip.close(); + + return HashCode.fromBytes(out.toByteArray()).toString(); + } catch (NotSerializableException e) { + throw e; + } catch (IOException e) { + throw new IllegalStateException("Failed to serialize: " + object, e); + } + } + + /** + * Decompresses and deserializes the binary data. + * @param code Encoded binary data. + * @return An object. + */ + protected Serializable deserialize(String code) { + try { + GZIPInputStream gzip = new GZIPInputStream( + new ByteArrayInputStream(HashCode.fromString(code).asBytes())); + ObjectInputStream objectInput = new CursorDeserializationStream( + new ByteArrayInputStream(gzip.readAllBytes())); + return (Serializable) objectInput.readObject(); + } catch (Exception e) { + throw new IllegalStateException("Failed to deserialize object", e); + } + } + + /** + * Converts a cursor to a physical plan tree. + */ + public PhysicalPlan convertToPlan(String cursor) { + if (!cursor.startsWith(CURSOR_PREFIX)) { + throw new UnsupportedOperationException("Unsupported cursor"); + } + try { + return (PhysicalPlan) deserialize(cursor.substring(CURSOR_PREFIX.length())); + } catch (Exception e) { + throw new UnsupportedOperationException("Unsupported cursor", e); + } + } + + /** + * This function is used in testing only, to get access to {@link CursorDeserializationStream}. + */ + public CursorDeserializationStream getCursorDeserializationStream(InputStream in) + throws IOException { + return new CursorDeserializationStream(in); + } + + public class CursorDeserializationStream extends ObjectInputStream { + public CursorDeserializationStream(InputStream in) throws IOException { + super(in); + } + + @Override + public Object resolveObject(Object obj) throws IOException { + return obj.equals("engine") ? engine : obj; + } + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java b/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java index d4cdb528fa..af234027e6 100644 --- a/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java +++ b/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java @@ -3,12 +3,14 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.planner; +import org.opensearch.sql.executor.pagination.PlanSerializer; import org.opensearch.sql.planner.logical.LogicalAggregation; +import org.opensearch.sql.planner.logical.LogicalCloseCursor; import org.opensearch.sql.planner.logical.LogicalDedupe; import org.opensearch.sql.planner.logical.LogicalEval; +import org.opensearch.sql.planner.logical.LogicalFetchCursor; import org.opensearch.sql.planner.logical.LogicalFilter; import org.opensearch.sql.planner.logical.LogicalLimit; import org.opensearch.sql.planner.logical.LogicalNested; @@ -23,6 +25,7 @@ import org.opensearch.sql.planner.logical.LogicalValues; import org.opensearch.sql.planner.logical.LogicalWindow; import org.opensearch.sql.planner.physical.AggregationOperator; +import org.opensearch.sql.planner.physical.CursorCloseOperator; import org.opensearch.sql.planner.physical.DedupeOperator; import org.opensearch.sql.planner.physical.EvalOperator; import org.opensearch.sql.planner.physical.FilterOperator; @@ -148,9 +151,18 @@ public PhysicalPlan visitRelation(LogicalRelation node, C context) { + "implementing and optimizing logical plan with relation involved"); } + @Override + public PhysicalPlan visitFetchCursor(LogicalFetchCursor plan, C context) { + return new PlanSerializer(plan.getEngine()).convertToPlan(plan.getCursor()); + } + + @Override + public PhysicalPlan visitCloseCursor(LogicalCloseCursor node, C context) { + return new CursorCloseOperator(visitChild(node, context)); + } + protected PhysicalPlan visitChild(LogicalPlan node, C context) { // Logical operators visited here must have a single child return node.getChild().get(0).accept(this, context); } - } diff --git a/core/src/main/java/org/opensearch/sql/planner/SerializablePlan.java b/core/src/main/java/org/opensearch/sql/planner/SerializablePlan.java new file mode 100644 index 0000000000..ab195da5bf --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/SerializablePlan.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner; + +import java.io.Externalizable; + +/** + * All subtypes of PhysicalPlan which needs to be serialized (in cursor, for pagination feature) + * should follow one of the following options. + *

    + *
  • Both: + *
      + *
    • Override both methods from {@link Externalizable}.
    • + *
    • Define a public no-arg constructor.
    • + *
    + *
  • + *
  • + * Overwrite {@link #getPlanForSerialization} to return + * another instance of {@link SerializablePlan}. + *
  • + *
+ */ +public interface SerializablePlan extends Externalizable { + + /** + * Override to return child or delegated plan, so parent plan should skip this one + * for serialization, but it should try to serialize grandchild plan. + * Imagine plan structure like this + *
+   *    A         -> this
+   *    `- B      -> child
+   *      `- C    -> this
+   * 
+ * In that case only plans A and C should be attempted to serialize. + * It is needed to skip a `ResourceMonitorPlan` instance only, actually. + * + *
{@code
+   *    * A.writeObject(B.getPlanForSerialization());
+   *  }
+ * @return Next plan for serialization. + */ + default SerializablePlan getPlanForSerialization() { + return this; + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalCloseCursor.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalCloseCursor.java new file mode 100644 index 0000000000..e5c30a4f4f --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalCloseCursor.java @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.logical; + +import java.util.List; +import lombok.EqualsAndHashCode; +import lombok.ToString; + +/** + * A logical plan node which wraps {@link org.opensearch.sql.planner.LogicalCursor} + * and represent a cursor close operation. + */ +@ToString +@EqualsAndHashCode(callSuper = false) +public class LogicalCloseCursor extends LogicalPlan { + + public LogicalCloseCursor(LogicalPlan child) { + super(List.of(child)); + } + + @Override + public R accept(LogicalPlanNodeVisitor visitor, C context) { + return visitor.visitCloseCursor(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalFetchCursor.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalFetchCursor.java new file mode 100644 index 0000000000..e4a0482aac --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalFetchCursor.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.logical; + +import java.util.List; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.logical.LogicalPlanNodeVisitor; +import org.opensearch.sql.storage.StorageEngine; + +/** + * A plan node which represents operation of fetching a next page from the cursor. + */ +@EqualsAndHashCode(callSuper = false) +@ToString +public class LogicalFetchCursor extends LogicalPlan { + @Getter + private final String cursor; + + @Getter + private final StorageEngine engine; + + /** + * LogicalCursor constructor. Does not have child plans. + */ + public LogicalFetchCursor(String cursor, StorageEngine engine) { + super(List.of()); + this.cursor = cursor; + this.engine = engine; + } + + @Override + public R accept(LogicalPlanNodeVisitor visitor, C context) { + return visitor.visitFetchCursor(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPaginate.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPaginate.java new file mode 100644 index 0000000000..372f9dcf0b --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPaginate.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.logical; + +import java.util.List; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; + +/** + * LogicalPaginate represents pagination operation for underlying plan. + */ +@ToString +@EqualsAndHashCode(callSuper = false) +public class LogicalPaginate extends LogicalPlan { + @Getter + private final int pageSize; + + public LogicalPaginate(int pageSize, List childPlans) { + super(childPlans); + this.pageSize = pageSize; + } + + @Override + public R accept(LogicalPlanNodeVisitor visitor, C context) { + return visitor.visitPaginate(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java index 411d9a51be..c0e253ca50 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java @@ -11,19 +11,18 @@ import java.util.Arrays; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import lombok.experimental.UtilityClass; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort.SortOption; -import org.opensearch.sql.data.model.ExprCollectionValue; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.LiteralExpression; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.aggregation.NamedAggregator; import org.opensearch.sql.expression.window.WindowDefinition; +import org.opensearch.sql.storage.StorageEngine; import org.opensearch.sql.storage.Table; /** @@ -32,6 +31,10 @@ @UtilityClass public class LogicalPlanDSL { + public static LogicalPlan fetchCursor(String cursor, StorageEngine engine) { + return new LogicalFetchCursor(cursor, engine); + } + public static LogicalPlan write(LogicalPlan input, Table table, List columns) { return new LogicalWrite(input, table, columns); } @@ -54,6 +57,10 @@ public static LogicalPlan rename( return new LogicalRename(input, renameMap); } + public static LogicalPlan paginate(LogicalPlan input, int fetchSize) { + return new LogicalPaginate(fetchSize, List.of(input)); + } + public static LogicalPlan project(LogicalPlan input, NamedExpression... fields) { return new LogicalProject(input, Arrays.asList(fields), ImmutableList.of()); } diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java index d7ab75f869..dbe21d38e0 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java @@ -104,4 +104,16 @@ public R visitML(LogicalML plan, C context) { public R visitAD(LogicalAD plan, C context) { return visitNode(plan, context); } + + public R visitPaginate(LogicalPaginate plan, C context) { + return visitNode(plan, context); + } + + public R visitFetchCursor(LogicalFetchCursor plan, C context) { + return visitNode(plan, context); + } + + public R visitCloseCursor(LogicalCloseCursor plan, C context) { + return visitNode(plan, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java b/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java index 097c5ff8ce..be1227c1da 100644 --- a/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java +++ b/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java @@ -55,6 +55,7 @@ public static LogicalPlanOptimizer create() { TableScanPushDown.PUSH_DOWN_AGGREGATION, TableScanPushDown.PUSH_DOWN_SORT, TableScanPushDown.PUSH_DOWN_LIMIT, + new PushDownPageSize(), TableScanPushDown.PUSH_DOWN_HIGHLIGHT, TableScanPushDown.PUSH_DOWN_NESTED, TableScanPushDown.PUSH_DOWN_PROJECT, diff --git a/core/src/main/java/org/opensearch/sql/planner/optimizer/PushDownPageSize.java b/core/src/main/java/org/opensearch/sql/planner/optimizer/PushDownPageSize.java new file mode 100644 index 0000000000..8150de824d --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/optimizer/PushDownPageSize.java @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.optimizer; + +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Optional; +import org.opensearch.sql.planner.logical.LogicalPaginate; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.storage.read.TableScanBuilder; + +/** + * A {@link LogicalPlanOptimizer} rule that pushes down page size + * to table scan builder. + */ +public class PushDownPageSize implements Rule { + @Override + public Pattern pattern() { + return Pattern.typeOf(LogicalPaginate.class) + .matching(lp -> findTableScanBuilder(lp).isPresent()); + } + + @Override + public LogicalPlan apply(LogicalPaginate plan, Captures captures) { + + var builder = findTableScanBuilder(plan).orElseThrow(); + if (!builder.pushDownPageSize(plan)) { + throw new IllegalStateException("Failed to push down LogicalPaginate"); + } + return plan.getChild().get(0); + } + + private Optional findTableScanBuilder(LogicalPaginate logicalPaginate) { + Deque plans = new ArrayDeque<>(); + plans.add(logicalPaginate); + do { + var plan = plans.removeFirst(); + var children = plan.getChild(); + if (children.stream().anyMatch(TableScanBuilder.class::isInstance)) { + if (children.size() > 1) { + throw new UnsupportedOperationException( + "Unsupported plan: relation operator cannot have siblings"); + } + return Optional.of((TableScanBuilder) children.get(0)); + } + plans.addAll(children); + } while (!plans.isEmpty()); + return Optional.empty(); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/CursorCloseOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/CursorCloseOperator.java new file mode 100644 index 0000000000..7921d0dd50 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/physical/CursorCloseOperator.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.physical; + +import java.util.List; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.executor.ExecutionEngine; + +/** + * A plan node which blocks issuing a request in {@link #open} and + * getting results in {@link #hasNext}, but doesn't block releasing resources in {@link #close}. + * Designed to be on top of the deserialized tree. + */ +@RequiredArgsConstructor +public class CursorCloseOperator extends PhysicalPlan { + + // Entire deserialized from cursor plan tree + private final PhysicalPlan input; + + @Override + public R accept(PhysicalPlanNodeVisitor visitor, C context) { + return visitor.visitCursorClose(this, context); + } + + @Override + public boolean hasNext() { + return false; + } + + @Override + public ExprValue next() { + throw new IllegalStateException(); + } + + @Override + public List getChild() { + return List.of(input); + } + + /** + * Provides an empty schema, because this plan node is always located on the top of the tree. + */ + @Override + public ExecutionEngine.Schema schema() { + return new ExecutionEngine.Schema(List.of()); + } + + @Override + public void open() { + // no-op, no search should be invoked. + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/FilterOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/FilterOperator.java index 86cd411a2d..4b5045d24e 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/FilterOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/FilterOperator.java @@ -17,8 +17,9 @@ import org.opensearch.sql.storage.bindingtuple.BindingTuple; /** - * The Filter operator use the conditions to evaluate the input {@link BindingTuple}. - * The Filter operator only return the results that evaluated to true. + * The Filter operator represents WHERE clause and + * uses the conditions to evaluate the input {@link BindingTuple}. + * The Filter operator only returns the results that evaluated to true. * The NULL and MISSING are handled by the logic defined in {@link BinaryPredicateOperator}. */ @EqualsAndHashCode(callSuper = false) @@ -29,7 +30,8 @@ public class FilterOperator extends PhysicalPlan { private final PhysicalPlan input; @Getter private final Expression conditions; - @ToString.Exclude private ExprValue next = null; + @ToString.Exclude + private ExprValue next = null; @Override public R accept(PhysicalPlanNodeVisitor visitor, C context) { diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/NestedOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/NestedOperator.java index 049e9fd16e..54cd541519 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/NestedOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/NestedOperator.java @@ -99,7 +99,6 @@ public boolean hasNext() { return input.hasNext() || flattenedResult.hasNext(); } - @Override public ExprValue next() { if (!flattenedResult.hasNext()) { @@ -233,7 +232,6 @@ boolean containSamePath(Map newMap) { return false; } - /** * Retrieve nested field(s) in row. * diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlan.java b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlan.java index b476b01557..247b347940 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlan.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlan.java @@ -15,9 +15,8 @@ /** * Physical plan. */ -public abstract class PhysicalPlan implements PlanNode, - Iterator, - AutoCloseable { +public abstract class PhysicalPlan + implements PlanNode, Iterator, AutoCloseable { /** * Accept the {@link PhysicalPlanNodeVisitor}. * @@ -43,6 +42,6 @@ public void add(Split split) { public ExecutionEngine.Schema schema() { throw new IllegalStateException(String.format("[BUG] schema can been only applied to " - + "ProjectOperator, instead of %s", toString())); + + "ProjectOperator, instead of %s", this.getClass().getSimpleName())); } } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java index cb488700a0..1e8f08d39f 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java @@ -92,4 +92,8 @@ public R visitAD(PhysicalPlan node, C context) { public R visitML(PhysicalPlan node, C context) { return visitNode(node, context); } + + public R visitCursorClose(CursorCloseOperator node, C context) { + return visitNode(node, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/ProjectOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/ProjectOperator.java index 496e4e6ddb..1699c97c15 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/ProjectOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/ProjectOperator.java @@ -8,13 +8,16 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap.Builder; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; +import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; -import lombok.RequiredArgsConstructor; import lombok.ToString; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; @@ -22,20 +25,21 @@ import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.parse.ParseExpression; +import org.opensearch.sql.planner.SerializablePlan; /** * Project the fields specified in {@link ProjectOperator#projectList} from input. */ @ToString @EqualsAndHashCode(callSuper = false) -@RequiredArgsConstructor -public class ProjectOperator extends PhysicalPlan { +@AllArgsConstructor +public class ProjectOperator extends PhysicalPlan implements SerializablePlan { @Getter - private final PhysicalPlan input; + private PhysicalPlan input; @Getter - private final List projectList; + private List projectList; @Getter - private final List namedParseExpressions; + private List namedParseExpressions; @Override public R accept(PhysicalPlanNodeVisitor visitor, C context) { @@ -94,4 +98,24 @@ public ExecutionEngine.Schema schema() { .map(expr -> new ExecutionEngine.Schema.Column(expr.getName(), expr.getAlias(), expr.type())).collect(Collectors.toList())); } + + /** Don't use, it is for deserialization needs only. */ + @Deprecated + public ProjectOperator() { + } + + @SuppressWarnings("unchecked") + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + projectList = (List) in.readObject(); + // note: namedParseExpressions aren't serialized and deserialized + namedParseExpressions = List.of(); + input = (PhysicalPlan) in.readObject(); + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + out.writeObject(projectList); + out.writeObject(((SerializablePlan) input).getPlanForSerialization()); + } } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/ValuesOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/ValuesOperator.java index 51d2850df7..4ac9d6a30a 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/ValuesOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/ValuesOperator.java @@ -15,6 +15,7 @@ import lombok.ToString; import org.opensearch.sql.data.model.ExprCollectionValue; import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.LiteralExpression; /** @@ -58,7 +59,7 @@ public boolean hasNext() { @Override public ExprValue next() { List values = valuesIterator.next().stream() - .map(expr -> expr.valueOf()) + .map(Expression::valueOf) .collect(Collectors.toList()); return new ExprCollectionValue(values); } diff --git a/core/src/main/java/org/opensearch/sql/storage/StorageEngine.java b/core/src/main/java/org/opensearch/sql/storage/StorageEngine.java index 246a50ea09..ffcc0911de 100644 --- a/core/src/main/java/org/opensearch/sql/storage/StorageEngine.java +++ b/core/src/main/java/org/opensearch/sql/storage/StorageEngine.java @@ -29,5 +29,4 @@ public interface StorageEngine { default Collection getFunctions() { return Collections.emptyList(); } - } diff --git a/core/src/main/java/org/opensearch/sql/storage/Table.java b/core/src/main/java/org/opensearch/sql/storage/Table.java index e2586ed22c..fc1def5a2e 100644 --- a/core/src/main/java/org/opensearch/sql/storage/Table.java +++ b/core/src/main/java/org/opensearch/sql/storage/Table.java @@ -99,4 +99,5 @@ default TableWriteBuilder createWriteBuilder(LogicalWrite plan) { default StreamingSource asStreamingSource() { throw new UnsupportedOperationException(); } + } diff --git a/core/src/main/java/org/opensearch/sql/storage/read/TableScanBuilder.java b/core/src/main/java/org/opensearch/sql/storage/read/TableScanBuilder.java index 9af66e219f..f0158c52b8 100644 --- a/core/src/main/java/org/opensearch/sql/storage/read/TableScanBuilder.java +++ b/core/src/main/java/org/opensearch/sql/storage/read/TableScanBuilder.java @@ -11,6 +11,7 @@ import org.opensearch.sql.planner.logical.LogicalHighlight; import org.opensearch.sql.planner.logical.LogicalLimit; import org.opensearch.sql.planner.logical.LogicalNested; +import org.opensearch.sql.planner.logical.LogicalPaginate; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalPlanNodeVisitor; import org.opensearch.sql.planner.logical.LogicalProject; @@ -28,7 +29,7 @@ public abstract class TableScanBuilder extends LogicalPlan { /** * Construct and initialize children to empty list. */ - public TableScanBuilder() { + protected TableScanBuilder() { super(Collections.emptyList()); } @@ -116,6 +117,10 @@ public boolean pushDownNested(LogicalNested nested) { return false; } + public boolean pushDownPageSize(LogicalPaginate paginate) { + return false; + } + @Override public R accept(LogicalPlanNodeVisitor visitor, C context) { return visitor.visitTableScanBuilder(this, context); diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 9cabaccb0e..59edde6f86 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -7,6 +7,7 @@ package org.opensearch.sql.analysis; import static java.util.Collections.emptyList; +import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -75,8 +76,11 @@ import org.opensearch.sql.ast.expression.ScoreFunction; import org.opensearch.sql.ast.expression.SpanUnit; import org.opensearch.sql.ast.tree.AD; +import org.opensearch.sql.ast.tree.CloseCursor; +import org.opensearch.sql.ast.tree.FetchCursor; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.ML; +import org.opensearch.sql.ast.tree.Paginate; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.antlr.SyntaxCheckException; @@ -89,8 +93,11 @@ import org.opensearch.sql.expression.function.OpenSearchFunctions; import org.opensearch.sql.expression.window.WindowDefinition; import org.opensearch.sql.planner.logical.LogicalAD; +import org.opensearch.sql.planner.logical.LogicalCloseCursor; +import org.opensearch.sql.planner.logical.LogicalFetchCursor; import org.opensearch.sql.planner.logical.LogicalFilter; import org.opensearch.sql.planner.logical.LogicalMLCommons; +import org.opensearch.sql.planner.logical.LogicalPaginate; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalPlanDSL; import org.opensearch.sql.planner.logical.LogicalProject; @@ -1609,4 +1616,29 @@ public void ml_relation_predict_rcf_without_time_field() { assertTrue(((LogicalProject) actual).getProjectList() .contains(DSL.named(RCF_ANOMALOUS, DSL.ref(RCF_ANOMALOUS, BOOLEAN)))); } + + @Test + public void visit_paginate() { + LogicalPlan actual = analyze(new Paginate(10, AstDSL.relation("dummy"))); + assertTrue(actual instanceof LogicalPaginate); + assertEquals(10, ((LogicalPaginate) actual).getPageSize()); + } + + @Test + void visit_cursor() { + LogicalPlan actual = analyze((new FetchCursor("test"))); + assertTrue(actual instanceof LogicalFetchCursor); + assertEquals(new LogicalFetchCursor("test", + dataSourceService.getDataSource("@opensearch").getStorageEngine()), actual); + } + + @Test + public void visit_close_cursor() { + var analyzed = analyze(new CloseCursor().attach(new FetchCursor("pewpew"))); + assertAll( + () -> assertTrue(analyzed instanceof LogicalCloseCursor), + () -> assertTrue(analyzed.getChild().get(0) instanceof LogicalFetchCursor), + () -> assertEquals("pewpew", ((LogicalFetchCursor) analyzed.getChild().get(0)).getCursor()) + ); + } } diff --git a/core/src/test/java/org/opensearch/sql/executor/QueryServiceTest.java b/core/src/test/java/org/opensearch/sql/executor/QueryServiceTest.java index 4df38027f4..1510b304e6 100644 --- a/core/src/test/java/org/opensearch/sql/executor/QueryServiceTest.java +++ b/core/src/test/java/org/opensearch/sql/executor/QueryServiceTest.java @@ -15,11 +15,9 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.lenient; -import static org.mockito.Mockito.when; import java.util.Collections; import java.util.Optional; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -27,6 +25,7 @@ import org.opensearch.sql.analysis.Analyzer; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.response.ResponseListener; +import org.opensearch.sql.executor.pagination.Cursor; import org.opensearch.sql.planner.PlanContext; import org.opensearch.sql.planner.Planner; import org.opensearch.sql.planner.logical.LogicalPlan; @@ -134,7 +133,8 @@ Helper executeSuccess(Split split) { invocation -> { ResponseListener listener = invocation.getArgument(2); listener.onResponse( - new ExecutionEngine.QueryResponse(schema, Collections.emptyList())); + new ExecutionEngine.QueryResponse(schema, Collections.emptyList(), + Cursor.None)); return null; }) .when(executionEngine) diff --git a/core/src/test/java/org/opensearch/sql/executor/execution/CommandPlanTest.java b/core/src/test/java/org/opensearch/sql/executor/execution/CommandPlanTest.java new file mode 100644 index 0000000000..aa300cb0da --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/executor/execution/CommandPlanTest.java @@ -0,0 +1,75 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.sql.executor.execution; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.CALLS_REAL_METHODS; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.withSettings; + +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.common.response.ResponseListener; +import org.opensearch.sql.executor.QueryId; +import org.opensearch.sql.executor.QueryService; +import org.opensearch.sql.planner.logical.LogicalPlan; + +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +public class CommandPlanTest { + + @Test + public void execute_without_error() { + QueryService qs = mock(QueryService.class); + ResponseListener listener = mock(ResponseListener.class); + doNothing().when(qs).execute(any(), any()); + + new CommandPlan(QueryId.queryId(), mock(UnresolvedPlan.class), qs, listener).execute(); + + verify(qs).execute(any(), any()); + verify(listener, never()).onFailure(any()); + } + + @Test + public void execute_with_error() { + QueryService qs = mock(QueryService.class, withSettings().defaultAnswer(CALLS_REAL_METHODS)); + ResponseListener listener = mock(ResponseListener.class); + doThrow(new RuntimeException()) + .when(qs).executePlan(any(LogicalPlan.class), any(), any()); + + new CommandPlan(QueryId.queryId(), mock(UnresolvedPlan.class), qs, listener).execute(); + + verify(listener).onFailure(any()); + } + + @Test + @SuppressWarnings("unchecked") + public void explain_not_supported() { + QueryService qs = mock(QueryService.class); + ResponseListener listener = mock(ResponseListener.class); + ResponseListener explainListener = mock(ResponseListener.class); + + var exception = assertThrows(Throwable.class, () -> + new CommandPlan(QueryId.queryId(), mock(UnresolvedPlan.class), qs, listener) + .explain(explainListener)); + assertEquals("CommandPlan does not support explain", exception.getMessage()); + + verify(listener, never()).onResponse(any()); + verify(listener, never()).onFailure(any()); + verify(explainListener, never()).onResponse(any()); + verify(explainListener, never()).onFailure(any()); + } +} diff --git a/core/src/test/java/org/opensearch/sql/executor/execution/QueryPlanFactoryTest.java b/core/src/test/java/org/opensearch/sql/executor/execution/QueryPlanFactoryTest.java index cc4bf070fb..2d346e4c2a 100644 --- a/core/src/test/java/org/opensearch/sql/executor/execution/QueryPlanFactoryTest.java +++ b/core/src/test/java/org/opensearch/sql/executor/execution/QueryPlanFactoryTest.java @@ -8,26 +8,37 @@ package org.opensearch.sql.executor.execution; +import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static org.opensearch.sql.executor.execution.QueryPlanFactory.NO_CONSUMER_RESPONSE_LISTENER; import java.util.Optional; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.ast.statement.Explain; import org.opensearch.sql.ast.statement.Query; import org.opensearch.sql.ast.statement.Statement; +import org.opensearch.sql.ast.tree.CloseCursor; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.response.ResponseListener; +import org.opensearch.sql.exception.UnsupportedCursorRequestException; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.executor.QueryService; +import org.opensearch.sql.executor.pagination.CanPaginateVisitor; @ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class QueryPlanFactoryTest { @Mock @@ -53,43 +64,55 @@ void init() { } @Test - public void createFromQueryShouldSuccess() { - Statement query = new Query(plan); + public void create_from_query_should_success() { + Statement query = new Query(plan, 0); AbstractPlan queryExecution = factory.create(query, Optional.of(queryListener), Optional.empty()); assertTrue(queryExecution instanceof QueryPlan); } @Test - public void createFromExplainShouldSuccess() { - Statement query = new Explain(new Query(plan)); + public void create_from_explain_should_success() { + Statement query = new Explain(new Query(plan, 0)); AbstractPlan queryExecution = factory.create(query, Optional.empty(), Optional.of(explainListener)); assertTrue(queryExecution instanceof ExplainPlan); } @Test - public void createFromQueryWithoutQueryListenerShouldThrowException() { - Statement query = new Query(plan); + public void create_from_cursor_should_success() { + AbstractPlan queryExecution = factory.create("", false, + queryListener, explainListener); + AbstractPlan explainExecution = factory.create("", true, + queryListener, explainListener); + assertAll( + () -> assertTrue(queryExecution instanceof QueryPlan), + () -> assertTrue(explainExecution instanceof ExplainPlan) + ); + } + + @Test + public void create_from_query_without_query_listener_should_throw_exception() { + Statement query = new Query(plan, 0); IllegalArgumentException exception = - assertThrows(IllegalArgumentException.class, () -> factory.create(query, - Optional.empty(), Optional.empty())); + assertThrows(IllegalArgumentException.class, () -> factory.create( + query, Optional.empty(), Optional.empty())); assertEquals("[BUG] query listener must be not null", exception.getMessage()); } @Test - public void createFromExplainWithoutExplainListenerShouldThrowException() { - Statement query = new Explain(new Query(plan)); + public void create_from_explain_without_explain_listener_should_throw_exception() { + Statement query = new Explain(new Query(plan, 0)); IllegalArgumentException exception = - assertThrows(IllegalArgumentException.class, () -> factory.create(query, - Optional.empty(), Optional.empty())); + assertThrows(IllegalArgumentException.class, () -> factory.create( + query, Optional.empty(), Optional.empty())); assertEquals("[BUG] explain listener must be not null", exception.getMessage()); } @Test - public void noConsumerResponseChannel() { + public void no_consumer_response_channel() { IllegalStateException exception = assertThrows( IllegalStateException.class, @@ -104,4 +127,35 @@ public void noConsumerResponseChannel() { assertEquals( "[BUG] exception response should not sent to unexpected channel", exception.getMessage()); } + + @Test + public void create_query_with_fetch_size_which_can_be_paged() { + when(plan.accept(any(CanPaginateVisitor.class), any())).thenReturn(Boolean.TRUE); + factory = new QueryPlanFactory(queryService); + Statement query = new Query(plan, 10); + AbstractPlan queryExecution = + factory.create(query, Optional.of(queryListener), Optional.empty()); + assertTrue(queryExecution instanceof QueryPlan); + } + + @Test + public void create_query_with_fetch_size_which_cannot_be_paged() { + when(plan.accept(any(CanPaginateVisitor.class), any())).thenReturn(Boolean.FALSE); + factory = new QueryPlanFactory(queryService); + Statement query = new Query(plan, 10); + assertThrows(UnsupportedCursorRequestException.class, + () -> factory.create(query, + Optional.of(queryListener), Optional.empty())); + } + + @Test + public void create_close_cursor() { + factory = new QueryPlanFactory(queryService); + var plan = factory.createCloseCursor("pewpew", queryListener); + assertTrue(plan instanceof CommandPlan); + plan.execute(); + var captor = ArgumentCaptor.forClass(UnresolvedPlan.class); + verify(queryService).execute(captor.capture(), any()); + assertTrue(captor.getValue() instanceof CloseCursor); + } } diff --git a/core/src/test/java/org/opensearch/sql/executor/execution/QueryPlanTest.java b/core/src/test/java/org/opensearch/sql/executor/execution/QueryPlanTest.java index 834db76996..a0a98e2be7 100644 --- a/core/src/test/java/org/opensearch/sql/executor/execution/QueryPlanTest.java +++ b/core/src/test/java/org/opensearch/sql/executor/execution/QueryPlanTest.java @@ -8,21 +8,30 @@ package org.opensearch.sql.executor.execution; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import org.apache.commons.lang3.NotImplementedException; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.response.ResponseListener; +import org.opensearch.sql.executor.DefaultExecutionEngine; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.executor.QueryId; import org.opensearch.sql.executor.QueryService; @ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class QueryPlanTest { @Mock @@ -41,7 +50,7 @@ class QueryPlanTest { private ResponseListener queryListener; @Test - public void execute() { + public void execute_no_page_size() { QueryPlan query = new QueryPlan(queryId, plan, queryService, queryListener); query.execute(); @@ -49,10 +58,62 @@ public void execute() { } @Test - public void explain() { + public void explain_no_page_size() { QueryPlan query = new QueryPlan(queryId, plan, queryService, queryListener); query.explain(explainListener); verify(queryService, times(1)).explain(plan, explainListener); } + + @Test + public void can_execute_paginated_plan() { + var listener = new ResponseListener() { + @Override + public void onResponse(ExecutionEngine.QueryResponse response) { + assertNotNull(response); + } + + @Override + public void onFailure(Exception e) { + fail(); + } + }; + var plan = new QueryPlan(QueryId.queryId(), mock(UnresolvedPlan.class), 10, + queryService, listener); + plan.execute(); + } + + @Test + // Same as previous test, but with incomplete QueryService + public void can_handle_error_while_executing_plan() { + var listener = new ResponseListener() { + @Override + public void onResponse(ExecutionEngine.QueryResponse response) { + fail(); + } + + @Override + public void onFailure(Exception e) { + assertNotNull(e); + } + }; + var plan = new QueryPlan(QueryId.queryId(), mock(UnresolvedPlan.class), 10, + new QueryService(null, new DefaultExecutionEngine(), null), listener); + plan.execute(); + } + + @Test + public void explain_is_not_supported_for_pagination() { + new QueryPlan(null, null, 0, null, null).explain(new ResponseListener<>() { + @Override + public void onResponse(ExecutionEngine.ExplainResponse response) { + fail(); + } + + @Override + public void onFailure(Exception e) { + assertTrue(e instanceof NotImplementedException); + } + }); + } } diff --git a/core/src/test/java/org/opensearch/sql/executor/pagination/CanPaginateVisitorTest.java b/core/src/test/java/org/opensearch/sql/executor/pagination/CanPaginateVisitorTest.java new file mode 100644 index 0000000000..02a0dbc05e --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/executor/pagination/CanPaginateVisitorTest.java @@ -0,0 +1,132 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.executor.pagination; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; + +import java.util.List; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.ast.dsl.AstDSL; +import org.opensearch.sql.ast.tree.Project; +import org.opensearch.sql.ast.tree.Relation; +import org.opensearch.sql.executor.pagination.CanPaginateVisitor; + +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +public class CanPaginateVisitorTest { + + static final CanPaginateVisitor visitor = new CanPaginateVisitor(); + + @Test + // select * from y + public void accept_query_with_select_star_and_from() { + var plan = AstDSL.project(AstDSL.relation("dummy"), AstDSL.allFields()); + assertTrue(plan.accept(visitor, null)); + } + + @Test + // select x from y + public void reject_query_with_select_field_and_from() { + var plan = AstDSL.project(AstDSL.relation("dummy"), AstDSL.field("pewpew")); + assertFalse(plan.accept(visitor, null)); + } + + @Test + // select x,z from y + public void reject_query_with_select_fields_and_from() { + var plan = AstDSL.project(AstDSL.relation("dummy"), + AstDSL.field("pewpew"), AstDSL.field("pewpew")); + assertFalse(plan.accept(visitor, null)); + } + + @Test + // select x + public void reject_query_without_from() { + var plan = AstDSL.project(AstDSL.values(List.of(AstDSL.intLiteral(1))), + AstDSL.alias("1",AstDSL.intLiteral(1))); + assertFalse(plan.accept(visitor, null)); + } + + @Test + // select * from y limit z + public void reject_query_with_limit() { + var plan = AstDSL.project(AstDSL.limit(AstDSL.relation("dummy"), 1, 2), AstDSL.allFields()); + assertFalse(plan.accept(visitor, null)); + } + + @Test + // select * from y where z + public void reject_query_with_where() { + var plan = AstDSL.project(AstDSL.filter(AstDSL.relation("dummy"), + AstDSL.booleanLiteral(true)), AstDSL.allFields()); + assertFalse(plan.accept(visitor, null)); + } + + @Test + // select * from y order by z + public void reject_query_with_order_by() { + var plan = AstDSL.project(AstDSL.sort(AstDSL.relation("dummy"), AstDSL.field("1")), + AstDSL.allFields()); + assertFalse(plan.accept(visitor, null)); + } + + @Test + // select * from y group by z + public void reject_query_with_group_by() { + var plan = AstDSL.project(AstDSL.agg( + AstDSL.relation("dummy"), List.of(), List.of(), List.of(AstDSL.field("1")), List.of()), + AstDSL.allFields()); + assertFalse(plan.accept(visitor, null)); + } + + @Test + // select agg(x) from y + public void reject_query_with_aggregation_function() { + var plan = AstDSL.project(AstDSL.agg( + AstDSL.relation("dummy"), + List.of(AstDSL.alias("agg", AstDSL.aggregate("func", AstDSL.field("pewpew")))), + List.of(), List.of(), List.of()), + AstDSL.allFields()); + assertFalse(plan.accept(visitor, null)); + } + + @Test + // select window(x) from y + public void reject_query_with_window_function() { + var plan = AstDSL.project(AstDSL.relation("dummy"), + AstDSL.alias("pewpew", + AstDSL.window( + AstDSL.aggregate("func", AstDSL.field("pewpew")), + List.of(AstDSL.qualifiedName("1")), List.of()))); + assertFalse(plan.accept(visitor, null)); + } + + @Test + // select * from y, z + public void reject_query_with_select_from_multiple_indices() { + var plan = mock(Project.class); + when(plan.getChild()).thenReturn(List.of(AstDSL.relation("dummy"), AstDSL.relation("pummy"))); + when(plan.getProjectList()).thenReturn(List.of(AstDSL.allFields())); + assertFalse(visitor.visitProject(plan, null)); + } + + @Test + // unreal case, added for coverage only + public void reject_project_when_relation_has_child() { + var relation = mock(Relation.class, withSettings().useConstructor(AstDSL.qualifiedName("42"))); + when(relation.getChild()).thenReturn(List.of(AstDSL.relation("pewpew"))); + when(relation.accept(visitor, null)).thenCallRealMethod(); + var plan = mock(Project.class); + when(plan.getChild()).thenReturn(List.of(relation)); + when(plan.getProjectList()).thenReturn(List.of(AstDSL.allFields())); + assertFalse(visitor.visitProject((Project) plan, null)); + } +} diff --git a/core/src/test/java/org/opensearch/sql/executor/pagination/CursorTest.java b/core/src/test/java/org/opensearch/sql/executor/pagination/CursorTest.java new file mode 100644 index 0000000000..e3e2c8cf33 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/executor/pagination/CursorTest.java @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.executor.pagination; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.executor.pagination.Cursor; + +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +class CursorTest { + + @Test + void empty_array_is_none() { + Assertions.assertEquals(Cursor.None, new Cursor(null)); + } + + @Test + void toString_is_array_value() { + String cursorTxt = "This is a test"; + Assertions.assertEquals(cursorTxt, new Cursor(cursorTxt).toString()); + } +} diff --git a/core/src/test/java/org/opensearch/sql/executor/pagination/PlanSerializerTest.java b/core/src/test/java/org/opensearch/sql/executor/pagination/PlanSerializerTest.java new file mode 100644 index 0000000000..8211a3bc12 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/executor/pagination/PlanSerializerTest.java @@ -0,0 +1,178 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.executor.pagination; + +import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import lombok.SneakyThrows; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.opensearch.sql.exception.NoCursorException; +import org.opensearch.sql.planner.SerializablePlan; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.storage.StorageEngine; +import org.opensearch.sql.utils.TestOperator; + +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +public class PlanSerializerTest { + + StorageEngine storageEngine; + + PlanSerializer planCache; + + @BeforeEach + void setUp() { + storageEngine = mock(StorageEngine.class); + planCache = new PlanSerializer(storageEngine); + } + + @ParameterizedTest + @ValueSource(strings = {"pewpew", "asdkfhashdfjkgakgfwuigfaijkb", "ajdhfgajklghadfjkhgjkadhgad" + + "kadfhgadhjgfjklahdgqheygvskjfbvgsdklgfuirehiluANUIfgauighbahfuasdlhfnhaughsdlfhaughaggf" + + "and_some_other_funny_stuff_which_could_be_generated_while_sleeping_on_the_keyboard"}) + void serialize_deserialize_str(String input) { + var compressed = serialize(input); + assertEquals(input, deserialize(compressed)); + if (input.length() > 200) { + // Compression of short strings isn't profitable, because encoding into string and gzip + // headers add more bytes than input string has. + assertTrue(compressed.length() < input.length()); + } + } + + public static class SerializableTestClass implements Serializable { + public int field; + + @Override + public boolean equals(Object obj) { + return field == ((SerializableTestClass) obj).field; + } + } + + // Can't serialize private classes because they are not accessible + private class NotSerializableTestClass implements Serializable { + public int field; + + @Override + public boolean equals(Object obj) { + return field == ((SerializableTestClass) obj).field; + } + } + + @Test + void serialize_deserialize_obj() { + var obj = new SerializableTestClass(); + obj.field = 42; + assertEquals(obj, deserialize(serialize(obj))); + assertNotSame(obj, deserialize(serialize(obj))); + } + + @Test + void serialize_throws() { + assertThrows(Throwable.class, () -> serialize(new NotSerializableTestClass())); + var testObj = new TestOperator(); + testObj.setThrowIoOnWrite(true); + assertThrows(Throwable.class, () -> serialize(testObj)); + } + + @Test + void deserialize_throws() { + assertAll( + // from gzip - damaged header + () -> assertThrows(Throwable.class, () -> deserialize("00")), + // from HashCode::fromString + () -> assertThrows(Throwable.class, () -> deserialize("000")) + ); + } + + @Test + @SneakyThrows + void convertToCursor_returns_no_cursor_if_cant_serialize() { + var plan = new TestOperator(42); + plan.setThrowNoCursorOnWrite(true); + assertAll( + () -> assertThrows(NoCursorException.class, () -> serialize(plan)), + () -> assertEquals(Cursor.None, planCache.convertToCursor(plan)) + ); + } + + @Test + @SneakyThrows + void convertToCursor_returns_no_cursor_if_plan_is_not_paginate() { + var plan = mock(PhysicalPlan.class); + assertEquals(Cursor.None, planCache.convertToCursor(plan)); + } + + @Test + void convertToPlan_throws_cursor_has_no_prefix() { + assertThrows(UnsupportedOperationException.class, () -> + planCache.convertToPlan("abc")); + } + + @Test + void convertToPlan_throws_if_failed_to_deserialize() { + assertThrows(UnsupportedOperationException.class, () -> + planCache.convertToPlan("n:" + serialize(mock(Serializable.class)))); + } + + @Test + @SneakyThrows + void serialize_and_deserialize() { + var plan = new TestOperator(42); + var roundTripPlan = planCache.deserialize(planCache.serialize(plan)); + assertEquals(roundTripPlan, plan); + assertNotSame(roundTripPlan, plan); + } + + @Test + void convertToCursor_and_convertToPlan() { + var plan = new TestOperator(100500); + var roundTripPlan = (SerializablePlan) + planCache.convertToPlan(planCache.convertToCursor(plan).toString()); + assertEquals(plan, roundTripPlan); + assertNotSame(plan, roundTripPlan); + } + + @Test + @SneakyThrows + void resolveObject() { + ByteArrayOutputStream output = new ByteArrayOutputStream(); + ObjectOutputStream objectOutput = new ObjectOutputStream(output); + objectOutput.writeObject("Hello, world!"); + objectOutput.flush(); + + var cds = planCache.getCursorDeserializationStream( + new ByteArrayInputStream(output.toByteArray())); + assertEquals(storageEngine, cds.resolveObject("engine")); + var object = new Object(); + assertSame(object, cds.resolveObject(object)); + } + + // Helpers and auxiliary classes section below + + @SneakyThrows + private String serialize(Serializable input) { + return new PlanSerializer(null).serialize(input); + } + + private Serializable deserialize(String input) { + return new PlanSerializer(null).deserialize(input); + } +} diff --git a/core/src/test/java/org/opensearch/sql/executor/streaming/MicroBatchStreamingExecutionTest.java b/core/src/test/java/org/opensearch/sql/executor/streaming/MicroBatchStreamingExecutionTest.java index 1a2b6e3f2a..f0974db13e 100644 --- a/core/src/test/java/org/opensearch/sql/executor/streaming/MicroBatchStreamingExecutionTest.java +++ b/core/src/test/java/org/opensearch/sql/executor/streaming/MicroBatchStreamingExecutionTest.java @@ -26,6 +26,7 @@ import org.opensearch.sql.common.response.ResponseListener; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.executor.QueryService; +import org.opensearch.sql.executor.pagination.Cursor; import org.opensearch.sql.planner.PlanContext; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.storage.split.Split; @@ -169,7 +170,8 @@ Helper executeSuccess(Long... offsets) { ResponseListener listener = invocation.getArgument(2); listener.onResponse( - new ExecutionEngine.QueryResponse(null, Collections.emptyList())); + new ExecutionEngine.QueryResponse(null, Collections.emptyList(), + Cursor.None)); PlanContext planContext = invocation.getArgument(1); assertTrue(planContext.getSplit().isPresent()); diff --git a/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java b/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java index a717c4ed8f..c382f2634e 100644 --- a/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java @@ -3,12 +3,15 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.planner; import static java.util.Collections.emptyList; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.expression.DSL.literal; @@ -35,15 +38,17 @@ import java.util.Set; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; -import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.data.model.ExprBooleanValue; import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.executor.pagination.PlanSerializer; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.NamedExpression; @@ -52,36 +57,32 @@ import org.opensearch.sql.expression.aggregation.NamedAggregator; import org.opensearch.sql.expression.window.WindowDefinition; import org.opensearch.sql.expression.window.ranking.RowNumberFunction; +import org.opensearch.sql.planner.logical.LogicalCloseCursor; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalPlanDSL; import org.opensearch.sql.planner.logical.LogicalRelation; +import org.opensearch.sql.planner.physical.CursorCloseOperator; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanDSL; +import org.opensearch.sql.storage.StorageEngine; import org.opensearch.sql.storage.Table; import org.opensearch.sql.storage.TableScanOperator; import org.opensearch.sql.storage.read.TableScanBuilder; import org.opensearch.sql.storage.write.TableWriteBuilder; import org.opensearch.sql.storage.write.TableWriteOperator; +import org.opensearch.sql.utils.TestOperator; @ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class DefaultImplementorTest { - @Mock - private Expression filter; - - @Mock - private NamedAggregator aggregator; - - @Mock - private NamedExpression groupBy; - @Mock private Table table; private final DefaultImplementor implementor = new DefaultImplementor<>(); @Test - public void visitShouldReturnDefaultPhysicalOperator() { + public void visit_should_return_default_physical_operator() { String indexName = "test"; NamedExpression include = named("age", ref("age", INTEGER)); ReferenceExpression exclude = ref("name", STRING); @@ -181,14 +182,14 @@ public void visitShouldReturnDefaultPhysicalOperator() { } @Test - public void visitRelationShouldThrowException() { + public void visitRelation_should_throw_an_exception() { assertThrows(UnsupportedOperationException.class, () -> new LogicalRelation("test", table).accept(implementor, null)); } @SuppressWarnings({"rawtypes", "unchecked"}) @Test - public void visitWindowOperatorShouldReturnPhysicalWindowOperator() { + public void visitWindowOperator_should_return_PhysicalWindowOperator() { NamedExpression windowFunction = named(new RowNumberFunction()); WindowDefinition windowDefinition = new WindowDefinition( Collections.singletonList(ref("state", STRING)), @@ -228,8 +229,18 @@ public void visitWindowOperatorShouldReturnPhysicalWindowOperator() { } @Test - public void visitTableScanBuilderShouldBuildTableScanOperator() { - TableScanOperator tableScanOperator = Mockito.mock(TableScanOperator.class); + void visitLogicalCursor_deserializes_it() { + var engine = mock(StorageEngine.class); + + var physicalPlan = new TestOperator(); + var logicalPlan = LogicalPlanDSL.fetchCursor(new PlanSerializer(engine) + .convertToCursor(physicalPlan).toString(), engine); + assertEquals(physicalPlan, logicalPlan.accept(implementor, null)); + } + + @Test + public void visitTableScanBuilder_should_build_TableScanOperator() { + TableScanOperator tableScanOperator = mock(TableScanOperator.class); TableScanBuilder tableScanBuilder = new TableScanBuilder() { @Override public TableScanOperator build() { @@ -240,9 +251,9 @@ public TableScanOperator build() { } @Test - public void visitTableWriteBuilderShouldBuildTableWriteOperator() { + public void visitTableWriteBuilder_should_build_TableWriteOperator() { LogicalPlan child = values(); - TableWriteOperator tableWriteOperator = Mockito.mock(TableWriteOperator.class); + TableWriteOperator tableWriteOperator = mock(TableWriteOperator.class); TableWriteBuilder logicalPlan = new TableWriteBuilder(child) { @Override public TableWriteOperator build(PhysicalPlan child) { @@ -251,4 +262,15 @@ public TableWriteOperator build(PhysicalPlan child) { }; assertEquals(tableWriteOperator, logicalPlan.accept(implementor, null)); } + + @Test + public void visitCloseCursor_should_build_CursorCloseOperator() { + var logicalChild = mock(LogicalPlan.class); + var physicalChild = mock(PhysicalPlan.class); + when(logicalChild.accept(implementor, null)).thenReturn(physicalChild); + var logicalPlan = new LogicalCloseCursor(logicalChild); + var implemented = logicalPlan.accept(implementor, null); + assertTrue(implemented instanceof CursorCloseOperator); + assertSame(physicalChild, implemented.getChild().get(0)); + } } diff --git a/core/src/test/java/org/opensearch/sql/planner/SerializablePlanTest.java b/core/src/test/java/org/opensearch/sql/planner/SerializablePlanTest.java new file mode 100644 index 0000000000..8073445dc0 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/planner/SerializablePlanTest.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner; + +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Answers.CALLS_REAL_METHODS; + +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +public class SerializablePlanTest { + @Mock(answer = CALLS_REAL_METHODS) + SerializablePlan plan; + + @Test + void getPlanForSerialization_defaults_to_self() { + assertSame(plan, plan.getPlanForSerialization()); + } +} diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index fe76589066..d4d5c89c9b 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -8,23 +8,24 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.mockito.Mockito.mock; import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.expression.DSL.named; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; +import java.util.stream.Stream; import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.sql.ast.expression.DataType; -import org.opensearch.sql.ast.expression.Literal; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.data.model.ExprValueUtils; @@ -36,6 +37,7 @@ import org.opensearch.sql.expression.aggregation.Aggregator; import org.opensearch.sql.expression.window.WindowDefinition; import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.storage.StorageEngine; import org.opensearch.sql.storage.Table; import org.opensearch.sql.storage.TableScanOperator; import org.opensearch.sql.storage.read.TableScanBuilder; @@ -45,20 +47,24 @@ /** * Todo. Temporary added for UT coverage, Will be removed. */ -@ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class LogicalPlanNodeVisitorTest { - @Mock - Expression expression; - @Mock - ReferenceExpression ref; - @Mock - Aggregator aggregator; - @Mock - Table table; + static Expression expression; + static ReferenceExpression ref; + static Aggregator aggregator; + static Table table; + + @BeforeAll + private static void initMocks() { + expression = mock(Expression.class); + ref = mock(ReferenceExpression.class); + aggregator = mock(Aggregator.class); + table = mock(Table.class); + } @Test - public void logicalPlanShouldTraversable() { + public void logical_plan_should_be_traversable() { LogicalPlan logicalPlan = LogicalPlanDSL.rename( LogicalPlanDSL.aggregation( @@ -75,85 +81,42 @@ public void logicalPlanShouldTraversable() { assertEquals(5, result); } - @Test - public void testAbstractPlanNodeVisitorShouldReturnNull() { + @SuppressWarnings("unchecked") + private static Stream getLogicalPlansForVisitorTest() { LogicalPlan relation = LogicalPlanDSL.relation("schema", table); - assertNull(relation.accept(new LogicalPlanNodeVisitor() { - }, null)); - LogicalPlan tableScanBuilder = new TableScanBuilder() { @Override public TableScanOperator build() { return null; } }; - assertNull(tableScanBuilder.accept(new LogicalPlanNodeVisitor() { - }, null)); - - LogicalPlan write = LogicalPlanDSL.write(null, table, Collections.emptyList()); - assertNull(write.accept(new LogicalPlanNodeVisitor() { - }, null)); - TableWriteBuilder tableWriteBuilder = new TableWriteBuilder(null) { @Override public TableWriteOperator build(PhysicalPlan child) { return null; } }; - assertNull(tableWriteBuilder.accept(new LogicalPlanNodeVisitor() { - }, null)); - + LogicalPlan write = LogicalPlanDSL.write(null, table, Collections.emptyList()); LogicalPlan filter = LogicalPlanDSL.filter(relation, expression); - assertNull(filter.accept(new LogicalPlanNodeVisitor() { - }, null)); - - LogicalPlan aggregation = - LogicalPlanDSL.aggregation( - filter, ImmutableList.of(DSL.named("avg", aggregator)), ImmutableList.of(DSL.named( - "group", expression))); - assertNull(aggregation.accept(new LogicalPlanNodeVisitor() { - }, null)); - + LogicalPlan aggregation = LogicalPlanDSL.aggregation( + filter, ImmutableList.of(DSL.named("avg", aggregator)), ImmutableList.of(DSL.named( + "group", expression))); LogicalPlan rename = LogicalPlanDSL.rename(aggregation, ImmutableMap.of(ref, ref)); - assertNull(rename.accept(new LogicalPlanNodeVisitor() { - }, null)); - LogicalPlan project = LogicalPlanDSL.project(relation, named("ref", ref)); - assertNull(project.accept(new LogicalPlanNodeVisitor() { - }, null)); - LogicalPlan remove = LogicalPlanDSL.remove(relation, ref); - assertNull(remove.accept(new LogicalPlanNodeVisitor() { - }, null)); - LogicalPlan eval = LogicalPlanDSL.eval(relation, Pair.of(ref, expression)); - assertNull(eval.accept(new LogicalPlanNodeVisitor() { - }, null)); - - LogicalPlan sort = LogicalPlanDSL.sort(relation, - Pair.of(SortOption.DEFAULT_ASC, expression)); - assertNull(sort.accept(new LogicalPlanNodeVisitor() { - }, null)); - + LogicalPlan sort = LogicalPlanDSL.sort(relation, Pair.of(SortOption.DEFAULT_ASC, expression)); LogicalPlan dedup = LogicalPlanDSL.dedupe(relation, 1, false, false, expression); - assertNull(dedup.accept(new LogicalPlanNodeVisitor() { - }, null)); - LogicalPlan window = LogicalPlanDSL.window(relation, named(expression), new WindowDefinition( ImmutableList.of(ref), ImmutableList.of(Pair.of(SortOption.DEFAULT_ASC, expression)))); - assertNull(window.accept(new LogicalPlanNodeVisitor() { - }, null)); - LogicalPlan rareTopN = LogicalPlanDSL.rareTopN( relation, CommandType.TOP, ImmutableList.of(expression), expression); - assertNull(rareTopN.accept(new LogicalPlanNodeVisitor() { - }, null)); - - Map args = new HashMap<>(); LogicalPlan highlight = new LogicalHighlight(filter, - new LiteralExpression(ExprValueUtils.stringValue("fieldA")), args); - assertNull(highlight.accept(new LogicalPlanNodeVisitor() { - }, null)); + new LiteralExpression(ExprValueUtils.stringValue("fieldA")), Map.of()); + LogicalPlan mlCommons = new LogicalMLCommons(relation, "kmeans", Map.of()); + LogicalPlan ad = new LogicalAD(relation, Map.of()); + LogicalPlan ml = new LogicalML(relation, Map.of()); + LogicalPlan paginate = new LogicalPaginate(42, List.of(relation)); List> nestedArgs = List.of( Map.of( @@ -167,42 +130,26 @@ public TableWriteOperator build(PhysicalPlan child) { ); LogicalNested nested = new LogicalNested(null, nestedArgs, projectList); - assertNull(nested.accept(new LogicalPlanNodeVisitor() { - }, null)); - LogicalPlan mlCommons = new LogicalMLCommons(LogicalPlanDSL.relation("schema", table), - "kmeans", - ImmutableMap.builder() - .put("centroids", new Literal(3, DataType.INTEGER)) - .put("iterations", new Literal(3, DataType.DOUBLE)) - .put("distance_type", new Literal(null, DataType.STRING)) - .build()); - assertNull(mlCommons.accept(new LogicalPlanNodeVisitor() { - }, null)); + LogicalFetchCursor cursor = new LogicalFetchCursor("n:test", mock(StorageEngine.class)); - LogicalPlan ad = new LogicalAD(LogicalPlanDSL.relation("schema", table), - new HashMap() {{ - put("shingle_size", new Literal(8, DataType.INTEGER)); - put("time_decay", new Literal(0.0001, DataType.DOUBLE)); - put("time_field", new Literal(null, DataType.STRING)); - } - }); - assertNull(ad.accept(new LogicalPlanNodeVisitor() { - }, null)); + LogicalCloseCursor closeCursor = new LogicalCloseCursor(cursor); + + return Stream.of( + relation, tableScanBuilder, write, tableWriteBuilder, filter, aggregation, rename, project, + remove, eval, sort, dedup, window, rareTopN, highlight, mlCommons, ad, ml, paginate, nested, + cursor, closeCursor + ).map(Arguments::of); + } - LogicalPlan ml = new LogicalML(LogicalPlanDSL.relation("schema", table), - new HashMap() {{ - put("action", new Literal("train", DataType.STRING)); - put("algorithm", new Literal("rcf", DataType.STRING)); - put("shingle_size", new Literal(8, DataType.INTEGER)); - put("time_decay", new Literal(0.0001, DataType.DOUBLE)); - put("time_field", new Literal(null, DataType.STRING)); - } - }); - assertNull(ml.accept(new LogicalPlanNodeVisitor() { + @ParameterizedTest + @MethodSource("getLogicalPlansForVisitorTest") + public void abstract_plan_node_visitor_should_return_null(LogicalPlan plan) { + assertNull(plan.accept(new LogicalPlanNodeVisitor() { }, null)); } + private static class NodesCount extends LogicalPlanNodeVisitor { @Override public Integer visitRelation(LogicalRelation plan, Object context) { @@ -213,32 +160,28 @@ public Integer visitRelation(LogicalRelation plan, Object context) { public Integer visitFilter(LogicalFilter plan, Object context) { return 1 + plan.getChild().stream() - .map(child -> child.accept(this, context)) - .collect(Collectors.summingInt(Integer::intValue)); + .map(child -> child.accept(this, context)).mapToInt(Integer::intValue).sum(); } @Override public Integer visitAggregation(LogicalAggregation plan, Object context) { return 1 + plan.getChild().stream() - .map(child -> child.accept(this, context)) - .collect(Collectors.summingInt(Integer::intValue)); + .map(child -> child.accept(this, context)).mapToInt(Integer::intValue).sum(); } @Override public Integer visitRename(LogicalRename plan, Object context) { return 1 + plan.getChild().stream() - .map(child -> child.accept(this, context)) - .collect(Collectors.summingInt(Integer::intValue)); + .map(child -> child.accept(this, context)).mapToInt(Integer::intValue).sum(); } @Override public Integer visitRareTopN(LogicalRareTopN plan, Object context) { return 1 + plan.getChild().stream() - .map(child -> child.accept(this, context)) - .collect(Collectors.summingInt(Integer::intValue)); + .map(child -> child.accept(this, context)).mapToInt(Integer::intValue).sum(); } } } diff --git a/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java b/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java index d220f599f8..faedb88111 100644 --- a/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java @@ -9,6 +9,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; import static org.opensearch.sql.data.model.ExprValueUtils.longValue; @@ -20,6 +22,7 @@ import static org.opensearch.sql.planner.logical.LogicalPlanDSL.highlight; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.limit; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.nested; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.paginate; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.project; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.relation; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.sort; @@ -32,6 +35,8 @@ import java.util.Map; import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -43,13 +48,17 @@ import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.planner.logical.LogicalPaginate; import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.logical.LogicalPlanNodeVisitor; +import org.opensearch.sql.planner.logical.LogicalRelation; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.storage.Table; import org.opensearch.sql.storage.read.TableScanBuilder; import org.opensearch.sql.storage.write.TableWriteBuilder; @ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class LogicalPlanOptimizerTest { @Mock @@ -60,7 +69,7 @@ class LogicalPlanOptimizerTest { @BeforeEach void setUp() { - when(table.createScanBuilder()).thenReturn(tableScanBuilder); + lenient().when(table.createScanBuilder()).thenReturn(tableScanBuilder); } /** @@ -279,7 +288,6 @@ void table_scan_builder_support_nested_push_down_can_apply_its_rule() { @Test void table_not_support_scan_builder_should_not_be_impact() { - Mockito.reset(table, tableScanBuilder); Table table = new Table() { @Override public Map getFieldTypes() { @@ -300,7 +308,6 @@ public PhysicalPlan implement(LogicalPlan plan) { @Test void table_support_write_builder_should_be_replaced() { - Mockito.reset(table, tableScanBuilder); TableWriteBuilder writeBuilder = Mockito.mock(TableWriteBuilder.class); when(table.createWriteBuilder(any())).thenReturn(writeBuilder); @@ -312,7 +319,6 @@ void table_support_write_builder_should_be_replaced() { @Test void table_not_support_write_builder_should_report_error() { - Mockito.reset(table, tableScanBuilder); Table table = new Table() { @Override public Map getFieldTypes() { @@ -329,9 +335,75 @@ public PhysicalPlan implement(LogicalPlan plan) { () -> table.createWriteBuilder(null)); } + @Test + void paged_table_scan_builder_support_project_push_down_can_apply_its_rule() { + when(tableScanBuilder.pushDownPageSize(any())).thenReturn(true); + + var relation = relation("schema", table); + var optimized = LogicalPlanOptimizer.create() + .optimize(paginate(project(relation), 4)); + verify(tableScanBuilder).pushDownPageSize(any()); + + assertEquals(project(tableScanBuilder), optimized); + } + + @Test + void push_down_page_size_multiple_children() { + var relation = relation("schema", table); + var twoChildrenPlan = new LogicalPlan(List.of(relation, relation)) { + @Override + public R accept(LogicalPlanNodeVisitor visitor, C context) { + return null; + } + }; + var queryPlan = paginate(twoChildrenPlan, 4); + var optimizer = LogicalPlanOptimizer.create(); + final var exception = assertThrows(UnsupportedOperationException.class, + () -> optimizer.optimize(queryPlan)); + assertEquals("Unsupported plan: relation operator cannot have siblings", + exception.getMessage()); + } + + @Test + void push_down_page_size_push_failed() { + when(tableScanBuilder.pushDownPageSize(any())).thenReturn(false); + + var queryPlan = paginate( + project( + relation("schema", table)), 4); + var optimizer = LogicalPlanOptimizer.create(); + final var exception = assertThrows(IllegalStateException.class, + () -> optimizer.optimize(queryPlan)); + assertEquals("Failed to push down LogicalPaginate", exception.getMessage()); + } + + @Test + void push_page_size_noop_if_no_relation() { + var paginate = new LogicalPaginate(42, List.of(project(values()))); + assertEquals(paginate, LogicalPlanOptimizer.create().optimize(paginate)); + } + + @Test + void push_page_size_noop_if_no_sub_plans() { + var paginate = new LogicalPaginate(42, List.of()); + assertEquals(paginate, + LogicalPlanOptimizer.create().optimize(paginate)); + } + + @Test + void table_scan_builder_support_offset_push_down_can_apply_its_rule() { + when(tableScanBuilder.pushDownPageSize(any())).thenReturn(true); + + var relation = new LogicalRelation("schema", table); + var optimized = LogicalPlanOptimizer.create() + .optimize(new LogicalPaginate(42, List.of(project(relation)))); + // `optimized` structure: LogicalProject -> TableScanBuilder + // LogicalRelation replaced by a TableScanBuilder instance + assertEquals(project(tableScanBuilder), optimized); + } + private LogicalPlan optimize(LogicalPlan plan) { final LogicalPlanOptimizer optimizer = LogicalPlanOptimizer.create(); - final LogicalPlan optimize = optimizer.optimize(plan); - return optimize; + return optimizer.optimize(plan); } } diff --git a/core/src/test/java/org/opensearch/sql/planner/optimizer/pattern/PatternsTest.java b/core/src/test/java/org/opensearch/sql/planner/optimizer/pattern/PatternsTest.java index 9f90fd8d05..ef310e3b0e 100644 --- a/core/src/test/java/org/opensearch/sql/planner/optimizer/pattern/PatternsTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/optimizer/pattern/PatternsTest.java @@ -6,35 +6,39 @@ package org.opensearch.sql.planner.optimizer.pattern; +import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import java.util.Collections; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalPaginate; import org.opensearch.sql.planner.logical.LogicalPlan; -@ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class PatternsTest { - @Mock - LogicalPlan plan; - @Test void source_is_empty() { + var plan = mock(LogicalPlan.class); when(plan.getChild()).thenReturn(Collections.emptyList()); - assertFalse(Patterns.source().getFunction().apply(plan).isPresent()); - assertFalse(Patterns.source(null).getProperty().getFunction().apply(plan).isPresent()); + assertAll( + () -> assertFalse(Patterns.source().getFunction().apply(plan).isPresent()), + () -> assertFalse(Patterns.source(null).getProperty().getFunction().apply(plan).isPresent()) + ); } @Test void table_is_empty() { - plan = Mockito.mock(LogicalFilter.class); - assertFalse(Patterns.table().getFunction().apply(plan).isPresent()); - assertFalse(Patterns.writeTable().getFunction().apply(plan).isPresent()); + var plan = mock(LogicalFilter.class); + assertAll( + () -> assertFalse(Patterns.table().getFunction().apply(plan).isPresent()), + () -> assertFalse(Patterns.writeTable().getFunction().apply(plan).isPresent()) + ); } } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/CursorCloseOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/CursorCloseOperatorTest.java new file mode 100644 index 0000000000..5ae30faa30 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/planner/physical/CursorCloseOperatorTest.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.physical; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; + +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +public class CursorCloseOperatorTest { + + @Test + public void never_hasNext() { + var plan = new CursorCloseOperator(null); + assertFalse(plan.hasNext()); + plan.open(); + assertFalse(plan.hasNext()); + } + + @Test + public void open_is_not_propagated() { + var child = mock(PhysicalPlan.class); + var plan = new CursorCloseOperator(child); + plan.open(); + verify(child, never()).open(); + } + + @Test + public void close_is_propagated() { + var child = mock(PhysicalPlan.class); + var plan = new CursorCloseOperator(child); + plan.close(); + verify(child).close(); + } + + @Test + public void next_always_throws() { + var plan = new CursorCloseOperator(null); + assertThrows(Throwable.class, plan::next); + plan.open(); + assertThrows(Throwable.class, plan::next); + } + + @Test + public void produces_empty_schema() { + var child = mock(PhysicalPlan.class); + var plan = new CursorCloseOperator(child); + assertEquals(0, plan.schema().getColumns().size()); + verify(child, never()).schema(); + } +} diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/FilterOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/FilterOperatorTest.java index be8080ad3c..6a8bcad203 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/FilterOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/FilterOperatorTest.java @@ -17,22 +17,30 @@ import com.google.common.collect.ImmutableMap; import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.AdditionalAnswers; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.data.model.ExprIntegerValue; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.expression.DSL; @ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class FilterOperatorTest extends PhysicalPlanTestBase { @Mock private PhysicalPlan inputPlan; @Test - public void filterTest() { + public void filter_test() { FilterOperator plan = new FilterOperator(new TestScan(), DSL.and(DSL.notequal(DSL.ref("response", INTEGER), DSL.literal(200)), DSL.notequal(DSL.ref("response", INTEGER), DSL.literal(500)))); @@ -45,7 +53,7 @@ public void filterTest() { } @Test - public void nullValueShouldBeenIgnored() { + public void null_value_should_been_ignored() { LinkedHashMap value = new LinkedHashMap<>(); value.put("response", LITERAL_NULL); when(inputPlan.hasNext()).thenReturn(true, false); @@ -58,7 +66,7 @@ public void nullValueShouldBeenIgnored() { } @Test - public void missingValueShouldBeenIgnored() { + public void missing_value_should_been_ignored() { LinkedHashMap value = new LinkedHashMap<>(); value.put("response", LITERAL_MISSING); when(inputPlan.hasNext()).thenReturn(true, false); diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/NestedOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/NestedOperatorTest.java index 5d8b893869..5f8bf99b0d 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/NestedOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/NestedOperatorTest.java @@ -7,6 +7,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.contains; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.when; import static org.opensearch.sql.data.model.ExprValueUtils.collectionValue; @@ -91,8 +92,10 @@ public void nested_one_nested_field() { Map> groupedFieldsByPath = Map.of("message", List.of("message.info")); + var nested = new NestedOperator(inputPlan, fields, groupedFieldsByPath); + assertThat( - execute(new NestedOperator(inputPlan, fields, groupedFieldsByPath)), + execute(nested), contains( tupleValue( new LinkedHashMap<>() {{ @@ -176,8 +179,10 @@ public void nested_two_nested_field() { "field", new ReferenceExpression("comment.data", STRING), "path", new ReferenceExpression("comment", STRING)) ); + var nested = new NestedOperator(inputPlan, fields); + assertThat( - execute(new NestedOperator(inputPlan, fields)), + execute(nested), contains( tupleValue( new LinkedHashMap<>() {{ @@ -252,8 +257,10 @@ public void nested_two_nested_fields_with_same_path() { "field", new ReferenceExpression("message.id", STRING), "path", new ReferenceExpression("message", STRING)) ); + var nested = new NestedOperator(inputPlan, fields); + assertThat( - execute(new NestedOperator(inputPlan, fields)), + execute(nested), contains( tupleValue( new LinkedHashMap<>() {{ @@ -286,8 +293,10 @@ public void non_nested_field_tests() { Set fields = Set.of("message"); Map> groupedFieldsByPath = Map.of("message", List.of("message.info")); + + var nested = new NestedOperator(inputPlan, fields, groupedFieldsByPath); assertThat( - execute(new NestedOperator(inputPlan, fields, groupedFieldsByPath)), + execute(nested), contains( tupleValue(new LinkedHashMap<>(Map.of("message", "val"))) ) @@ -302,8 +311,10 @@ public void nested_missing_tuple_field() { Set fields = Set.of("message.val"); Map> groupedFieldsByPath = Map.of("message", List.of("message.val")); + + var nested = new NestedOperator(inputPlan, fields, groupedFieldsByPath); assertThat( - execute(new NestedOperator(inputPlan, fields, groupedFieldsByPath)), + execute(nested), contains( tupleValue(new LinkedHashMap<>(Map.of("message.val", ExprNullValue.of()))) ) @@ -318,11 +329,11 @@ public void nested_missing_array_field() { Set fields = Set.of("missing.data"); Map> groupedFieldsByPath = Map.of("message", List.of("message.data")); - assertTrue( - execute(new NestedOperator(inputPlan, fields, groupedFieldsByPath)) - .get(0) - .tupleValue() - .size() == 0 - ); + + var nested = new NestedOperator(inputPlan, fields, groupedFieldsByPath); + assertEquals(0, execute(nested) + .get(0) + .tupleValue() + .size()); } } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java index fb687277ce..8ed4881d33 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java @@ -9,9 +9,22 @@ import static java.util.Collections.emptyList; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.mockito.Mockito.mock; import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.expression.DSL.named; +import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.agg; +import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.dedupe; +import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.eval; +import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.filter; +import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.limit; +import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.project; +import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.rareTopN; +import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.remove; +import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.rename; +import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.sort; +import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.values; +import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.window; import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; @@ -19,9 +32,15 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Stream; import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.ast.tree.RareTopN.CommandType; @@ -34,6 +53,7 @@ * Todo, testing purpose, delete later. */ @ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class PhysicalPlanNodeVisitorTest extends PhysicalPlanTestBase { @Mock PhysicalPlan plan; @@ -43,13 +63,13 @@ class PhysicalPlanNodeVisitorTest extends PhysicalPlanTestBase { @Test public void print_physical_plan() { PhysicalPlan plan = - PhysicalPlanDSL.remove( - PhysicalPlanDSL.project( - PhysicalPlanDSL.rename( - PhysicalPlanDSL.agg( - PhysicalPlanDSL.rareTopN( - PhysicalPlanDSL.filter( - PhysicalPlanDSL.limit( + remove( + project( + rename( + agg( + rareTopN( + filter( + limit( new TestScan(), 1, 1 ), @@ -76,71 +96,59 @@ public void print_physical_plan() { printer.print(plan)); } - @Test - public void test_PhysicalPlanVisitor_should_return_null() { + public static Stream getPhysicalPlanForTest() { + PhysicalPlan plan = mock(PhysicalPlan.class); + ReferenceExpression ref = mock(ReferenceExpression.class); + PhysicalPlan filter = - PhysicalPlanDSL.filter( - new TestScan(), DSL.equal(DSL.ref("response", INTEGER), DSL.literal(10))); - assertNull(filter.accept(new PhysicalPlanNodeVisitor() { - }, null)); + filter(new TestScan(), DSL.equal(DSL.ref("response", INTEGER), DSL.literal(10))); PhysicalPlan aggregation = - PhysicalPlanDSL.agg( - filter, ImmutableList.of(DSL.named("avg(response)", + agg(filter, ImmutableList.of(DSL.named("avg(response)", DSL.avg(DSL.ref("response", INTEGER)))), ImmutableList.of()); - assertNull(aggregation.accept(new PhysicalPlanNodeVisitor() { - }, null)); PhysicalPlan rename = - PhysicalPlanDSL.rename( - aggregation, ImmutableMap.of(DSL.ref("ivalue", INTEGER), DSL.ref("avg(response)", + rename(aggregation, ImmutableMap.of(DSL.ref("ivalue", INTEGER), DSL.ref("avg(response)", DOUBLE))); - assertNull(rename.accept(new PhysicalPlanNodeVisitor() { - }, null)); - PhysicalPlan project = PhysicalPlanDSL.project(plan, named("ref", ref)); - assertNull(project.accept(new PhysicalPlanNodeVisitor() { - }, null)); + PhysicalPlan project = project(plan, named("ref", ref)); - PhysicalPlan window = PhysicalPlanDSL.window(plan, named(DSL.rowNumber()), + PhysicalPlan window = window(plan, named(DSL.rowNumber()), new WindowDefinition(emptyList(), emptyList())); - assertNull(window.accept(new PhysicalPlanNodeVisitor() { - }, null)); - PhysicalPlan remove = PhysicalPlanDSL.remove(plan, ref); - assertNull(remove.accept(new PhysicalPlanNodeVisitor() { - }, null)); + PhysicalPlan remove = remove(plan, ref); - PhysicalPlan eval = PhysicalPlanDSL.eval(plan, Pair.of(ref, ref)); - assertNull(eval.accept(new PhysicalPlanNodeVisitor() { - }, null)); + PhysicalPlan eval = eval(plan, Pair.of(ref, ref)); - PhysicalPlan sort = PhysicalPlanDSL.sort(plan, Pair.of(SortOption.DEFAULT_ASC, ref)); - assertNull(sort.accept(new PhysicalPlanNodeVisitor() { - }, null)); + PhysicalPlan sort = sort(plan, Pair.of(SortOption.DEFAULT_ASC, ref)); - PhysicalPlan dedupe = PhysicalPlanDSL.dedupe(plan, ref); - assertNull(dedupe.accept(new PhysicalPlanNodeVisitor() { - }, null)); + PhysicalPlan dedupe = dedupe(plan, ref); - PhysicalPlan values = PhysicalPlanDSL.values(emptyList()); - assertNull(values.accept(new PhysicalPlanNodeVisitor() { - }, null)); + PhysicalPlan values = values(emptyList()); - PhysicalPlan rareTopN = - PhysicalPlanDSL.rareTopN(plan, CommandType.TOP, 5, ImmutableList.of(), ref); - assertNull(rareTopN.accept(new PhysicalPlanNodeVisitor() { - }, null)); + PhysicalPlan rareTopN = rareTopN(plan, CommandType.TOP, 5, ImmutableList.of(), ref); - PhysicalPlan limit = PhysicalPlanDSL.limit(plan, 1, 1); - assertNull(limit.accept(new PhysicalPlanNodeVisitor() { - }, null)); + PhysicalPlan limit = limit(plan, 1, 1); Set nestedArgs = Set.of("nested.test"); - Map> groupedFieldsByPath = - Map.of("nested", List.of("nested.test")); + Map> groupedFieldsByPath = Map.of("nested", List.of("nested.test")); PhysicalPlan nested = new NestedOperator(plan, nestedArgs, groupedFieldsByPath); - assertNull(nested.accept(new PhysicalPlanNodeVisitor() { + + PhysicalPlan cursorClose = new CursorCloseOperator(plan); + + return Stream.of(Arguments.of(filter, "filter"), Arguments.of(aggregation, "aggregation"), + Arguments.of(rename, "rename"), Arguments.of(project, "project"), + Arguments.of(window, "window"), Arguments.of(remove, "remove"), + Arguments.of(eval, "eval"), Arguments.of(sort, "sort"), Arguments.of(dedupe, "dedupe"), + Arguments.of(values, "values"), Arguments.of(rareTopN, "rareTopN"), + Arguments.of(limit, "limit"), Arguments.of(nested, "nested"), + Arguments.of(cursorClose, "cursorClose")); + } + + @ParameterizedTest(name = "{1}") + @MethodSource("getPhysicalPlanForTest") + public void test_PhysicalPlanVisitor_should_return_null(PhysicalPlan plan, String name) { + assertNull(plan.accept(new PhysicalPlanNodeVisitor() { }, null)); } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTest.java index 0a93c96bbb..ab3f0ef36d 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTest.java @@ -5,9 +5,19 @@ package org.opensearch.sql.planner.physical; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.CALLS_REAL_METHODS; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.util.List; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -16,6 +26,7 @@ import org.opensearch.sql.storage.split.Split; @ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class PhysicalPlanTest { @Mock Split split; @@ -46,7 +57,7 @@ public List getChild() { }; @Test - void addSplitToChildByDefault() { + void add_split_to_child_by_default() { testPlan.add(split); verify(child).add(split); } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/ProjectOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/ProjectOperatorTest.java index 24be5eb2b8..f5ecf76bd0 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/ProjectOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/ProjectOperatorTest.java @@ -11,6 +11,7 @@ import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.hasItems; import static org.hamcrest.Matchers.iterableWithSize; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.when; import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_MISSING; import static org.opensearch.sql.data.model.ExprValueUtils.stringValue; @@ -20,7 +21,12 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; import java.util.List; +import lombok.SneakyThrows; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -30,11 +36,12 @@ import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.utils.TestOperator; @ExtendWith(MockitoExtension.class) class ProjectOperatorTest extends PhysicalPlanTestBase { - @Mock + @Mock(serializable = true) private PhysicalPlan inputPlan; @Test @@ -206,4 +213,21 @@ public void project_parse_missing_will_fallback() { ExprValueUtils.tupleValue(ImmutableMap.of("action", "GET", "response", "200")), ExprValueUtils.tupleValue(ImmutableMap.of("action", "POST"))))); } + + @Test + @SneakyThrows + public void serializable() { + var projects = List.of(DSL.named("action", DSL.ref("action", STRING))); + var project = new ProjectOperator(new TestOperator(), projects, List.of()); + + ByteArrayOutputStream output = new ByteArrayOutputStream(); + ObjectOutputStream objectOutput = new ObjectOutputStream(output); + objectOutput.writeObject(project); + objectOutput.flush(); + + ObjectInputStream objectInput = new ObjectInputStream( + new ByteArrayInputStream(output.toByteArray())); + var roundTripPlan = (ProjectOperator) objectInput.readObject(); + assertEquals(project, roundTripPlan); + } } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/RemoveOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/RemoveOperatorTest.java index bf046bf0a6..ec950e6016 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/RemoveOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/RemoveOperatorTest.java @@ -113,12 +113,11 @@ public void remove_nothing_with_none_tuple_value() { @Test public void invalid_to_retrieve_schema_from_remove() { - PhysicalPlan plan = remove(inputPlan, DSL.ref("response", STRING), DSL.ref("referer", STRING)); + PhysicalPlan plan = remove(inputPlan); IllegalStateException exception = assertThrows(IllegalStateException.class, () -> plan.schema()); assertEquals( - "[BUG] schema can been only applied to ProjectOperator, " - + "instead of RemoveOperator(input=inputPlan, removeList=[response, referer])", + "[BUG] schema can been only applied to ProjectOperator, instead of RemoveOperator", exception.getMessage()); } } diff --git a/core/src/test/java/org/opensearch/sql/storage/StorageEngineTest.java b/core/src/test/java/org/opensearch/sql/storage/StorageEngineTest.java index 0e969c6dac..67014b76bd 100644 --- a/core/src/test/java/org/opensearch/sql/storage/StorageEngineTest.java +++ b/core/src/test/java/org/opensearch/sql/storage/StorageEngineTest.java @@ -13,11 +13,9 @@ public class StorageEngineTest { - @Test void testFunctionsMethod() { StorageEngine k = (dataSourceSchemaName, tableName) -> null; Assertions.assertEquals(Collections.emptyList(), k.getFunctions()); } - } diff --git a/core/src/test/java/org/opensearch/sql/utils/TestOperator.java b/core/src/test/java/org/opensearch/sql/utils/TestOperator.java new file mode 100644 index 0000000000..584cf6f3fd --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/utils/TestOperator.java @@ -0,0 +1,73 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.utils; + +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.util.List; +import lombok.Setter; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.exception.NoCursorException; +import org.opensearch.sql.planner.SerializablePlan; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; + +public class TestOperator extends PhysicalPlan implements SerializablePlan { + private int field; + @Setter + private boolean throwNoCursorOnWrite = false; + @Setter + private boolean throwIoOnWrite = false; + + public TestOperator() { + } + + public TestOperator(int value) { + field = value; + } + + @Override + public void readExternal(ObjectInput in) throws IOException { + field = in.readInt(); + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + if (throwNoCursorOnWrite) { + throw new NoCursorException(); + } + if (throwIoOnWrite) { + throw new IOException(); + } + out.writeInt(field); + } + + @Override + public boolean equals(Object o) { + return field == ((TestOperator) o).field; + } + + @Override + public R accept(PhysicalPlanNodeVisitor visitor, C context) { + return null; + } + + @Override + public boolean hasNext() { + return false; + } + + @Override + public ExprValue next() { + return null; + } + + @Override + public List getChild() { + return null; + } +} diff --git a/core/src/testFixtures/java/org/opensearch/sql/executor/DefaultExecutionEngine.java b/core/src/testFixtures/java/org/opensearch/sql/executor/DefaultExecutionEngine.java index e4f9a185a3..db72498a1d 100644 --- a/core/src/testFixtures/java/org/opensearch/sql/executor/DefaultExecutionEngine.java +++ b/core/src/testFixtures/java/org/opensearch/sql/executor/DefaultExecutionEngine.java @@ -9,6 +9,7 @@ import java.util.List; import org.opensearch.sql.common.response.ResponseListener; import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.executor.pagination.Cursor; import org.opensearch.sql.planner.physical.PhysicalPlan; /** @@ -32,7 +33,8 @@ public void execute( while (plan.hasNext()) { result.add(plan.next()); } - QueryResponse response = new QueryResponse(new Schema(new ArrayList<>()), new ArrayList<>()); + QueryResponse response = new QueryResponse(new Schema(new ArrayList<>()), new ArrayList<>(), + Cursor.None); listener.onResponse(response); } catch (Exception e) { listener.onFailure(e); diff --git a/docs/dev/Pagination-v2.md b/docs/dev/Pagination-v2.md new file mode 100644 index 0000000000..1c3510b116 --- /dev/null +++ b/docs/dev/Pagination-v2.md @@ -0,0 +1,774 @@ +# Pagination in v2 Engine + +Pagination allows a SQL plugin client to retrieve arbitrarily large results sets one subset at a time. + +A cursor is a SQL abstraction for pagination. A client can open a cursor, retrieve a subset of data given a cursor and close a cursor. + +Currently, SQL plugin does not provide SQL cursor syntax. However, the SQL REST endpoint can return result a page at a time. This feature is used by JDBC and ODBC drivers. + +# Scope +This document describes pagination in V2 sql engine for non-aggregate queries -- queries +without `GROUP BY` clause or use of window functions. + +# Demo +https://user-images.githubusercontent.com/88679692/224208630-8d38d833-abf8-4035-8d15-d5fb4382deca.mp4 + +# REST API +## Initial Query Request + +Initial query request contains the search request and page size. Search query to OpenSearch is built during processing of this request. Neither the query nor page size can be changed while scrolling through pages based on this request. +The only difference between paged and non-paged requests is `fetch_size` parameter supplied in paged request. + +```json +POST /_plugins/_sql +{ + "query" : "...", + "fetch_size": N +} +``` + +Response: +```json +{ + "cursor": "", + "datarows": [ + ... + ], + "schema" : [ + ... + ] +} +``` +`query` is a DQL statement. `fetch_size` is a positive integer, indicating number of rows to return in each page. + +If `query` is a DML statement then pagination does not apply, the `fetch_size` parameter is ignored and a cursor is not created. This is existing behaviour in v1 engine. + +The client receives an [error response](#error-response) if: +- `fetch_size` is not a positive integer +- evaluating `query` results in a server-side error +- `fetch_size` is bigger than `max_window_size` cluster-wide parameter. + +## Subsequent Query Request + +Subsequent query request contains a cursor only. + +```json +POST /_plugins/_sql +{ + "cursor": "" +} +``` +Similarly to v1 engine, the response object is the same as initial response if this is not the last page. + +`cursor_id` will be different with each request. + +## End of scrolling/paging +The last page in a response will not have a cursor id property. + +## Cursor Keep Alive Timeout + +Each cursor has a keep alive timer associated with it. When the timer runs out, the cursor is automatically closed by OpenSearch. + +This timer is reset every time a page is retrieved. + +The client will receive an [error response](#error-response) if it sends a cursor request for an expired cursor. + +Keep alive timeout is [configurable](../user/admin/settings.rst#plugins.sql.cursor.keep_alive) by setting `plugins.sql.cursor.keep_alive` and has default value of 1 minute. + +## Error Response + +The client will receive an error response if any of the above REST calls result in a server-side error. + +The response object has the following format: +```json +{ + "error": { + "details": "", + "reason": "", + "type": "" + }, + "status": +} +``` + +`details`, `reason`, and `type` properties are string values. The exact values will depend on the error state encountered. +`status` is an HTTP status code + +## OpenSearch Data Retrieval Strategy + +OpenSearch provides several data retrieval APIs that are optimized for different use cases. + +At this time, SQL plugin uses simple search API and scroll API. + +Simple retrieval API returns at most `max_result_window` number of documents. `max_result_window` is an index setting. + +Scroll API requests returns all documents but can incur high memory costs on OpenSearch coordination node. + +Efficient implementation of pagination needs to be aware of retrieval API used. Each retrieval strategy will be considered separately. + +The discussion below uses *under max_result_window* to refer to scenarios that can be implemented with simple retrieval API and *over max_result_window* for scenarios that require scroll API to implement. + +## SQL Node Load Balancing + +V2 SQL engine supports *sql node load balancing* — a cursor request can be routed to any SQL node in a cluster. This is achieved by encoding all data necessary to retrieve the next page in the `cursor_id` property in the response. + +## Feature Design +To support pagination, v2 SQL engine needs to: +1. in REST front-end: + 1. Route supported paginated query to v2 engine for + 1. Initial requests, + 2. Next page requests. + 2. Fallback to v1 engine for queries not supported by v2 engine. + 3. Create correct JSON response from execution of paginated physical plan by v2 engine. +2. during query planning: + 1. Differentiate between paginated and normal query plans. + 2. Push down pagination to table scan. + 3. Create a physical query plan from a cursor id. +3. during query execution: + 1. Serialize an executing query and generate a cursor id after returning `fetch_size` number of elements. +4. in OpenSearch data source: + 1. Support pagination push down. + 2. Support other push down optimizations with pagination. + +### Query Plan Changes + +All three kinds of query requests — non-paged, initial page, or subsequent page — are processed in the same way. Simplified workflow of query plan processing is shown below for reference. + +```mermaid +stateDiagram-v2 + state "Request" as NonPaged { + direction LR + state "Parse Tree" as Parse + state "Unresolved Query Plan" as Unresolved + state "Abstract Query Plan" as Abstract + state "Logical Query Plan" as Logical + state "Optimized Query Plan" as Optimized + state "Physical Query Plan" as Physical + + [*] --> Parse : ANTLR + Parse --> Unresolved : AstBuilder + Unresolved --> Abstract : QueryPlanner + Abstract --> Logical : Planner + Logical --> Optimized : Optimizer + Optimized --> Physical : Implementor + } +``` + + +#### Unresolved Query Plan + +Unresolved Query Plan for non-paged requests remains unchanged. + +To support initial query requests, the `QueryPlan` class has a new optional field `pageSize`. + +```mermaid +classDiagram + direction LR + class QueryPlan { + <> + -Optional~int~ pageSize + -UnresolvedPlan plan + -QueryService queryService + } + class UnresolvedQueryPlan { + <> + } + QueryPlan --* UnresolvedQueryPlan +``` + +When `QueryPlanFactory.create` is passed initial query request, it: +1. Adds an instance of `Paginate` unresolved plan as the root of the unresolved query plan. +2. Sets `pageSize` parameter in `QueryPlan`. + +```mermaid +classDiagram + direction LR + class QueryPlan { + <> + -Optional~int~ pageSize + -UnresolvedPlan plan + -QueryService queryService + } + class Paginate { + <> + -int pageSize + -UnresolvedPlan child + } + class UnresolvedQueryPlan { + <> + } + QueryPlan --* Paginate + Paginate --* UnresolvedQueryPlan +``` + +When `QueryPlanFactory.create` is passed a subsequent query request, it: +1. Creates an instance of `FetchCursor` unresolved plan as the sole node in the unresolved query plan. + +```mermaid +classDiagram + direction LR + class QueryPlan { + <> + -Optional~int~ pageSize + -UnresolvedPlan plan + -QueryService queryService + } + class FetchCursor { + <> + -String cursorId + } + QueryPlan --* FetchCursor +``` + +The examples below show Abstract Query Plan for the same query in different request types: + +```mermaid +stateDiagram-v2 + state "Non Paged Request" as NonPaged { + state "QueryPlan" as QueryPlanNP + state "Project" as ProjectNP + state "Limit" as LimitNP + state "Filter" as FilterNP + state "Aggregation" as AggregationNP + state "Relation" as RelationNP + + QueryPlanNP --> ProjectNP + ProjectNP --> LimitNP + LimitNP --> FilterNP + FilterNP --> AggregationNP + AggregationNP --> RelationNP + } + + state "Initial Query Request" as Paged { + state "QueryPlan" as QueryPlanIP + state "Project" as ProjectIP + state "Limit" as LimitIP + state "Filter" as FilterIP + state "Aggregation" as AggregationIP + state "Relation" as RelationIP + + Paginate --> QueryPlanIP + QueryPlanIP --> ProjectIP + ProjectIP --> LimitIP + LimitIP --> FilterIP + FilterIP --> AggregationIP + AggregationIP --> RelationIP + } + + state "Subsequent Query Request" As Sub { + FetchCursor + } +``` + +#### Logical Query Plan + +There are no changes for non-paging requests. + +Changes to logical query plan to support Initial Query Request: +1. `LogicalPaginate` is added to the top of the tree. It stores information about paging should be done in a private field `pageSize` being pushed down in the `Optimizer`. + +```mermaid +classDiagram + direction LR + class LogicalPaginate { + <> + int pageSize + } + class LogicalQueryPlan { + <> + } + class LogicalRelation { + <> + } + LogicalPaginate --* LogicalQueryPlan + LogicalQueryPlan --* LogicalRelation +``` + +For subsequent page requests, `FetchCursor` unresolved plan is mapped to `LogicalFetchCursor` logical plan. + +```mermaid +classDiagram + direction LR + class LogicalQueryPlan { + <> + } + class LogicalFetchCursor { + <> + -String cursorId + } + LogicalQueryPlan --* LogicalFetchCursor +``` + +The examples below show logical query plan for the same query in different request types: + +```mermaid +stateDiagram-v2 + state "Non Paged Request" as NonPaged { + state "LogicalProject" as ProjectNP + state "LogicalLimit" as LimitNP + state "LogicalFilter" as FilterNP + state "LogicalAggregation" as AggregationNP + state "LogicalRelation" as RelationNP + + ProjectNP --> LimitNP + LimitNP --> FilterNP + FilterNP --> AggregationNP + AggregationNP --> RelationNP + } + + state "Initial Query Request" as Paged { + state "LogicalProject" as ProjectIP + state "LogicalLimit" as LimitIP + state "LogicalFilter" as FilterIP + state "LogicalAggregation" as AggregationIP + state "LogicalRelation" as RelationIP + + LogicalPaginate --> ProjectIP + ProjectIP --> LimitIP + LimitIP --> FilterIP + FilterIP --> AggregationIP + AggregationIP --> RelationIP + } + + state "Subsequent Query Request" As Sub { + FetchCursor + } +``` + + +#### Optimized Logical Query Plan + +Pagination is implemented by push down to OpenSearch. The following is only relevant for +initial paged requests. Non-paged request optimization was not changed and there is no optimization +to be done for subsequent page query plans. + +Push down logical is implemented in `OpenSearchIndexScanQueryBuilder.pushDownPageSize` method. +This method is called by `PushDownPageSize` rule during plan optimization. `LogicalPaginate` is removed from the query plan during push down operation in `Optimizer`. + +See [article about `TableScanBuilder`](query-optimizer-improvement.md#TableScanBuilder) for more details. + +The examples below show optimized Logical Query Plan for the same query in different request types: + +```mermaid +stateDiagram-v2 + state "Non Paged Request" as NonPaged { + state "LogicalProject" as ProjectNP + state "LogicalLimit" as LimitNP + state "LogicalSort" as SortNP + state "OpenSearchIndexScanQueryBuilder" as RelationNP + + ProjectNP --> LimitNP + LimitNP --> SortNP + SortNP --> RelationNP + } + +``` + +#### Physical Query Plan and Execution + +Changes: +1. `OpenSearchIndexScanBuilder` is converted to `OpenSearchIndexScan` by `Implementor`. +2. `LogicalPlan.pageSize` is mapped to `OpenSearchIndexScan.maxResponseSize`. This is the limit to the number of elements in a response. +2. Entire Physical Query Plan is created by `PlanSerializer` for Subsequent Query requests. The deserialized plan has the same structure as the Initial Query Request. +3. Implemented serialization and deserialization for `OpenSearchScrollRequest`. + + +The examples below show physical query plan for the same query in different request types: + +```mermaid +stateDiagram-v2 + state "Non Paged Request" as NonPaged { + state "ProjectOperator" as ProjectNP + state "LimitOperator" as LimitNP + state "SortOperator" as SortNP + state "OpenSearchIndexScan" as RelationNP + state "OpenSearchQueryRequest" as QRequestNP + + ProjectNP --> LimitNP + LimitNP --> SortNP + SortNP --> RelationNP + RelationNP --> QRequestNP + } + + state "Initial Query Request" as Paged { + state "ProjectOperator" as ProjectIP + state "LimitOperator" as LimitIP + state "SortOperator" as SortIP + state "OpenSearchIndexScan" as RelationIP + state "OpenSearchQueryRequest" as QRequestIP + + ProjectIP --> LimitIP + LimitIP --> SortIP + SortIP --> RelationIP + RelationIP --> QRequestIP + } + + state "Subsequent Query Request" As Sub { + state "ProjectOperator" as ProjectSP + state "LimitOperator" as LimitSP + state "SortOperator" as SortSP + state "OpenSearchIndexScan" as RelationSP + state "OpenSearchScrollRequest" as RequestSP + + ProjectSP --> LimitSP + LimitSP --> SortSP + SortSP --> RelationSP + RelationSP --> RequestSP + } +``` + +### Architecture Diagrams + +New code workflows which added by Pagination feature are highlighted. + +#### Non Paging Query Request + +A non-paging request sequence diagram is shown below for comparison: + +```mermaid +sequenceDiagram + participant SQLService + participant QueryPlanFactory + participant QueryService + participant Planner + participant CreateTableScanBuilder + participant OpenSearchExecutionEngine + +SQLService ->>+ QueryPlanFactory: execute + QueryPlanFactory ->>+ QueryService: execute + QueryService ->>+ Planner: optimize + Planner ->>+ CreateTableScanBuilder: apply + CreateTableScanBuilder -->>- Planner: index scan + Planner -->>- QueryService: Logical Query Plan + QueryService ->>+ OpenSearchExecutionEngine: execute + OpenSearchExecutionEngine -->>- QueryService: execution completed + QueryService -->>- QueryPlanFactory: execution completed + QueryPlanFactory -->>- SQLService: execution completed +``` + +#### Initial Query Request + +Processing of an Initial Query Request has few extra steps comparing versus processing a regular Query Request: +1. Query validation with `CanPaginateVisitor`. This is required to validate whether incoming query can be paged. This also activate legacy engine fallback mechanism. +2. `Serialization` is performed by `PlanSerializer` - it converts Physical Plan Tree into a cursor, which could be used query a next page. + +```mermaid +sequenceDiagram + participant SQLService + participant QueryPlanFactory + participant CanPaginateVisitor + participant QueryService + participant Planner + participant CreatePagingScanBuilder + participant OpenSearchExecutionEngine + participant PlanSerializer + +SQLService ->>+ QueryPlanFactory : execute + rect rgb(91, 123, 155) + QueryPlanFactory ->>+ CanPaginateVisitor : canConvertToCursor + CanPaginateVisitor -->>- QueryPlanFactory : true + end + QueryPlanFactory ->>+ QueryService : execute + QueryService ->>+ Planner : optimize + rect rgb(91, 123, 155) + Planner ->>+ CreateTableScanBuilder : apply + CreateTableScanBuilder -->>- Planner : paged index scan + end + Planner -->>- QueryService : Logical Query Plan + QueryService ->>+ OpenSearchExecutionEngine : execute + rect rgb(91, 123, 155) + Note over OpenSearchExecutionEngine, PlanSerializer : Serialization + OpenSearchExecutionEngine ->>+ PlanSerializer : convertToCursor + PlanSerializer -->>- OpenSearchExecutionEngine : cursor + end + OpenSearchExecutionEngine -->>- QueryService : execution completed + QueryService -->>- QueryPlanFactory : execution completed + QueryPlanFactory -->>- SQLService : execution completed +``` + +#### Subsequent Query Request + +Subsequent pages are processed by a new workflow. The key point there: +1. `Deserialization` is performed by `PlanSerializer` to restore entire Physical Plan Tree encoded into the cursor. +2. Since query already contains the Physical Plan Tree, all tree processing steps are skipped. +3. `Serialization` is performed by `PlanSerializer` - it converts Physical Plan Tree into a cursor, which could be used query a next page. + +```mermaid +sequenceDiagram + participant QueryPlanFactory + participant QueryService + participant Analyzer + participant Planner + participant DefaultImplementor + participant PlanSerializer + participant OpenSearchExecutionEngine + +QueryPlanFactory ->>+ QueryService : execute + QueryService ->>+ Analyzer : analyze + Analyzer -->>- QueryService : new LogicalFetchCursor + QueryService ->>+ Planner : plan + Planner ->>+ DefaultImplementor : implement + rect rgb(91, 123, 155) + DefaultImplementor ->>+ PlanSerializer : deserialize + PlanSerializer -->>- DefaultImplementor: physical query plan + end + DefaultImplementor -->>- Planner : physical query plan + Planner -->>- QueryService : physical query plan + QueryService ->>+ OpenSearchExecutionEngine : execute + OpenSearchExecutionEngine -->>- QueryService: execution completed + QueryService -->>- QueryPlanFactory : execution completed +``` + +#### Legacy Engine Fallback + +Since pagination in V2 engine supports fewer SQL commands than pagination in legacy engine, a fallback mechanism is created to keep V1 engine features still available for the end user. Pagination fallback is backed by a new exception type which allows legacy engine to intersect execution of a request. + +```mermaid +sequenceDiagram + participant RestSQLQueryAction + participant Legacy Engine + participant SQLService + participant QueryPlanFactory + participant CanPaginateVisitor + +RestSQLQueryAction ->>+ SQLService : prepareRequest + SQLService ->>+ QueryPlanFactory : execute + rect rgb(91, 123, 155) + note over SQLService, CanPaginateVisitor : V2 support check + QueryPlanFactory ->>+ CanPaginateVisitor : canConvertToCursor + CanPaginateVisitor -->>- QueryPlanFactory : false + QueryPlanFactory -->>- RestSQLQueryAction : UnsupportedCursorRequestException + deactivate SQLService + end + RestSQLQueryAction ->> Legacy Engine: accept + Note over Legacy Engine : Processing in Legacy engine + Legacy Engine -->> RestSQLQueryAction : complete +``` + +#### Serialization and Deserialization round trip + +The SQL engine should be able to completely recover the Physical Query Plan to continue its execution to get the next page. Serialization mechanism is responsible for recovering the query plan. note: `ResourceMonitorPlan` isn't serialized, because a new object of this type would be created for the restored query plan before execution. +Serialization and Deserialization are performed by Java object serialization API. + +```mermaid +stateDiagram-v2 + direction LR + state "Initial Query Request Query Plan" as FirstPage + state FirstPage { + state "ProjectOperator" as logState1_1 + state "..." as logState1_2 + state "ResourceMonitorPlan" as logState1_3 + state "OpenSearchIndexScan" as logState1_4 + state "OpenSearchScrollRequest" as logState1_5 + logState1_1 --> logState1_2 + logState1_2 --> logState1_3 + logState1_3 --> logState1_4 + logState1_4 --> logState1_5 + } + + state "Deserialized Query Plan" as SecondPageTree + state SecondPageTree { + state "ProjectOperator" as logState2_1 + state "..." as logState2_2 + state "OpenSearchIndexScan" as logState2_3 + state "OpenSearchScrollRequest" as logState2_4 + logState2_1 --> logState2_2 + logState2_2 --> logState2_3 + logState2_3 --> logState2_4 + } + + state "Subsequent Query Request Query Plan" as SecondPage + state SecondPage { + state "ProjectOperator" as logState3_1 + state "..." as logState3_2 + state "ResourceMonitorPlan" as logState3_3 + state "OpenSearchIndexScan" as logState3_4 + state "OpenSearchScrollRequest" as logState3_5 + logState3_1 --> logState3_2 + logState3_2 --> logState3_3 + logState3_3 --> logState3_4 + logState3_4 --> logState3_5 + } + + FirstPage --> SecondPageTree : Serialization and\nDeserialization + SecondPageTree --> SecondPage : Execution\nPreparation +``` + +#### Serialization + +All query plan nodes which are supported by pagination should implement [`SerializablePlan`](https://github.com/opensearch-project/sql/blob/f40bb6d68241e76728737d88026e4c8b1e6b3b8b/core/src/main/java/org/opensearch/sql/planner/SerializablePlan.java) interface. `getPlanForSerialization` method of this interface allows serialization mechanism to skip a tree node from serialization. OpenSearch search request objects are not serialized, but search context provided by the OpenSearch cluster is extracted from them. + +```mermaid +sequenceDiagram + participant PlanSerializer + participant ProjectOperator + participant ResourceMonitorPlan + participant OpenSearchIndexScan + participant OpenSearchScrollRequest + +PlanSerializer ->>+ ProjectOperator : getPlanForSerialization + ProjectOperator -->>- PlanSerializer : this +PlanSerializer ->>+ ProjectOperator : serialize + Note over ProjectOperator : dump private fields + ProjectOperator ->>+ ResourceMonitorPlan : getPlanForSerialization + ResourceMonitorPlan -->>- ProjectOperator : delegate + Note over ResourceMonitorPlan : ResourceMonitorPlan
is not serialized + ProjectOperator ->>+ OpenSearchIndexScan : writeExternal + OpenSearchIndexScan ->>+ OpenSearchScrollRequest : writeTo + Note over OpenSearchScrollRequest : dump private fields + OpenSearchScrollRequest -->>- OpenSearchIndexScan : serialized request + Note over OpenSearchIndexScan : dump private fields + OpenSearchIndexScan -->>- ProjectOperator : serialized + ProjectOperator -->>- PlanSerializer : serialized +Note over PlanSerializer : Zip to reduce size +``` + +#### Deserialization + +Deserialization restores previously serialized Physical Query Plan. The recovered plan is ready to execute and returns the next page of the search response. To complete the query plan restoration, SQL engine will build a new request to the OpenSearch node. This request doesn't contain a search query, but it contains a search context reference — `scrollID`. To create a new `OpenSearchScrollRequest` object it requires access to the instance of `OpenSearchStorageEngine`. Note: `OpenSearchStorageEngine` can't be serialized, and it exists as a singleton in the SQL plugin engine. `PlanSerializer` creates a customized deserialization binary object stream — `CursorDeserializationStream`. This stream provides an interface to access the `OpenSearchStorageEngine` object. + +```mermaid +sequenceDiagram + participant PlanSerializer + participant CursorDeserializationStream + participant ProjectOperator + participant OpenSearchIndexScan + participant OpenSearchScrollRequest + +Note over PlanSerializer : Unzip +Note over PlanSerializer : Validate cursor integrity +PlanSerializer ->>+ CursorDeserializationStream : deserialize + CursorDeserializationStream ->>+ ProjectOperator : create new + Note over ProjectOperator: load private fields + ProjectOperator -->> CursorDeserializationStream : deserialize input + activate CursorDeserializationStream + CursorDeserializationStream ->>+ OpenSearchIndexScan : create new + deactivate CursorDeserializationStream + OpenSearchIndexScan -->>+ CursorDeserializationStream : resolve engine + CursorDeserializationStream ->>- OpenSearchIndexScan : OpenSearchStorageEngine + Note over OpenSearchIndexScan : load private fields + OpenSearchIndexScan ->>+ OpenSearchScrollRequest : create new + OpenSearchScrollRequest -->>- OpenSearchIndexScan : created + OpenSearchIndexScan -->>- ProjectOperator : deserialized + ProjectOperator -->>- PlanSerializer : deserialized + deactivate CursorDeserializationStream +``` + +#### Close Cursor + +A user can forcibly close a cursor (scroll) at any moment of paging. Automatic close occurs when paging is complete and no more results left. +Close cursor protocol defined by following: +1. REST endpoint: `/_plugins/_sql/close` +2. Request type: `POST` +3. Request format: +```json +{ + "cursor" : "" +} +``` +4. Response format: +```json +{ + "succeeded": true +} +``` +5. Failure or error: [error response](#error-response) +6. Use or sequential close of already closed cursor produces the same error as use of expired/auto-closed/non-existing cursor. + +```mermaid +sequenceDiagram +SQLService ->>+ QueryPlanFactory : execute + QueryPlanFactory ->>+ QueryService : execute + QueryService ->>+ Analyzer : analyze + Analyzer -->>- QueryService : new LogicalCloseCursor + QueryService ->>+ Planner : plan + Planner ->>+ DefaultImplementor : implement + DefaultImplementor ->>+ PlanSerializer : deserialize + PlanSerializer -->>- DefaultImplementor: physical query plan + DefaultImplementor -->>- Planner : new CloseOperator + Planner -->>- QueryService : CloseOperator + QueryService ->>+ OpenSearchExecutionEngine : execute + Note over OpenSearchExecutionEngine : Open is no-op, no request issued,
no results received and processed + Note over OpenSearchExecutionEngine : Clean-up (clear scroll) on auto-close + OpenSearchExecutionEngine -->>- QueryService: execution completed + QueryService -->>- QueryPlanFactory : execution completed + QueryPlanFactory -->>- SQLService : execution completed +``` + +```mermaid +stateDiagram-v2 + direction LR + state "Abstract Query Plan" as Abstract { + state "CommandPlan" as CommandPlan { + state "Unresolved Query Plan" as Unresolved { + state "CloseCursor" as CloseCursor + state "FetchCursor" as FetchCursor + + CloseCursor --> FetchCursor + } + } + } + state "Logical Query Plan" as Logical { + state "LogicalCloseCursor" as LogicalCloseCursor + state "LogicalFetchCursor" as LogicalFetchCursor + + LogicalCloseCursor --> LogicalFetchCursor + } + state "Optimized Query Plan" as Optimized { + state "LogicalCloseCursor" as LogicalCloseCursorO + state "LogicalFetchCursor" as LogicalFetchCursorO + + LogicalCloseCursorO --> LogicalFetchCursorO + } + state "Physical Query Plan" as Physical { + state "CursorCloseOperator" as CursorCloseOperator + state "ProjectOperator" as ProjectOperator + state "..." as ... + state "OpenSearchIndexScan" as OpenSearchIndexScan + + CursorCloseOperator --> ProjectOperator + ProjectOperator --> ... + ... --> OpenSearchIndexScan + } + + [*] --> Unresolved : QueryPlanner + Unresolved --> Logical : Planner + Logical --> Optimized : Optimizer + Optimized --> Physical : Implementor +``` + +`CursorCloseOperator` provides a dummy (empty, since not used) `Schema`, does not perform `open` and always returns `false` by `hasNext`. Such behavior makes it a no-op operator which blocks underlying Physical Plan Tree from issuing any search request, but does not block auto-close provided by `AutoCloseable`. Default close action clears scroll context. +Regular paging doesn't execute scroll clear, because it checks whether paging is finished or not and raises a flag to prevent clear. This check performed when search response recevied, what never happen due to `CursorCloseOperator`. + +```py +class OpenSearchScrollRequest: + bool needClean = true + + def search: + ... + needClean = response.isEmpty() + + def clean: + if needClean: + clearScroll() +``` + +```py +class CursorCloseOperator(PhysicalPlan): + PhysicalPlan tree + def open: + pass + # no-op, don't propagate `open` of underlying plan tree + + def hasNext: + return false +``` + +```py +class PhysicalPlan: + def open: + innerPlan.open() + + def close: + innerPlan.close() +``` diff --git a/docs/dev/index.md b/docs/dev/index.md index c64bed7882..4b8745e2df 100644 --- a/docs/dev/index.md +++ b/docs/dev/index.md @@ -45,7 +45,7 @@ + [Semantic Analysis](query-semantic-analysis.md): performs semantic analysis to ensure semantic correctness + [Type Conversion](query-type-conversion.md): implement implicit data type conversion + **Query Planning** - + [Logical Optimization](query-optimizier-improvement.md): improvement on logical optimizer and physical implementer + + [Logical Optimization](query-optimizer-improvement.md): improvement on logical optimizer and physical implementer + **Query Execution** + [Query Manager](query-manager.md): query management + **Query Acceleration** @@ -57,6 +57,7 @@ + [Relevancy Search](opensearch-relevancy-search.md): OpenSearch relevancy search functions + [Sub Queries](opensearch-nested-field-subquery.md): support sub queries on OpenSearch nested field + [Pagination](opensearch-pagination.md): pagination implementation by OpenSearch scroll API + + [Pagination in V2](Pagination-v2.md): pagination implementation in V2 engine + [Nested Function](sql-nested-function.md): Nested function in sql plugin + [Prometheus](datasource-prometheus.md): Prometheus query federation + **File System** diff --git a/docs/dev/query-optimizer-improvement.md b/docs/dev/query-optimizer-improvement.md new file mode 100644 index 0000000000..720649b280 --- /dev/null +++ b/docs/dev/query-optimizer-improvement.md @@ -0,0 +1,208 @@ +### Background + +This section introduces the current architecture of logical optimizer and physical transformation. + +#### Logical-to-Logical Optimization + +Currently each storage engine adds its own logical operator as concrete implementation for `TableScanOperator` abstraction. Typically each data source needs to add 2 logical operators for table scan with without aggregation. Take OpenSearch for example, there are `OpenSearchLogicalIndexScan` and `OpenSearchLogicalIndexAgg` and a bunch of pushdown optimization rules for each accordingly. + +```py +class LogicalPlanOptimizer: + /* + * OpenSearch rules include: + * MergeFilterAndRelation + * MergeAggAndIndexScan + * MergeAggAndRelation + * MergeSortAndRelation + * MergeSortAndIndexScan + * MergeSortAndIndexAgg + * MergeSortAndIndexScan + * MergeLimitAndRelation + * MergeLimitAndIndexScan + * PushProjectAndRelation + * PushProjectAndIndexScan + * + * that return *OpenSearchLogicalIndexAgg* + * or *OpenSearchLogicalIndexScan* finally + */ + val rules: List + + def optimize(plan: LogicalPlan): + for rule in rules: + if rule.match(plan): + plan = rules.apply(plan) + return plan.children().forEach(this::optimize) +``` + +#### Logical-to-Physical Transformation + +After logical transformation, planner will let the `Table` in `LogicalRelation` (identified before logical transformation above) transform the logical plan to physical plan. + +```py +class OpenSearchIndex: + + def implement(plan: LogicalPlan): + return plan.accept( + DefaultImplementor(): + def visitNode(node): + if node is OpenSearchLogicalIndexScan: + return OpenSearchIndexScan(...) + else if node is OpenSearchLogicalIndexAgg: + return OpenSearchIndexScan(...) +``` + +### Problem Statement + +The current planning architecture causes 2 serious problems: + +1. Each data source adds special logical operator and explode the optimizer rule space. For example, Prometheus also has `PrometheusLogicalMetricAgg` and `PrometheusLogicalMetricScan` accordingly. They have the exactly same pattern to match query plan tree as OpenSearch. +2. A bigger problem is the difficulty of transforming from logical to physical when there are 2 `Table`s in query plan. Because only 1 of them has the chance to do the `implement()`. This is a blocker for supporting `INSERT ... SELECT ...` statement or JOIN query. See code below. + +```java + public PhysicalPlan plan(LogicalPlan plan) { + Table table = findTable(plan); + if (table == null) { + return plan.accept(new DefaultImplementor<>(), null); + } + return table.implement( + table.optimize(optimize(plan))); + } +``` + +### Solution + +#### TableScanBuilder + +A new abstraction `TableScanBuilder` is added as a transition operator during logical planning and optimization. Each data source provides its implementation class by `Table` interface. The push down difference in non-aggregate and aggregate query is hidden inside specific scan builder, for example `OpenSearchIndexScanBuilder` rather than exposed to core module. + +```mermaid +classDiagram +%% Mermaid fails to render `LogicalPlanNodeVisitor~R, C~` https://github.com/mermaid-js/mermaid/issues/3188, using `<R, C>` as a workaround + class LogicalPlan { + -List~LogicalPlan~ childPlans + +LogicalPlan(List~LogicalPlan~) + +accept(LogicalPlanNodeVisitor<R, C>, C)* R + +replaceChildPlans(List~LogicalPlan~ childPlans) LogicalPlan + } + class TableScanBuilder { + +TableScanBuilder() + +build()* TableScanOperator + +pushDownFilter(LogicalFilter) boolean + +pushDownAggregation(LogicalAggregation) boolean + +pushDownSort(LogicalSort) boolean + +pushDownLimit(LogicalLimit) boolean + +pushDownPageSize(LogicalPaginate) boolean + +pushDownProject(LogicalProject) boolean + +pushDownHighlight(LogicalHighlight) boolean + +pushDownNested(LogicalNested) boolean + +accept(LogicalPlanNodeVisitor<R, C>, C) R + } + class OpenSearchIndexScanQueryBuilder { + OpenSearchIndexScanQueryBuilder(OpenSearchIndexScan) + +build() TableScanOperator + +pushDownFilter(LogicalFilter) boolean + +pushDownAggregation(LogicalAggregation) boolean + +pushDownSort(LogicalSort) boolean + +pushDownLimit(LogicalLimit) boolean + +pushDownPageSize(LogicalPaginate) boolean + +pushDownProject(LogicalProject) boolean + +pushDownHighlight(LogicalHighlight) boolean + +pushDownNested(LogicalNested) boolean + +findReferenceExpression(NamedExpression)$ List~ReferenceExpression~ + +findReferenceExpressions(List~NamedExpression~)$ Set~ReferenceExpression~ + } + class OpenSearchIndexScanBuilder { + -TableScanBuilder delegate + -boolean isLimitPushedDown + +OpenSearchIndexScanBuilder(OpenSearchIndexScan) + OpenSearchIndexScanBuilder(TableScanBuilder) + +build() TableScanOperator + +pushDownFilter(LogicalFilter) boolean + +pushDownAggregation(LogicalAggregation) boolean + +pushDownSort(LogicalSort) boolean + +pushDownLimit(LogicalLimit) boolean + +pushDownProject(LogicalProject) boolean + +pushDownHighlight(LogicalHighlight) boolean + +pushDownNested(LogicalNested) boolean + -sortByFieldsOnly(LogicalSort) boolean + } + + LogicalPlan <|-- TableScanBuilder + TableScanBuilder <|-- OpenSearchIndexScanQueryBuilder + TableScanBuilder <|-- OpenSearchIndexScanBuilder + OpenSearchIndexScanBuilder *-- "1" TableScanBuilder : delegate + OpenSearchIndexScanBuilder <.. OpenSearchIndexScanQueryBuilder : creates +``` + +#### Table Push Down Rules + +In this way, `LogicalPlanOptimizer` in core module always have the same set of rule for all push down optimization. + +```mermaid +classDiagram + class LogicalPlanOptimizer { + +create()$ LogicalPlanOptimizer + +optimize(LogicalPlan) LogicalPlan + -internalOptimize(LogicalPlan) LogicalPlan + } + class CreateTableScanBuilder { + +apply(LogicalRelation, Captures) LogicalPlan + -pattern() Pattern~LogicalRelation~ + } + class CreatePagingTableScanBuilder { + +apply(LogicalPaginate, Captures) LogicalPlan + -pattern() Pattern~LogicalRelation~ + -findLogicalRelation(LogicalPaginate) boolean + } + class Table { + +TableScanBuilder createScanBuilder() + } + class TableScanPushDown~T~ { + +Rule~T~ PUSH_DOWN_FILTER$ + +Rule~T~ PUSH_DOWN_AGGREGATION$ + +Rule~T~ PUSH_DOWN_SORT$ + +Rule~T~ PUSH_DOWN_LIMIT$ + +Rule~T~ PUSH_DOWN_PROJECT$ + +Rule~T~ PUSH_DOWN_HIGHLIGHT$ + +Rule~T~ PUSH_DOWN_NESTED$ + +apply(T, Captures) LogicalPlan + +pattern() Pattern~T~ + } + class TableScanBuilder { + +pushDownFilter(LogicalFilter) boolean + +pushDownAggregation(LogicalAggregation) boolean + +pushDownSort(LogicalSort) boolean + +pushDownLimit(LogicalLimit) boolean + +pushDownProject(LogicalProject) boolean + +pushDownHighlight(LogicalHighlight) boolean + +pushDownNested(LogicalNested) boolean + } + TableScanPushDown~T~ -- TableScanBuilder + LogicalPlanOptimizer ..> CreateTableScanBuilder : creates + LogicalPlanOptimizer ..> CreatePagingTableScanBuilder : creates + CreateTableScanBuilder ..> Table + CreatePagingTableScanBuilder ..> Table + LogicalPlanOptimizer ..* TableScanPushDown~T~ + Table ..> TableScanBuilder : creates +``` + +### Examples + +The following diagram illustrates how `TableScanBuilder` along with `TablePushDownRule` solve the problem aforementioned. + +![optimizer-Page-1](https://user-images.githubusercontent.com/46505291/203645359-3f2fff73-a210-4bc0-a582-951a27de684d.jpg) + + +Similarly, `TableWriteBuilder` will be added and work in the same way in separate PR: https://github.com/opensearch-project/sql/pull/1094 + +![optimizer-Page-2](https://user-images.githubusercontent.com/46505291/203645380-5155fd22-71b4-49ca-8ed7-9652b005f761.jpg) + +### TODO + +1. Refactor Prometheus optimize rule and enforce table scan builder +2. Figure out how to implement AD commands +4. Deprecate `optimize()` and `implement()` if item 1 and 2 complete +5. Introduce fixed point or maximum iteration limit for iterative optimization +6. Investigate if CBO should be part of current optimizer or distributed planner in future +7. Remove `pushdownHighlight` once it's moved to OpenSearch storage +8. Move `TableScanOperator` to the new `read` package (leave it in this PR to avoid even more file changed) diff --git a/docs/dev/query-optimizier-improvement.md b/docs/dev/query-optimizier-improvement.md deleted file mode 100644 index 753abcc844..0000000000 --- a/docs/dev/query-optimizier-improvement.md +++ /dev/null @@ -1,106 +0,0 @@ -### Background - -This section introduces the current architecture of logical optimizer and physical transformation. - -#### Logical-to-Logical Optimization - -Currently each storage engine adds its own logical operator as concrete implementation for `TableScanOperator` abstraction. Typically each data source needs to add 2 logical operators for table scan with and without aggregation. Take OpenSearch for example, there are `OpenSearchLogicalIndexScan` and `OpenSearchLogicalIndexAgg` and a bunch of pushdown optimization rules for each accordingly. - -``` -class LogicalPlanOptimizer: - /* - * OpenSearch rules include: - * MergeFilterAndRelation - * MergeAggAndIndexScan - * MergeAggAndRelation - * MergeSortAndRelation - * MergeSortAndIndexScan - * MergeSortAndIndexAgg - * MergeSortAndIndexScan - * MergeLimitAndRelation - * MergeLimitAndIndexScan - * PushProjectAndRelation - * PushProjectAndIndexScan - * - * that return *OpenSearchLogicalIndexAgg* - * or *OpenSearchLogicalIndexScan* finally - */ - val rules: List - - def optimize(plan: LogicalPlan): - for rule in rules: - if rule.match(plan): - plan = rules.apply(plan) - return plan.children().forEach(this::optimize) -``` - -#### Logical-to-Physical Transformation - -After logical transformation, planner will let the `Table` in `LogicalRelation` (identified before logical transformation above) transform the logical plan to physical plan. - -``` -class OpenSearchIndex: - - def implement(plan: LogicalPlan): - return plan.accept( - DefaultImplementor(): - def visitNode(node): - if node is OpenSearchLogicalIndexScan: - return OpenSearchIndexScan(...) - else if node is OpenSearchLogicalIndexAgg: - return OpenSearchIndexScan(...) -``` - -### Problem Statement - -The current planning architecture causes 2 serious problems: - -1. Each data source adds special logical operator and explode the optimizer rule space. For example, Prometheus also has `PrometheusLogicalMetricAgg` and `PrometheusLogicalMetricScan` accordingly. They have the exactly same pattern to match query plan tree as OpenSearch. -2. A bigger problem is the difficulty of transforming from logical to physical when there are 2 `Table`s in query plan. Because only 1 of them has the chance to do the `implement()`. This is a blocker for supporting `INSERT ... SELECT ...` statement or JOIN query. See code below. - -``` - public PhysicalPlan plan(LogicalPlan plan) { - Table table = findTable(plan); - if (table == null) { - return plan.accept(new DefaultImplementor<>(), null); - } - return table.implement( - table.optimize(optimize(plan))); - } -``` - -### Solution - -#### TableScanBuilder - -A new abstraction `TableScanBuilder` is added as a transition operator during logical planning and optimization. Each data source provides its implementation class by `Table` interface. The push down difference in non-aggregate and aggregate query is hidden inside specific scan builder, for example `OpenSearchIndexScanBuilder` rather than exposed to core module. - -![TableScanBuilder](https://user-images.githubusercontent.com/46505291/204355538-e54f7679-3585-423e-97d5-5832b2038cc1.png) - -#### TablePushDownRules - -In this way, `LogicalOptimizier` in core module always have the same set of rule for all push down optimization. - -![LogicalPlanOptimizer](https://user-images.githubusercontent.com/46505291/203142195-9b38f1e9-1116-469d-9709-3cbf893ec522.png) - - -### Examples - -The following diagram illustrates how `TableScanBuilder` along with `TablePushDownRule` solve the problem aforementioned. - -![optimizer-Page-1](https://user-images.githubusercontent.com/46505291/203645359-3f2fff73-a210-4bc0-a582-951a27de684d.jpg) - - -Similarly, `TableWriteBuilder` will be added and work in the same way in separate PR: https://github.com/opensearch-project/sql/pull/1094 - -![optimizer-Page-2](https://user-images.githubusercontent.com/46505291/203645380-5155fd22-71b4-49ca-8ed7-9652b005f761.jpg) - -### TODO - -1. Refactor Prometheus optimize rule and enforce table scan builder -2. Figure out how to implement AD commands -4. Deprecate `optimize()` and `implement()` if item 1 and 2 complete -5. Introduce fixed point or maximum iteration limit for iterative optimization -6. Investigate if CBO should be part of current optimizer or distributed planner in future -7. Remove `pushdownHighlight` once it's moved to OpenSearch storage -8. Move `TableScanOperator` to the new `read` package (leave it in this PR to avoid even more file changed) \ No newline at end of file diff --git a/docs/user/optimization/optimization.rst b/docs/user/optimization/optimization.rst index e0fe943560..8ab998309d 100644 --- a/docs/user/optimization/optimization.rst +++ b/docs/user/optimization/optimization.rst @@ -287,7 +287,7 @@ The Aggregation operator will merge into OpenSearch Aggregation:: { "name": "OpenSearchIndexScan", "description": { - "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":0,\"timeout\":\"1m\",\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"gender\":{\"terms\":{\"field\":\"gender.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg(age)\":{\"avg\":{\"field\":\"age\"}}}}}}, searchDone=false)" + "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":200,\"timeout\":\"1m\",\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"gender\":{\"terms\":{\"field\":\"gender.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg(age)\":{\"avg\":{\"field\":\"age\"}}}}}}, searchDone=false)" }, "children": [] } @@ -313,7 +313,7 @@ The Sort operator will merge into OpenSearch Aggregation.:: { "name": "OpenSearchIndexScan", "description": { - "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":0,\"timeout\":\"1m\",\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"gender\":{\"terms\":{\"field\":\"gender.keyword\",\"missing_bucket\":true,\"missing_order\":\"last\",\"order\":\"desc\"}}}]},\"aggregations\":{\"avg(age)\":{\"avg\":{\"field\":\"age\"}}}}}}, searchDone=false)" + "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":200,\"timeout\":\"1m\",\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"gender\":{\"terms\":{\"field\":\"gender.keyword\",\"missing_bucket\":true,\"missing_order\":\"last\",\"order\":\"desc\"}}}]},\"aggregations\":{\"avg(age)\":{\"avg\":{\"field\":\"age\"}}}}}}, searchDone=false)" }, "children": [] } @@ -348,7 +348,7 @@ Because the OpenSearch Composite Aggregation doesn't support order by metrics fi { "name": "OpenSearchIndexScan", "description": { - "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":0,\"timeout\":\"1m\",\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"gender\":{\"terms\":{\"field\":\"gender.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg(age)\":{\"avg\":{\"field\":\"age\"}}}}}}, searchDone=false)" + "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":200,\"timeout\":\"1m\",\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"gender\":{\"terms\":{\"field\":\"gender.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg(age)\":{\"avg\":{\"field\":\"age\"}}}}}}, searchDone=false)" }, "children": [] } diff --git a/docs/user/ppl/interfaces/endpoint.rst b/docs/user/ppl/interfaces/endpoint.rst index fb64eff688..793b94eb8d 100644 --- a/docs/user/ppl/interfaces/endpoint.rst +++ b/docs/user/ppl/interfaces/endpoint.rst @@ -91,7 +91,7 @@ The following PPL query demonstrated that where and stats command were pushed do { "name": "OpenSearchIndexScan", "description": { - "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":0,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":10,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"avg(age)\":{\"avg\":{\"field\":\"age\"}}}}, searchDone=false)" + "request": "OpenSearchQueryRequest(indexName=accounts, sourceBuilder={\"from\":0,\"size\":200,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":10,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"avg(age)\":{\"avg\":{\"field\":\"age\"}}}}, searchDone=false)" }, "children": [] } diff --git a/integ-test/build.gradle b/integ-test/build.gradle index 555121f9c7..2e20965ddd 100644 --- a/integ-test/build.gradle +++ b/integ-test/build.gradle @@ -126,6 +126,11 @@ compileTestJava { testClusters.all { testDistribution = 'archive' + + // debug with command, ./gradlew opensearch-sql:run -DdebugJVM. --debug-jvm does not work with keystore. + if (System.getProperty("debugJVM") != null) { + jvmArgs '-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=*:5005' + } } testClusters.integTest { @@ -256,10 +261,16 @@ integTest { // Tell the test JVM if the cluster JVM is running under a debugger so that tests can use longer timeouts for // requests. The 'doFirst' delays reading the debug setting on the cluster till execution time. - doFirst { systemProperty 'cluster.debug', getDebug() } + doFirst { + if (System.getProperty("debug-jvm") != null) { + setDebug(true); + } + systemProperty 'cluster.debug', getDebug() + } + if (System.getProperty("test.debug") != null) { - jvmArgs '-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=*:5005' + jvmArgs '-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=*:5006' } if (System.getProperty("tests.rest.bwcsuite") == null) { diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/CursorIT.java b/integ-test/src/test/java/org/opensearch/sql/legacy/CursorIT.java index 113a19885a..5b9a583d04 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/CursorIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/CursorIT.java @@ -123,11 +123,16 @@ public void validNumberOfPages() throws IOException { String selectQuery = StringUtils.format("SELECT firstname, state FROM %s", TEST_INDEX_ACCOUNT); JSONObject response = new JSONObject(executeFetchQuery(selectQuery, 50, JDBC)); String cursor = response.getString(CURSOR); + verifyIsV1Cursor(cursor); + int pageCount = 1; while (!cursor.isEmpty()) { //this condition also checks that there is no cursor on last page response = executeCursorQuery(cursor); cursor = response.optString(CURSOR); + if (!cursor.isEmpty()) { + verifyIsV1Cursor(cursor); + } pageCount++; } @@ -136,12 +141,16 @@ public void validNumberOfPages() throws IOException { // using random value here, with fetch size of 28 we should get 36 pages (ceil of 1000/28) response = new JSONObject(executeFetchQuery(selectQuery, 28, JDBC)); cursor = response.getString(CURSOR); + verifyIsV1Cursor(cursor); System.out.println(response); pageCount = 1; while (!cursor.isEmpty()) { response = executeCursorQuery(cursor); cursor = response.optString(CURSOR); + if (!cursor.isEmpty()) { + verifyIsV1Cursor(cursor); + } pageCount++; } assertThat(pageCount, equalTo(36)); @@ -223,6 +232,7 @@ public void testCursorWithPreparedStatement() throws IOException { "}", TestsConstants.TEST_INDEX_ACCOUNT)); assertTrue(response.has(CURSOR)); + verifyIsV1Cursor(response.getString(CURSOR)); } @Test @@ -244,11 +254,13 @@ public void testRegressionOnDateFormatChange() throws IOException { StringUtils.format("SELECT login_time FROM %s LIMIT 500", TEST_INDEX_DATE_TIME); JSONObject response = new JSONObject(executeFetchQuery(selectQuery, 1, JDBC)); String cursor = response.getString(CURSOR); + verifyIsV1Cursor(cursor); actualDateList.add(response.getJSONArray(DATAROWS).getJSONArray(0).getString(0)); while (!cursor.isEmpty()) { response = executeCursorQuery(cursor); cursor = response.optString(CURSOR); + verifyIsV1Cursor(cursor); actualDateList.add(response.getJSONArray(DATAROWS).getJSONArray(0).getString(0)); } @@ -274,7 +286,6 @@ public void defaultBehaviorWhenCursorSettingIsDisabled() throws IOException { query = StringUtils.format("SELECT firstname, email, state FROM %s", TEST_INDEX_ACCOUNT); response = new JSONObject(executeFetchQuery(query, 100, JDBC)); assertTrue(response.has(CURSOR)); - wipeAllClusterSettings(); } @@ -305,12 +316,14 @@ public void testDefaultFetchSizeFromClusterSettings() throws IOException { JSONObject response = new JSONObject(executeFetchLessQuery(query, JDBC)); JSONArray datawRows = response.optJSONArray(DATAROWS); assertThat(datawRows.length(), equalTo(1000)); + verifyIsV1Cursor(response.getString(CURSOR)); updateClusterSettings(new ClusterSetting(TRANSIENT, "opensearch.sql.cursor.fetch_size", "786")); response = new JSONObject(executeFetchLessQuery(query, JDBC)); datawRows = response.optJSONArray(DATAROWS); assertThat(datawRows.length(), equalTo(786)); assertTrue(response.has(CURSOR)); + verifyIsV1Cursor(response.getString(CURSOR)); wipeAllClusterSettings(); } @@ -323,11 +336,12 @@ public void testCursorCloseAPI() throws IOException { "SELECT firstname, state FROM %s WHERE balance > 100 and age < 40", TEST_INDEX_ACCOUNT); JSONObject result = new JSONObject(executeFetchQuery(selectQuery, 50, JDBC)); String cursor = result.getString(CURSOR); - + verifyIsV1Cursor(cursor); // Retrieving next 10 pages out of remaining 19 pages for (int i = 0; i < 10; i++) { result = executeCursorQuery(cursor); cursor = result.optString(CURSOR); + verifyIsV1Cursor(cursor); } //Closing the cursor JSONObject closeResp = executeCursorCloseQuery(cursor); @@ -386,12 +400,14 @@ public void respectLimitPassedInSelectClause() throws IOException { StringUtils.format("SELECT age, balance FROM %s LIMIT %s", TEST_INDEX_ACCOUNT, limit); JSONObject response = new JSONObject(executeFetchQuery(selectQuery, 50, JDBC)); String cursor = response.getString(CURSOR); + verifyIsV1Cursor(cursor); int actualDataRowCount = response.getJSONArray(DATAROWS).length(); int pageCount = 1; while (!cursor.isEmpty()) { response = executeCursorQuery(cursor); cursor = response.optString(CURSOR); + verifyIsV1Cursor(cursor); actualDataRowCount += response.getJSONArray(DATAROWS).length(); pageCount++; } @@ -432,10 +448,12 @@ public void verifyWithAndWithoutPaginationResponse(String sqlQuery, String curso response.optJSONArray(DATAROWS).forEach(dataRows::put); String cursor = response.getString(CURSOR); + verifyIsV1Cursor(cursor); while (!cursor.isEmpty()) { response = executeCursorQuery(cursor); response.optJSONArray(DATAROWS).forEach(dataRows::put); cursor = response.optString(CURSOR); + verifyIsV1Cursor(cursor); } verifySchema(withoutCursorResponse.optJSONArray(SCHEMA), @@ -465,6 +483,13 @@ public String executeFetchAsStringQuery(String query, String fetchSize, String r return responseString; } + private void verifyIsV1Cursor(String cursor) { + if (cursor.isEmpty()) { + return; + } + assertTrue("The cursor '" + cursor + "' is not from v1 engine.", cursor.startsWith("d:")); + } + private String makeRequest(String query, String fetch_size) { return String.format("{" + " \"fetch_size\": \"%s\"," + diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java b/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java index d1bcc94506..7216c03d08 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java @@ -269,6 +269,17 @@ protected String executeFetchQuery(String query, int fetchSize, String requestTy return responseString; } + protected JSONObject executeQueryTemplate(String queryTemplate, String index, int fetchSize) + throws IOException { + var query = String.format(queryTemplate, index); + return new JSONObject(executeFetchQuery(query, fetchSize, "jdbc")); + } + + protected JSONObject executeQueryTemplate(String queryTemplate, String index) throws IOException { + var query = String.format(queryTemplate, index); + return executeQueryTemplate(queryTemplate, index, 4); + } + protected String executeFetchLessQuery(String query, String requestType) throws IOException { String endpoint = "/_plugins/_sql?format=" + requestType; diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java index cca7833d66..b1fcbf7d1b 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java @@ -41,28 +41,29 @@ import org.opensearch.sql.executor.QueryManager; import org.opensearch.sql.executor.QueryService; import org.opensearch.sql.executor.execution.QueryPlanFactory; +import org.opensearch.sql.executor.pagination.PlanSerializer; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; import org.opensearch.sql.monitor.AlwaysHealthyMonitor; import org.opensearch.sql.monitor.ResourceMonitor; -import org.opensearch.sql.opensearch.client.OpenSearchClient; -import org.opensearch.sql.opensearch.client.OpenSearchRestClient; import org.opensearch.sql.opensearch.executor.OpenSearchExecutionEngine; import org.opensearch.sql.opensearch.executor.protector.ExecutionProtector; import org.opensearch.sql.opensearch.executor.protector.OpenSearchExecutionProtector; -import org.opensearch.sql.opensearch.security.SecurityAccess; -import org.opensearch.sql.opensearch.storage.OpenSearchDataSourceFactory; import org.opensearch.sql.opensearch.storage.OpenSearchStorageEngine; import org.opensearch.sql.planner.Planner; import org.opensearch.sql.planner.optimizer.LogicalPlanOptimizer; import org.opensearch.sql.ppl.antlr.PPLSyntaxParser; -import org.opensearch.sql.ppl.domain.PPLQueryRequest; -import org.opensearch.sql.protocol.response.QueryResult; -import org.opensearch.sql.protocol.response.format.SimpleJsonResponseFormatter; import org.opensearch.sql.sql.SQLService; import org.opensearch.sql.sql.antlr.SQLSyntaxParser; -import org.opensearch.sql.storage.DataSourceFactory; import org.opensearch.sql.storage.StorageEngine; import org.opensearch.sql.util.ExecuteOnCallerThreadQueryManager; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.client.OpenSearchRestClient; +import org.opensearch.sql.opensearch.security.SecurityAccess; +import org.opensearch.sql.opensearch.storage.OpenSearchDataSourceFactory; +import org.opensearch.sql.ppl.domain.PPLQueryRequest; +import org.opensearch.sql.protocol.response.QueryResult; +import org.opensearch.sql.protocol.response.format.SimpleJsonResponseFormatter; +import org.opensearch.sql.storage.DataSourceFactory; /** * Run PPL with query engine outside OpenSearch cluster. This IT doesn't require our plugin @@ -71,13 +72,11 @@ */ public class StandaloneIT extends PPLIntegTestCase { - private RestHighLevelClient restClient; - private PPLService pplService; @Override public void init() { - restClient = new InternalRestHighLevelClient(client()); + RestHighLevelClient restClient = new InternalRestHighLevelClient(client()); OpenSearchClient client = new OpenSearchRestClient(restClient); DataSourceService dataSourceService = new DataSourceServiceImpl( new ImmutableSet.Builder() @@ -198,8 +197,9 @@ public StorageEngine storageEngine(OpenSearchClient client) { } @Provides - public ExecutionEngine executionEngine(OpenSearchClient client, ExecutionProtector protector) { - return new OpenSearchExecutionEngine(client, protector); + public ExecutionEngine executionEngine(OpenSearchClient client, ExecutionProtector protector, + PlanSerializer planSerializer) { + return new OpenSearchExecutionEngine(client, protector, planSerializer); } @Provides @@ -228,18 +228,23 @@ public SQLService sqlService(QueryManager queryManager, QueryPlanFactory queryPl return new SQLService(new SQLSyntaxParser(), queryManager, queryPlanFactory); } + @Provides + public PlanSerializer planSerializer(StorageEngine storageEngine) { + return new PlanSerializer(storageEngine); + } + @Provides public QueryPlanFactory queryPlanFactory(ExecutionEngine executionEngine) { Analyzer analyzer = new Analyzer( new ExpressionAnalyzer(functionRepository), dataSourceService, functionRepository); Planner planner = new Planner(LogicalPlanOptimizer.create()); - return new QueryPlanFactory(new QueryService(analyzer, executionEngine, planner)); + QueryService queryService = new QueryService(analyzer, executionEngine, planner); + return new QueryPlanFactory(queryService); } } - - private DataSourceMetadataStorage getDataSourceMetadataStorage() { + public static DataSourceMetadataStorage getDataSourceMetadataStorage() { return new DataSourceMetadataStorage() { @Override public List getDataSourceMetadata() { @@ -268,7 +273,7 @@ public void deleteDataSourceMetadata(String datasourceName) { }; } - private DataSourceUserAuthorizationHelper getDataSourceUserRoleHelper() { + public static DataSourceUserAuthorizationHelper getDataSourceUserRoleHelper() { return new DataSourceUserAuthorizationHelper() { @Override public void authorizeDataSource(DataSourceMetadata dataSourceMetadata) { @@ -276,5 +281,4 @@ public void authorizeDataSource(DataSourceMetadata dataSourceMetadata) { } }; } - } diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/HighlightFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/HighlightFunctionIT.java index 809e2dc7c5..0ab6d5c70f 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/HighlightFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/HighlightFunctionIT.java @@ -64,7 +64,7 @@ public void highlight_multiple_optional_arguments_test() { schema("highlight(Body, pre_tags='', " + "post_tags='')", null, "nested")); - assertEquals(1, response.getInt("total")); + assertEquals(1, response.getInt("size")); verifyDataRows(response, rows(new JSONArray(List.of("What are the differences between an IPA" + " and its variants?")), diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/PaginationBlackboxIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/PaginationBlackboxIT.java new file mode 100644 index 0000000000..2a34dabd79 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/PaginationBlackboxIT.java @@ -0,0 +1,101 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.sql.sql; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import lombok.SneakyThrows; +import org.json.JSONArray; +import org.json.JSONObject; +import org.junit.Test; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.opensearch.sql.legacy.SQLIntegTestCase; +import org.opensearch.sql.util.TestUtils; + +// This class has only one test case, because it is parametrized and takes significant time +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +public class PaginationBlackboxIT extends SQLIntegTestCase { + + private final Index index; + private final Integer pageSize; + + public PaginationBlackboxIT(@Name("index") Index index, + @Name("pageSize") Integer pageSize) { + this.index = index; + this.pageSize = pageSize; + } + + @Override + protected void init() throws IOException { + loadIndex(index); + } + + @ParametersFactory(argumentFormatting = "index = %1$s, page_size = %2$d") + public static Iterable compareTwoDates() { + var indices = List.of(Index.ACCOUNT, Index.BEER, Index.BANK); + var pageSizes = List.of(5, 10, 100, 1000); + var testData = new ArrayList(); + for (var index : indices) { + for (var pageSize : pageSizes) { + testData.add(new Object[] { index, pageSize }); + } + } + return testData; + } + + @Test + @SneakyThrows + public void test_pagination_blackbox() { + var response = executeJdbcRequest(String.format("select * from %s", index.getName())); + var indexSize = response.getInt("total"); + var rows = response.getJSONArray("datarows"); + var schema = response.getJSONArray("schema"); + var testReportPrefix = String.format("index: %s, page size: %d || ", index.getName(), pageSize); + var rowsPaged = new JSONArray(); + var rowsReturned = 0; + + var responseCounter = 1; + this.logger.info(testReportPrefix + "first response"); + response = new JSONObject(executeFetchQuery( + String.format("select * from %s", index.getName()), pageSize, "jdbc")); + + var cursor = response.has("cursor")? response.getString("cursor") : ""; + do { + this.logger.info(testReportPrefix + + String.format("subsequent response %d/%d", responseCounter++, (indexSize / pageSize) + 1)); + assertTrue("Paged response schema doesn't match to non-paged", + schema.similar(response.getJSONArray("schema"))); + + rowsReturned += response.getInt("size"); + var datarows = response.getJSONArray("datarows"); + for (int i = 0; i < datarows.length(); i++) { + rowsPaged.put(datarows.get(i)); + } + + if (response.has("cursor")) { + TestUtils.verifyIsV2Cursor(response); + cursor = response.getString("cursor"); + response = executeCursorQuery(cursor); + } else { + cursor = ""; + } + + } while(!cursor.isEmpty()); + assertTrue("Paged response schema doesn't match to non-paged", + schema.similar(response.getJSONArray("schema"))); + + assertEquals(testReportPrefix + "Paged responses return another row count that non-paged", + indexSize, rowsReturned); + assertTrue(testReportPrefix + "Paged accumulated result has other rows than non-paged", + rows.similar(rowsPaged)); + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/PaginationFallbackIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/PaginationFallbackIT.java new file mode 100644 index 0000000000..33d9c5f6a8 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/PaginationFallbackIT.java @@ -0,0 +1,131 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql; + +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX; +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_ONLINE; +import static org.opensearch.sql.util.TestUtils.verifyIsV1Cursor; +import static org.opensearch.sql.util.TestUtils.verifyIsV2Cursor; + +import java.io.IOException; +import org.json.JSONObject; +import org.junit.Test; +import org.opensearch.sql.legacy.SQLIntegTestCase; +import org.opensearch.sql.util.TestUtils; + +public class PaginationFallbackIT extends SQLIntegTestCase { + @Override + public void init() throws IOException { + loadIndex(Index.PHRASE); + loadIndex(Index.ONLINE); + } + + @Test + public void testWhereClause() throws IOException { + var response = executeQueryTemplate("SELECT * FROM %s WHERE 1 = 1", TEST_INDEX_ONLINE); + verifyIsV1Cursor(response); + } + + @Test + public void testSelectAll() throws IOException { + var response = executeQueryTemplate("SELECT * FROM %s", TEST_INDEX_ONLINE); + verifyIsV2Cursor(response); + } + + @Test + public void testSelectWithOpenSearchFuncInFilter() throws IOException { + var response = executeQueryTemplate( + "SELECT * FROM %s WHERE `11` = match_phrase('96')", TEST_INDEX_ONLINE); + verifyIsV1Cursor(response); + } + + @Test + public void testSelectWithHighlight() throws IOException { + var response = executeQueryTemplate( + "SELECT highlight(`11`) FROM %s WHERE match_query(`11`, '96')", TEST_INDEX_ONLINE); + // As of 2023-03-08, WHERE clause sends the query to legacy engine and legacy engine + // does not support highlight as an expression. + assertTrue(response.has("error")); + } + + @Test + public void testSelectWithFullTextSearch() throws IOException { + var response = executeQueryTemplate( + "SELECT * FROM %s WHERE match_phrase(`11`, '96')", TEST_INDEX_ONLINE); + verifyIsV1Cursor(response); + } + + @Test + public void testSelectFromIndexWildcard() throws IOException { + var response = executeQueryTemplate("SELECT * FROM %s*", TEST_INDEX); + verifyIsV2Cursor(response); + } + + @Test + public void testSelectFromDataSource() throws IOException { + var response = executeQueryTemplate("SELECT * FROM @opensearch.%s", + TEST_INDEX_ONLINE); + verifyIsV2Cursor(response); + } + + @Test + public void testSelectColumnReference() throws IOException { + var response = executeQueryTemplate("SELECT `107` from %s", TEST_INDEX_ONLINE); + verifyIsV1Cursor(response); + } + + @Test + public void testSubquery() throws IOException { + var response = executeQueryTemplate("SELECT `107` from (SELECT * FROM %s)", + TEST_INDEX_ONLINE); + verifyIsV1Cursor(response); + } + + @Test + public void testSelectExpression() throws IOException { + var response = executeQueryTemplate("SELECT 1 + 1 - `107` from %s", + TEST_INDEX_ONLINE); + verifyIsV1Cursor(response); + } + + @Test + public void testGroupBy() throws IOException { + // GROUP BY is not paged by either engine. + var response = executeQueryTemplate("SELECT * FROM %s GROUP BY `107`", + TEST_INDEX_ONLINE); + TestUtils.verifyNoCursor(response); + } + + @Test + public void testGroupByHaving() throws IOException { + // GROUP BY is not paged by either engine. + var response = executeQueryTemplate("SELECT * FROM %s GROUP BY `107` HAVING `107` > 400", + TEST_INDEX_ONLINE); + TestUtils.verifyNoCursor(response); + } + + @Test + public void testLimit() throws IOException { + var response = executeQueryTemplate("SELECT * FROM %s LIMIT 8", TEST_INDEX_ONLINE); + verifyIsV1Cursor(response); + } + + @Test + public void testLimitOffset() throws IOException { + var response = executeQueryTemplate("SELECT * FROM %s LIMIT 8 OFFSET 4", + TEST_INDEX_ONLINE); + verifyIsV1Cursor(response); + } + + @Test + public void testOrderBy() throws IOException { + var response = executeQueryTemplate("SELECT * FROM %s ORDER By `107`", + TEST_INDEX_ONLINE); + verifyIsV1Cursor(response); + } + + +} diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/PaginationIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/PaginationIT.java new file mode 100644 index 0000000000..72ec20c679 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/PaginationIT.java @@ -0,0 +1,115 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql; + +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_CALCS; +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_ONLINE; +import static org.opensearch.sql.legacy.plugin.RestSqlAction.EXPLAIN_API_ENDPOINT; + +import java.io.IOException; + +import lombok.SneakyThrows; +import org.json.JSONObject; +import org.junit.Ignore; +import org.junit.Test; +import org.opensearch.client.Request; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.ResponseException; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.legacy.SQLIntegTestCase; +import org.opensearch.sql.util.TestUtils; + +public class PaginationIT extends SQLIntegTestCase { + @Override + public void init() throws IOException { + loadIndex(Index.CALCS); + loadIndex(Index.ONLINE); + } + + @Test + public void testSmallDataSet() throws IOException { + var query = "SELECT * from " + TEST_INDEX_CALCS; + var response = new JSONObject(executeFetchQuery(query, 4, "jdbc")); + assertTrue(response.has("cursor")); + assertEquals(4, response.getInt("size")); + TestUtils.verifyIsV2Cursor(response); + } + + @Test + public void testLargeDataSetV1() throws IOException { + var v1query = "SELECT * from " + TEST_INDEX_ONLINE + " WHERE 1 = 1"; + var v1response = new JSONObject(executeFetchQuery(v1query, 4, "jdbc")); + assertEquals(4, v1response.getInt("size")); + TestUtils.verifyIsV1Cursor(v1response); + } + + @Test + public void testLargeDataSetV2() throws IOException { + var query = "SELECT * from " + TEST_INDEX_ONLINE; + var response = new JSONObject(executeFetchQuery(query, 4, "jdbc")); + assertEquals(4, response.getInt("size")); + TestUtils.verifyIsV2Cursor(response); + } + + @Ignore("Scroll may not expire after timeout") + // Scroll keep alive parameter guarantees that scroll context would be kept for that time, + // but doesn't define how fast it will be expired after time out. + // With KA = 1s scroll may be kept up to 30 sec or more. We can't test exact expiration. + // I disable the test to prevent it waiting for a minute and delay all CI. + public void testCursorTimeout() throws IOException, InterruptedException { + updateClusterSettings( + new ClusterSetting(PERSISTENT, Settings.Key.SQL_CURSOR_KEEP_ALIVE.getKeyValue(), "1s")); + + var query = "SELECT * from " + TEST_INDEX_CALCS; + var response = new JSONObject(executeFetchQuery(query, 4, "jdbc")); + assertTrue(response.has("cursor")); + var cursor = response.getString("cursor"); + Thread.sleep(2222L); // > 1s + + ResponseException exception = + expectThrows(ResponseException.class, () -> executeCursorQuery(cursor)); + response = new JSONObject(TestUtils.getResponseBody(exception.getResponse())); + assertEquals(response.getJSONObject("error").getString("reason"), + "Error occurred in OpenSearch engine: all shards failed"); + assertTrue(response.getJSONObject("error").getString("details") + .contains("SearchContextMissingException[No search context found for id")); + assertEquals(response.getJSONObject("error").getString("type"), + "SearchPhaseExecutionException"); + + wipeAllClusterSettings(); + } + + @Test + @SneakyThrows + public void testCloseCursor() { + // Initial page request to get cursor + var query = "SELECT * from " + TEST_INDEX_CALCS; + var response = new JSONObject(executeFetchQuery(query, 4, "jdbc")); + assertTrue(response.has("cursor")); + var cursor = response.getString("cursor"); + + // Close the cursor + Request closeCursorRequest = new Request("POST", "_plugins/_sql/close"); + closeCursorRequest.setJsonEntity(String.format("{ \"cursor\" : \"%s\" } ", cursor)); + RequestOptions.Builder restOptionsBuilder = RequestOptions.DEFAULT.toBuilder(); + restOptionsBuilder.addHeader("Content-Type", "application/json"); + closeCursorRequest.setOptions(restOptionsBuilder); + response = new JSONObject(executeRequest(closeCursorRequest)); + assertTrue(response.has("succeeded")); + assertTrue(response.getBoolean("succeeded")); + + // Test that cursor is no longer available + ResponseException exception = + expectThrows(ResponseException.class, () -> executeCursorQuery(cursor)); + response = new JSONObject(TestUtils.getResponseBody(exception.getResponse())); + assertEquals(response.getJSONObject("error").getString("reason"), + "Error occurred in OpenSearch engine: all shards failed"); + assertTrue(response.getJSONObject("error").getString("details") + .contains("SearchContextMissingException[No search context found for id")); + assertEquals(response.getJSONObject("error").getString("type"), + "SearchPhaseExecutionException"); + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/PaginationWindowIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/PaginationWindowIT.java new file mode 100644 index 0000000000..be208cd137 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/PaginationWindowIT.java @@ -0,0 +1,97 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql; + +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_PHRASE; + +import java.io.IOException; +import org.json.JSONObject; +import org.junit.After; +import org.junit.Test; +import org.opensearch.client.ResponseException; +import org.opensearch.sql.legacy.SQLIntegTestCase; + +public class PaginationWindowIT extends SQLIntegTestCase { + @Override + public void init() throws IOException { + loadIndex(Index.PHRASE); + } + + @After + public void resetParams() throws IOException { + resetMaxResultWindow(TEST_INDEX_PHRASE); + resetQuerySizeLimit(); + } + + @Test + public void testFetchSizeLessThanMaxResultWindow() throws IOException { + setMaxResultWindow(TEST_INDEX_PHRASE, 6); + JSONObject response = executeQueryTemplate("SELECT * FROM %s", TEST_INDEX_PHRASE, 5); + + int numRows = 0; + do { + // Process response + String cursor = response.getString("cursor"); + numRows += response.getJSONArray("datarows").length(); + response = executeCursorQuery(cursor); + } while (response.has("cursor")); + numRows += response.getJSONArray("datarows").length(); + + var countRows = executeJdbcRequest("SELECT COUNT(*) FROM " + TEST_INDEX_PHRASE) + .getJSONArray("datarows") + .getJSONArray(0) + .get(0); + assertEquals(countRows, numRows); + } + + @Test + public void testQuerySizeLimitDoesNotEffectTotalRowsReturned() throws IOException { + int querySizeLimit = 4; + setQuerySizeLimit(querySizeLimit); + JSONObject response = executeQueryTemplate("SELECT * FROM %s", TEST_INDEX_PHRASE, 5); + assertTrue(response.getInt("size") > querySizeLimit); + + int numRows = 0; + do { + // Process response + String cursor = response.getString("cursor"); + numRows += response.getJSONArray("datarows").length(); + response = executeCursorQuery(cursor); + } while (response.has("cursor")); + numRows += response.getJSONArray("datarows").length(); + var countRows = executeJdbcRequest("SELECT COUNT(*) FROM " + TEST_INDEX_PHRASE) + .getJSONArray("datarows") + .getJSONArray(0) + .get(0); + assertEquals(countRows, numRows); + assertTrue(numRows > querySizeLimit); + } + + @Test + public void testQuerySizeLimitDoesNotEffectPageSize() throws IOException { + setQuerySizeLimit(3); + setMaxResultWindow(TEST_INDEX_PHRASE, 4); + var response + = executeQueryTemplate("SELECT * FROM %s", TEST_INDEX_PHRASE, 4); + assertEquals(4, response.getInt("size")); + + var response2 + = executeQueryTemplate("SELECT * FROM %s", TEST_INDEX_PHRASE, 2); + assertEquals(2, response2.getInt("size")); + } + + @Test + public void testFetchSizeLargerThanResultWindowFails() throws IOException { + final int window = 2; + setMaxResultWindow(TEST_INDEX_PHRASE, 2); + assertThrows(ResponseException.class, + () -> executeQueryTemplate("SELECT * FROM %s", + TEST_INDEX_PHRASE, window + 1)); + resetMaxResultWindow(TEST_INDEX_PHRASE); + } + + +} diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/StandalonePaginationIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/StandalonePaginationIT.java new file mode 100644 index 0000000000..aad39c4074 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/StandalonePaginationIT.java @@ -0,0 +1,172 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql; + +import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata; +import static org.opensearch.sql.ppl.StandaloneIT.getDataSourceMetadataStorage; +import static org.opensearch.sql.ppl.StandaloneIT.getDataSourceUserRoleHelper; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import lombok.Getter; +import lombok.SneakyThrows; +import org.json.JSONObject; +import org.junit.Test; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.opensearch.client.Request; +import org.opensearch.client.ResponseException; +import org.opensearch.client.RestHighLevelClient; +import org.opensearch.common.inject.Injector; +import org.opensearch.common.inject.ModulesBuilder; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.sql.ast.tree.FetchCursor; +import org.opensearch.sql.common.response.ResponseListener; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.executor.pagination.PlanSerializer; +import org.opensearch.sql.executor.QueryService; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.legacy.SQLIntegTestCase; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.client.OpenSearchRestClient; +import org.opensearch.sql.executor.pagination.Cursor; +import org.opensearch.sql.opensearch.storage.OpenSearchDataSourceFactory; +import org.opensearch.sql.opensearch.storage.OpenSearchIndex; +import org.opensearch.sql.planner.PlanContext; +import org.opensearch.sql.planner.logical.LogicalPaginate; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.planner.logical.LogicalRelation; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.storage.DataSourceFactory; +import org.opensearch.sql.util.InternalRestHighLevelClient; +import org.opensearch.sql.util.StandaloneModule; + +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +public class StandalonePaginationIT extends SQLIntegTestCase { + + private QueryService queryService; + + private PlanSerializer planSerializer; + + private OpenSearchClient client; + + @Override + @SneakyThrows + public void init() { + RestHighLevelClient restClient = new InternalRestHighLevelClient(client()); + client = new OpenSearchRestClient(restClient); + DataSourceService dataSourceService = new DataSourceServiceImpl( + new ImmutableSet.Builder() + .add(new OpenSearchDataSourceFactory(client, defaultSettings())) + .build(), + getDataSourceMetadataStorage(), + getDataSourceUserRoleHelper() + ); + dataSourceService.createDataSource(defaultOpenSearchDataSourceMetadata()); + + ModulesBuilder modules = new ModulesBuilder(); + modules.add(new StandaloneModule(new InternalRestHighLevelClient(client()), defaultSettings(), dataSourceService)); + Injector injector = modules.createInjector(); + + queryService = injector.getInstance(QueryService.class); + planSerializer = injector.getInstance(PlanSerializer.class); + } + + @Test + public void test_pagination_whitebox() throws IOException { + class TestResponder + implements ResponseListener { + @Getter + Cursor cursor = Cursor.None; + @Override + public void onResponse(ExecutionEngine.QueryResponse response) { + cursor = response.getCursor(); + } + + @Override + public void onFailure(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + // arrange + { + Request request1 = new Request("PUT", "/test/_doc/1?refresh=true"); + request1.setJsonEntity("{\"name\": \"hello\", \"age\": 20}"); + client().performRequest(request1); + Request request2 = new Request("PUT", "/test/_doc/2?refresh=true"); + request2.setJsonEntity("{\"name\": \"world\", \"age\": 30}"); + client().performRequest(request2); + } + + // act 1, asserts in firstResponder + var t = new OpenSearchIndex(client, defaultSettings(), "test"); + LogicalPlan p = new LogicalPaginate(1, List.of( + new LogicalProject( + new LogicalRelation("test", t), List.of( + DSL.named("name", DSL.ref("name", ExprCoreType.STRING)), + DSL.named("age", DSL.ref("age", ExprCoreType.LONG))), + List.of() + ))); + var firstResponder = new TestResponder(); + queryService.executePlan(p, PlanContext.emptyPlanContext(), firstResponder); + + // act 2, asserts in secondResponder + + PhysicalPlan plan = planSerializer.convertToPlan(firstResponder.getCursor().toString()); + var secondResponder = new TestResponder(); + queryService.execute(new FetchCursor(firstResponder.getCursor().toString()), secondResponder); + + // act 3: confirm that there's no cursor. + } + + @Test + @SneakyThrows + public void test_explain_not_supported() { + var request = new Request("POST", "_plugins/_sql/_explain"); + // Request should be rejected before index names are resolved + request.setJsonEntity("{ \"query\": \"select * from something\", \"fetch_size\": 10 }"); + var exception = assertThrows(ResponseException.class, () -> client().performRequest(request)); + var response = new JSONObject(new String(exception.getResponse().getEntity().getContent().readAllBytes())); + assertEquals("`explain` feature for paginated requests is not implemented yet.", + response.getJSONObject("error").getString("details")); + + // Request should be rejected before cursor parsed + request.setJsonEntity("{ \"cursor\" : \"n:0000\" }"); + exception = assertThrows(ResponseException.class, () -> client().performRequest(request)); + response = new JSONObject(new String(exception.getResponse().getEntity().getContent().readAllBytes())); + assertEquals("Explain of a paged query continuation is not supported. Use `explain` for the initial query request.", + response.getJSONObject("error").getString("details")); + } + + private Settings defaultSettings() { + return new Settings() { + private final Map defaultSettings = new ImmutableMap.Builder() + .put(Key.QUERY_SIZE_LIMIT, 200) + .put(Key.SQL_CURSOR_KEEP_ALIVE, TimeValue.timeValueMinutes(1)) + .build(); + + @Override + public T getSettingValue(Key key) { + return (T) defaultSettings.get(key); + } + + @Override + public List getSettings() { + return (List) defaultSettings; + } + }; + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/util/InternalRestHighLevelClient.java b/integ-test/src/test/java/org/opensearch/sql/util/InternalRestHighLevelClient.java new file mode 100644 index 0000000000..57726089ae --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/util/InternalRestHighLevelClient.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.util; + +import java.util.Collections; +import org.opensearch.client.RestClient; +import org.opensearch.client.RestHighLevelClient; + +/** + * Internal RestHighLevelClient only for testing purpose. + */ +public class InternalRestHighLevelClient extends RestHighLevelClient { + public InternalRestHighLevelClient(RestClient restClient) { + super(restClient, RestClient::close, Collections.emptyList()); + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/util/StandaloneModule.java b/integ-test/src/test/java/org/opensearch/sql/util/StandaloneModule.java new file mode 100644 index 0000000000..c347ea5244 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/util/StandaloneModule.java @@ -0,0 +1,120 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.util; + +import lombok.RequiredArgsConstructor; +import org.opensearch.client.RestHighLevelClient; +import org.opensearch.common.inject.AbstractModule; +import org.opensearch.common.inject.Provides; +import org.opensearch.common.inject.Singleton; +import org.opensearch.sql.analysis.Analyzer; +import org.opensearch.sql.analysis.ExpressionAnalyzer; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.executor.pagination.PlanSerializer; +import org.opensearch.sql.executor.QueryManager; +import org.opensearch.sql.executor.QueryService; +import org.opensearch.sql.executor.execution.QueryPlanFactory; +import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.monitor.AlwaysHealthyMonitor; +import org.opensearch.sql.monitor.ResourceMonitor; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.client.OpenSearchRestClient; +import org.opensearch.sql.opensearch.executor.OpenSearchExecutionEngine; +import org.opensearch.sql.opensearch.executor.protector.ExecutionProtector; +import org.opensearch.sql.opensearch.executor.protector.OpenSearchExecutionProtector; +import org.opensearch.sql.opensearch.storage.OpenSearchStorageEngine; +import org.opensearch.sql.planner.Planner; +import org.opensearch.sql.planner.optimizer.LogicalPlanOptimizer; +import org.opensearch.sql.ppl.PPLService; +import org.opensearch.sql.ppl.antlr.PPLSyntaxParser; +import org.opensearch.sql.sql.SQLService; +import org.opensearch.sql.sql.antlr.SQLSyntaxParser; +import org.opensearch.sql.storage.StorageEngine; + +/** + * A utility class which registers SQL engine singletons as `OpenSearchPluginModule` does. + * It is needed to get access to those instances in test and validate their behavior. + */ +@RequiredArgsConstructor +public class StandaloneModule extends AbstractModule { + + private final RestHighLevelClient client; + + private final Settings settings; + + private final DataSourceService dataSourceService; + + private final BuiltinFunctionRepository functionRepository = + BuiltinFunctionRepository.getInstance(); + + @Override + protected void configure() { + } + + @Provides + public OpenSearchClient openSearchClient() { + return new OpenSearchRestClient(client); + } + + @Provides + public StorageEngine storageEngine(OpenSearchClient client) { + return new OpenSearchStorageEngine(client, settings); + } + + @Provides + public ExecutionEngine executionEngine(OpenSearchClient client, ExecutionProtector protector, + PlanSerializer planSerializer) { + return new OpenSearchExecutionEngine(client, protector, planSerializer); + } + + @Provides + public ResourceMonitor resourceMonitor() { + return new AlwaysHealthyMonitor(); + } + + @Provides + public ExecutionProtector protector(ResourceMonitor resourceMonitor) { + return new OpenSearchExecutionProtector(resourceMonitor); + } + + @Provides + @Singleton + public QueryManager queryManager() { + return new ExecuteOnCallerThreadQueryManager(); + } + + @Provides + public PPLService pplService(QueryManager queryManager, QueryPlanFactory queryPlanFactory) { + return new PPLService(new PPLSyntaxParser(), queryManager, queryPlanFactory); + } + + @Provides + public SQLService sqlService(QueryManager queryManager, QueryPlanFactory queryPlanFactory) { + return new SQLService(new SQLSyntaxParser(), queryManager, queryPlanFactory); + } + + @Provides + public PlanSerializer planSerializer(StorageEngine storageEngine) { + return new PlanSerializer(storageEngine); + } + + @Provides + public QueryPlanFactory queryPlanFactory(QueryService qs) { + + return new QueryPlanFactory(qs); + } + + @Provides + public QueryService queryService(ExecutionEngine executionEngine) { + Analyzer analyzer = + new Analyzer( + new ExpressionAnalyzer(functionRepository), dataSourceService, functionRepository); + Planner planner = new Planner(LogicalPlanOptimizer.create()); + return new QueryService(analyzer, executionEngine, planner); + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/util/TestUtils.java b/integ-test/src/test/java/org/opensearch/sql/util/TestUtils.java index bd75ead43b..69f1649190 100644 --- a/integ-test/src/test/java/org/opensearch/sql/util/TestUtils.java +++ b/integ-test/src/test/java/org/opensearch/sql/util/TestUtils.java @@ -7,6 +7,8 @@ package org.opensearch.sql.util; import static com.google.common.base.Strings.isNullOrEmpty; +import static org.junit.Assert.assertTrue; +import static org.opensearch.sql.executor.pagination.PlanSerializer.CURSOR_PREFIX; import java.io.BufferedReader; import java.io.File; @@ -20,22 +22,21 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; +import java.util.Arrays; import java.util.LinkedList; import java.util.List; import java.util.Locale; import java.util.stream.Collectors; import org.json.JSONObject; -import org.junit.Assert; import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.index.IndexRequest; import org.opensearch.client.Client; import org.opensearch.client.Request; -import org.opensearch.client.RequestOptions; import org.opensearch.client.Response; import org.opensearch.client.RestClient; import org.opensearch.common.xcontent.XContentType; -import org.opensearch.rest.RestStatus; +import org.opensearch.sql.legacy.cursor.CursorType; public class TestUtils { @@ -839,4 +840,28 @@ public static List> getPermutations(final List items) { return result; } + + public static void verifyIsV1Cursor(JSONObject response) { + var legacyCursorPrefixes = Arrays.stream(CursorType.values()) + .map(c -> c.getId() + ":").collect(Collectors.toList()); + verifyCursor(response, legacyCursorPrefixes, "v1"); + } + + + public static void verifyIsV2Cursor(JSONObject response) { + verifyCursor(response, List.of(CURSOR_PREFIX), "v2"); + } + + private static void verifyCursor(JSONObject response, List validCursorPrefix, String engineName) { + assertTrue("'cursor' property does not exist", response.has("cursor")); + + var cursor = response.getString("cursor"); + assertTrue("'cursor' property is empty", !cursor.isEmpty()); + assertTrue("The cursor '" + cursor + "' is not from " + engineName + " engine.", + validCursorPrefix.stream().anyMatch(cursor::startsWith)); + } + + public static void verifyNoCursor(JSONObject response) { + assertTrue(!response.has("cursor")); + } } diff --git a/integ-test/src/test/resources/expectedOutput/ppl/explain_filter_agg_push.json b/integ-test/src/test/resources/expectedOutput/ppl/explain_filter_agg_push.json index 2d7f5f8c08..568b397f07 100644 --- a/integ-test/src/test/resources/expectedOutput/ppl/explain_filter_agg_push.json +++ b/integ-test/src/test/resources/expectedOutput/ppl/explain_filter_agg_push.json @@ -8,7 +8,7 @@ { "name": "OpenSearchIndexScan", "description": { - "request": "OpenSearchQueryRequest(indexName\u003dopensearch-sql_test_index_account, sourceBuilder\u003d{\"from\":0,\"size\":0,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":30,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"state\":{\"terms\":{\"field\":\"state.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}},{\"city\":{\"terms\":{\"field\":\"city.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg_age\":{\"avg\":{\"field\":\"age\"}}}}}}, searchDone\u003dfalse)" + "request": "OpenSearchQueryRequest(indexName\u003dopensearch-sql_test_index_account, sourceBuilder\u003d{\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":30,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"state\":{\"terms\":{\"field\":\"state.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}},{\"city\":{\"terms\":{\"field\":\"city.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg_age\":{\"avg\":{\"field\":\"age\"}}}}}}, searchDone\u003dfalse)" }, "children": [] } diff --git a/integ-test/src/test/resources/expectedOutput/ppl/explain_output.json b/integ-test/src/test/resources/expectedOutput/ppl/explain_output.json index 45988e35c7..8d45714283 100644 --- a/integ-test/src/test/resources/expectedOutput/ppl/explain_output.json +++ b/integ-test/src/test/resources/expectedOutput/ppl/explain_output.json @@ -31,7 +31,7 @@ { "name": "OpenSearchIndexScan", "description": { - "request": "OpenSearchQueryRequest(indexName\u003dopensearch-sql_test_index_account, sourceBuilder\u003d{\"from\":0,\"size\":0,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":30,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"state\":{\"terms\":{\"field\":\"state.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}},{\"city\":{\"terms\":{\"field\":\"city.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg_age\":{\"avg\":{\"field\":\"age\"}}}}}}, searchDone\u003dfalse)" + "request": "OpenSearchQueryRequest(indexName\u003dopensearch-sql_test_index_account, sourceBuilder\u003d{\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":30,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"state\":{\"terms\":{\"field\":\"state.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}},{\"city\":{\"terms\":{\"field\":\"city.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg_age\":{\"avg\":{\"field\":\"age\"}}}}}}, searchDone\u003dfalse)" }, "children": [] } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSQLQueryAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSQLQueryAction.java index bc97f71b47..37cbba4adf 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSQLQueryAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSQLQueryAction.java @@ -24,6 +24,7 @@ import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.common.response.ResponseListener; import org.opensearch.sql.common.utils.QueryContext; +import org.opensearch.sql.exception.UnsupportedCursorRequestException; import org.opensearch.sql.executor.ExecutionEngine.ExplainResponse; import org.opensearch.sql.legacy.metrics.MetricName; import org.opensearch.sql.legacy.metrics.Metrics; @@ -33,6 +34,7 @@ import org.opensearch.sql.protocol.response.format.Format; import org.opensearch.sql.protocol.response.format.JdbcResponseFormatter; import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; +import org.opensearch.sql.protocol.response.format.CommandResponseFormatter; import org.opensearch.sql.protocol.response.format.RawResponseFormatter; import org.opensearch.sql.protocol.response.format.ResponseFormatter; import org.opensearch.sql.sql.SQLService; @@ -101,7 +103,9 @@ public RestChannelConsumer prepareRequest( channel, createExplainResponseListener(channel, executionErrorHandler), fallbackHandler)); - } else { + } + // If close request, sqlService.closeCursor + else { return channel -> sqlService.execute( request, @@ -119,14 +123,14 @@ private ResponseListener fallBackListener( return new ResponseListener() { @Override public void onResponse(T response) { - LOG.error("[{}] Request is handled by new SQL query engine", + LOG.info("[{}] Request is handled by new SQL query engine", QueryContext.getRequestId()); next.onResponse(response); } @Override public void onFailure(Exception e) { - if (e instanceof SyntaxCheckException) { + if (e instanceof SyntaxCheckException || e instanceof UnsupportedCursorRequestException) { fallBackHandler.accept(channel, e); } else { next.onFailure(e); @@ -161,7 +165,10 @@ private ResponseListener createQueryResponseListener( BiConsumer errorHandler) { Format format = request.format(); ResponseFormatter formatter; - if (format.equals(Format.CSV)) { + + if (request.isCursorCloseRequest()) { + formatter = new CommandResponseFormatter(); + } else if (format.equals(Format.CSV)) { formatter = new CsvResponseFormatter(request.sanitize()); } else if (format.equals(Format.RAW)) { formatter = new RawResponseFormatter(); @@ -172,7 +179,8 @@ private ResponseListener createQueryResponseListener( @Override public void onResponse(QueryResponse response) { sendResponse(channel, OK, - formatter.format(new QueryResult(response.getSchema(), response.getResults()))); + formatter.format(new QueryResult(response.getSchema(), response.getResults(), + response.getCursor()))); } @Override diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java index 5249d2d5d0..0408a61342 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java @@ -42,6 +42,7 @@ import org.opensearch.sql.legacy.antlr.SqlAnalysisConfig; import org.opensearch.sql.legacy.antlr.SqlAnalysisException; import org.opensearch.sql.legacy.antlr.semantic.types.Type; +import org.opensearch.sql.legacy.cursor.CursorType; import org.opensearch.sql.legacy.domain.ColumnTypeProvider; import org.opensearch.sql.legacy.domain.QueryActionRequest; import org.opensearch.sql.legacy.esdomain.LocalClusterState; @@ -132,7 +133,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } final SqlRequest sqlRequest = SqlRequestFactory.getSqlRequest(request); - if (sqlRequest.cursor() != null) { + if (isLegacyCursor(sqlRequest)) { if (isExplainRequest(request)) { throw new IllegalArgumentException("Invalid request. Cannot explain cursor"); } else { @@ -147,14 +148,14 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli // Route request to new query engine if it's supported already SQLQueryRequest newSqlRequest = new SQLQueryRequest(sqlRequest.getJsonContent(), - sqlRequest.getSql(), request.path(), request.params()); + sqlRequest.getSql(), request.path(), request.params(), sqlRequest.cursor()); return newSqlQueryHandler.prepareRequest(newSqlRequest, (restChannel, exception) -> { try{ if (newSqlRequest.isExplainRequest()) { LOG.info("Request is falling back to old SQL engine due to: " + exception.getMessage()); } - LOG.debug("[{}] Request {} is not supported and falling back to old SQL engine", + LOG.info("[{}] Request {} is not supported and falling back to old SQL engine", QueryContext.getRequestId(), newSqlRequest); LOG.info("Request Query: {}", QueryDataAnonymizer.anonymizeData(sqlRequest.getSql())); QueryAction queryAction = explainRequest(client, sqlRequest, format); @@ -175,6 +176,17 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } } + + /** + * @param sqlRequest client request + * @return true if this cursor was generated by the legacy engine, false otherwise. + */ + private static boolean isLegacyCursor(SqlRequest sqlRequest) { + String cursor = sqlRequest.cursor(); + return cursor != null + && CursorType.getById(cursor.substring(0, 1)) != CursorType.NULL; + } + @Override protected Set responseParams() { Set responseParams = new HashSet<>(super.responseParams()); diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionCursorFallbackTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionCursorFallbackTest.java new file mode 100644 index 0000000000..a11f4c47d7 --- /dev/null +++ b/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionCursorFallbackTest.java @@ -0,0 +1,127 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.legacy.plugin; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.opensearch.sql.legacy.plugin.RestSqlAction.QUERY_API_ENDPOINT; + +import java.io.IOException; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import org.json.JSONObject; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.Strings; +import org.opensearch.common.inject.Injector; +import org.opensearch.common.inject.ModulesBuilder; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestRequest; +import org.opensearch.sql.common.antlr.SyntaxCheckException; +import org.opensearch.sql.executor.QueryManager; +import org.opensearch.sql.executor.execution.QueryPlanFactory; +import org.opensearch.sql.sql.SQLService; +import org.opensearch.sql.sql.antlr.SQLSyntaxParser; +import org.opensearch.sql.sql.domain.SQLQueryRequest; +import org.opensearch.threadpool.ThreadPool; + +/** + * A test suite that verifies fallback behaviour of cursor queries. + */ +@RunWith(MockitoJUnitRunner.class) +public class RestSQLQueryActionCursorFallbackTest extends BaseRestHandler { + + private NodeClient nodeClient; + + @Mock + private ThreadPool threadPool; + + @Mock + private QueryManager queryManager; + + @Mock + private QueryPlanFactory factory; + + @Mock + private RestChannel restChannel; + + private Injector injector; + + @Before + public void setup() { + nodeClient = new NodeClient(org.opensearch.common.settings.Settings.EMPTY, threadPool); + ModulesBuilder modules = new ModulesBuilder(); + modules.add(b -> { + b.bind(SQLService.class).toInstance(new SQLService(new SQLSyntaxParser(), queryManager, factory)); + }); + injector = modules.createInjector(); + Mockito.lenient().when(threadPool.getThreadContext()) + .thenReturn(new ThreadContext(org.opensearch.common.settings.Settings.EMPTY)); + } + + // Initial page request test cases + + @Test + public void no_fallback_with_column_reference() throws Exception { + String query = "SELECT name FROM test1"; + SQLQueryRequest request = createSqlQueryRequest(query, Optional.empty(), + Optional.of(5)); + + assertFalse(doesQueryFallback(request)); + } + + private static SQLQueryRequest createSqlQueryRequest(String query, Optional cursorId, + Optional fetchSize) throws IOException { + var builder = XContentFactory.jsonBuilder() + .startObject() + .field("query").value(query); + if (cursorId.isPresent()) { + builder.field("cursor").value(cursorId.get()); + } + + if (fetchSize.isPresent()) { + builder.field("fetch_size").value(fetchSize.get()); + } + builder.endObject(); + JSONObject jsonContent = new JSONObject(Strings.toString(builder)); + + return new SQLQueryRequest(jsonContent, query, QUERY_API_ENDPOINT, + Map.of("format", "jdbc"), cursorId.orElse("")); + } + + boolean doesQueryFallback(SQLQueryRequest request) throws Exception { + AtomicBoolean fallback = new AtomicBoolean(false); + RestSQLQueryAction queryAction = new RestSQLQueryAction(injector); + queryAction.prepareRequest(request, (channel, exception) -> { + fallback.set(true); + }, (channel, exception) -> { + }).accept(restChannel); + return fallback.get(); + } + + @Override + public String getName() { + // do nothing, RestChannelConsumer is protected which required to extend BaseRestHandler + return null; + } + + @Override + protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient nodeClient) + { + // do nothing, RestChannelConsumer is protected which required to extend BaseRestHandler + return null; + } +} diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionTest.java index 1bc34edf50..be572f3dfb 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionTest.java @@ -74,7 +74,7 @@ public void handleQueryThatCanSupport() throws Exception { new JSONObject("{\"query\": \"SELECT -123\"}"), "SELECT -123", QUERY_API_ENDPOINT, - ""); + "jdbc"); RestSQLQueryAction queryAction = new RestSQLQueryAction(injector); queryAction.prepareRequest(request, (channel, exception) -> { @@ -90,7 +90,7 @@ public void handleExplainThatCanSupport() throws Exception { new JSONObject("{\"query\": \"SELECT -123\"}"), "SELECT -123", EXPLAIN_API_ENDPOINT, - ""); + "jdbc"); RestSQLQueryAction queryAction = new RestSQLQueryAction(injector); queryAction.prepareRequest(request, (channel, exception) -> { @@ -107,7 +107,7 @@ public void queryThatNotSupportIsHandledByFallbackHandler() throws Exception { "{\"query\": \"SELECT name FROM test1 JOIN test2 ON test1.name = test2.name\"}"), "SELECT name FROM test1 JOIN test2 ON test1.name = test2.name", QUERY_API_ENDPOINT, - ""); + "jdbc"); AtomicBoolean fallback = new AtomicBoolean(false); RestSQLQueryAction queryAction = new RestSQLQueryAction(injector); @@ -128,7 +128,7 @@ public void queryExecutionFailedIsHandledByExecutionErrorHandler() throws Except "{\"query\": \"SELECT -123\"}"), "SELECT -123", QUERY_API_ENDPOINT, - ""); + "jdbc"); doThrow(new IllegalStateException("execution exception")) .when(queryManager) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java index b26680b3ba..b6ca5471e5 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java @@ -42,7 +42,7 @@ public class OpenSearchNodeClient implements OpenSearchClient { private final NodeClient client; /** - * Constructor of ElasticsearchNodeClient. + * Constructor of OpenSearchNodeClient. */ public OpenSearchNodeClient(NodeClient client) { this.client = client; @@ -171,7 +171,14 @@ public Map meta() { @Override public void cleanup(OpenSearchRequest request) { - request.clean(scrollId -> client.prepareClearScroll().addScrollId(scrollId).get()); + request.clean(scrollId -> { + try { + client.prepareClearScroll().addScrollId(scrollId).get(); + } catch (Exception e) { + throw new IllegalStateException( + "Failed to clean up resources for search request " + request, e); + } + }); } @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchRestClient.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchRestClient.java index 66cc067541..c27c4bbc30 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchRestClient.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchRestClient.java @@ -183,7 +183,6 @@ public void cleanup(OpenSearchRequest request) { "Failed to clean up resources for search request " + request, e); } }); - } @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java index 9a136a3bec..31e5c7f957 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java @@ -6,15 +6,16 @@ package org.opensearch.sql.opensearch.executor; -import com.google.common.collect.ImmutableMap; import java.util.ArrayList; import java.util.List; +import java.util.Map; import lombok.RequiredArgsConstructor; import org.opensearch.sql.common.response.ResponseListener; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.executor.ExecutionContext; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.executor.Explain; +import org.opensearch.sql.executor.pagination.PlanSerializer; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.executor.protector.ExecutionProtector; import org.opensearch.sql.planner.physical.PhysicalPlan; @@ -27,6 +28,7 @@ public class OpenSearchExecutionEngine implements ExecutionEngine { private final OpenSearchClient client; private final ExecutionProtector executionProtector; + private final PlanSerializer planSerializer; @Override public void execute(PhysicalPlan physicalPlan, ResponseListener listener) { @@ -49,7 +51,8 @@ public void execute(PhysicalPlan physicalPlan, ExecutionContext context, result.add(plan.next()); } - QueryResponse response = new QueryResponse(physicalPlan.schema(), result); + QueryResponse response = new QueryResponse(physicalPlan.schema(), result, + planSerializer.convertToCursor(plan)); listener.onResponse(response); } catch (Exception e) { listener.onFailure(e); @@ -67,7 +70,7 @@ public void explain(PhysicalPlan plan, ResponseListener listene @Override public ExplainResponseNode visitTableScan(TableScanOperator node, Object context) { return explain(node, context, explainNode -> { - explainNode.setDescription(ImmutableMap.of("request", node.explain())); + explainNode.setDescription(Map.of("request", node.explain())); }); } }; @@ -78,5 +81,4 @@ public ExplainResponseNode visitTableScan(TableScanOperator node, Object context } }); } - } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java index 9d71cee8c9..dff5545785 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java @@ -12,6 +12,7 @@ import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator; import org.opensearch.sql.opensearch.planner.physical.MLOperator; import org.opensearch.sql.planner.physical.AggregationOperator; +import org.opensearch.sql.planner.physical.CursorCloseOperator; import org.opensearch.sql.planner.physical.DedupeOperator; import org.opensearch.sql.planner.physical.EvalOperator; import org.opensearch.sql.planner.physical.FilterOperator; @@ -42,6 +43,15 @@ public PhysicalPlan protect(PhysicalPlan physicalPlan) { return physicalPlan.accept(this, null); } + /** + * Don't protect {@link CursorCloseOperator} and entire nested tree, because + * {@link CursorCloseOperator} as designed as no-op. + */ + @Override + public PhysicalPlan visitCursorClose(CursorCloseOperator node, Object context) { + return node; + } + @Override public PhysicalPlan visitFilter(FilterOperator node, Object context) { return new FilterOperator(visitInput(node.getInput(), context), node.getConditions()); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/ResourceMonitorPlan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/ResourceMonitorPlan.java index 8fc7480dd1..4c02affc5e 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/ResourceMonitorPlan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/ResourceMonitorPlan.java @@ -6,12 +6,16 @@ package org.opensearch.sql.opensearch.executor.protector; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; import java.util.List; import lombok.EqualsAndHashCode; import lombok.RequiredArgsConstructor; import lombok.ToString; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.monitor.ResourceMonitor; +import org.opensearch.sql.planner.SerializablePlan; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; @@ -21,7 +25,7 @@ @ToString @RequiredArgsConstructor @EqualsAndHashCode(callSuper = false) -public class ResourceMonitorPlan extends PhysicalPlan { +public class ResourceMonitorPlan extends PhysicalPlan implements SerializablePlan { /** * How many method calls to delegate's next() to perform resource check once. @@ -82,4 +86,23 @@ public ExprValue next() { } return delegate.next(); } + + @Override + public SerializablePlan getPlanForSerialization() { + return (SerializablePlan) delegate; + } + + /** + * Those two methods should never be called. They called if a plan upper in the tree missed to + * call {@link #getPlanForSerialization}. + */ + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + throw new UnsupportedOperationException(); + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + throw new UnsupportedOperationException(); + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequest.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequest.java index 3976f854fd..45954a3871 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequest.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequest.java @@ -6,7 +6,7 @@ package org.opensearch.sql.opensearch.request; -import com.google.common.annotations.VisibleForTesting; +import java.io.IOException; import java.util.Arrays; import java.util.List; import java.util.function.Consumer; @@ -17,7 +17,7 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchScrollRequest; -import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.search.SearchHits; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; @@ -35,11 +35,6 @@ @ToString public class OpenSearchQueryRequest implements OpenSearchRequest { - /** - * Default query timeout in minutes. - */ - public static final TimeValue DEFAULT_QUERY_TIMEOUT = TimeValue.timeValueMinutes(1L); - /** * {@link OpenSearchRequest.IndexName}. */ @@ -51,6 +46,7 @@ public class OpenSearchQueryRequest implements OpenSearchRequest { private final SearchSourceBuilder sourceBuilder; + /** * OpenSearchExprValueFactory. */ @@ -106,7 +102,9 @@ public OpenSearchResponse search(Function searchA } else { searchDone = true; return new OpenSearchResponse( - searchAction.apply(searchRequest()), exprValueFactory, includes); + searchAction.apply(new SearchRequest() + .indices(indexName.getIndexNames()) + .source(sourceBuilder)), exprValueFactory, includes); } } @@ -115,15 +113,14 @@ public void clean(Consumer cleanAction) { //do nothing. } - /** - * Generate OpenSearch search request. - * - * @return search request - */ - @VisibleForTesting - protected SearchRequest searchRequest() { - return new SearchRequest() - .indices(indexName.getIndexNames()) - .source(sourceBuilder); + @Override + public boolean hasAnotherBatch() { + return false; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new UnsupportedOperationException("OpenSearchQueryRequest serialization " + + "is not implemented."); } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequest.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequest.java index ce990780c1..ee9da5b53b 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequest.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequest.java @@ -6,20 +6,29 @@ package org.opensearch.sql.opensearch.request; +import java.io.IOException; import java.util.function.Consumer; import java.util.function.Function; import lombok.EqualsAndHashCode; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchScrollRequest; -import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.common.unit.TimeValue; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.response.OpenSearchResponse; /** * OpenSearch search request. */ -public interface OpenSearchRequest { +public interface OpenSearchRequest extends Writeable { + /** + * Default query timeout in minutes. + */ + TimeValue DEFAULT_QUERY_TIMEOUT = TimeValue.timeValueMinutes(1L); + /** * Apply the search action or scroll action on request based on context. * @@ -37,29 +46,33 @@ OpenSearchResponse search(Function searchAction, */ void clean(Consumer cleanAction); - /** - * Get the SearchSourceBuilder. - * - * @return SearchSourceBuilder. - */ - SearchSourceBuilder getSourceBuilder(); - /** * Get the ElasticsearchExprValueFactory. * @return ElasticsearchExprValueFactory. */ OpenSearchExprValueFactory getExprValueFactory(); + /** + * Check if there is more data to get from OpenSearch. + * @return True if calling {@ref OpenSearchClient.search} with this request will + * return non-empty response. + */ + boolean hasAnotherBatch(); + /** * OpenSearch Index Name. - * Indices are seperated by ",". + * Indices are separated by ",". */ @EqualsAndHashCode - class IndexName { + class IndexName implements Writeable { private static final String COMMA = ","; private final String[] indexNames; + public IndexName(StreamInput si) throws IOException { + indexNames = si.readStringArray(); + } + public IndexName(String indexName) { this.indexNames = indexName.split(COMMA); } @@ -72,5 +85,10 @@ public String[] getIndexNames() { public String toString() { return String.join(COMMA, indexNames); } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeStringArray(indexNames); + } } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java index 6f0618d961..bec133f834 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java @@ -14,7 +14,6 @@ import static org.opensearch.search.sort.SortOrder.ASC; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.List; import java.util.Map; @@ -40,7 +39,6 @@ import org.opensearch.search.sort.SortBuilder; import org.opensearch.search.sort.SortBuilders; import org.opensearch.sql.ast.expression.Literal; -import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.ReferenceExpression; @@ -57,24 +55,19 @@ public class OpenSearchRequestBuilder { /** - * Default query timeout in minutes. - */ - public static final TimeValue DEFAULT_QUERY_TIMEOUT = TimeValue.timeValueMinutes(1L); - - /** - * {@link OpenSearchRequest.IndexName}. + * Search request source builder. */ - private final OpenSearchRequest.IndexName indexName; + private final SearchSourceBuilder sourceBuilder; /** - * Index max result window. + * Query size of the request -- how many rows will be returned. */ - private final Integer maxResultWindow; + private int requestedTotalSize; /** - * Search request source builder. + * Size of each page request to return. */ - private final SearchSourceBuilder sourceBuilder; + private Integer pageSize = null; /** * OpenSearchExprValueFactory. @@ -82,35 +75,19 @@ public class OpenSearchRequestBuilder { @EqualsAndHashCode.Exclude @ToString.Exclude private final OpenSearchExprValueFactory exprValueFactory; - - /** - * Query size of the request. - */ - private Integer querySize; - - public OpenSearchRequestBuilder(String indexName, - Integer maxResultWindow, - Settings settings, - OpenSearchExprValueFactory exprValueFactory) { - this(new OpenSearchRequest.IndexName(indexName), maxResultWindow, settings, exprValueFactory); - } + private int startFrom = 0; /** * Constructor. */ - public OpenSearchRequestBuilder(OpenSearchRequest.IndexName indexName, - Integer maxResultWindow, - Settings settings, + public OpenSearchRequestBuilder(int requestedTotalSize, OpenSearchExprValueFactory exprValueFactory) { - this.indexName = indexName; - this.maxResultWindow = maxResultWindow; - this.sourceBuilder = new SearchSourceBuilder(); + this.requestedTotalSize = requestedTotalSize; + this.sourceBuilder = new SearchSourceBuilder() + .from(startFrom) + .timeout(OpenSearchRequest.DEFAULT_QUERY_TIMEOUT) + .trackScores(false); this.exprValueFactory = exprValueFactory; - this.querySize = settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT); - sourceBuilder.from(0); - sourceBuilder.size(querySize); - sourceBuilder.timeout(DEFAULT_QUERY_TIMEOUT); - sourceBuilder.trackScores(false); } /** @@ -118,24 +95,40 @@ public OpenSearchRequestBuilder(OpenSearchRequest.IndexName indexName, * * @return query request or scroll request */ - public OpenSearchRequest build() { - Integer from = sourceBuilder.from(); - Integer size = sourceBuilder.size(); - - if (from + size <= maxResultWindow) { - return new OpenSearchQueryRequest(indexName, sourceBuilder, exprValueFactory); + public OpenSearchRequest build(OpenSearchRequest.IndexName indexName, + int maxResultWindow, TimeValue scrollTimeout) { + int size = requestedTotalSize; + if (pageSize == null) { + if (startFrom + size > maxResultWindow) { + sourceBuilder.size(maxResultWindow - startFrom); + return new OpenSearchScrollRequest( + indexName, scrollTimeout, sourceBuilder, exprValueFactory); + } else { + sourceBuilder.from(startFrom); + sourceBuilder.size(requestedTotalSize); + return new OpenSearchQueryRequest(indexName, sourceBuilder, exprValueFactory); + } } else { - sourceBuilder.size(maxResultWindow - from); - return new OpenSearchScrollRequest(indexName, sourceBuilder, exprValueFactory); + if (startFrom != 0) { + throw new UnsupportedOperationException("Non-zero offset is not supported with pagination"); + } + sourceBuilder.size(pageSize); + return new OpenSearchScrollRequest(indexName, scrollTimeout, + sourceBuilder, exprValueFactory); } } + + boolean isBoolFilterQuery(QueryBuilder current) { + return (current instanceof BoolQueryBuilder); + } + /** * Push down query to DSL request. * * @param query query request */ - public void pushDown(QueryBuilder query) { + public void pushDownFilter(QueryBuilder query) { QueryBuilder current = sourceBuilder.query(); if (current == null) { @@ -162,7 +155,7 @@ public void pushDown(QueryBuilder query) { */ public void pushDownAggregation( Pair, OpenSearchAggregationResponseParser> aggregationBuilder) { - aggregationBuilder.getLeft().forEach(builder -> sourceBuilder.aggregation(builder)); + aggregationBuilder.getLeft().forEach(sourceBuilder::aggregation); sourceBuilder.size(0); exprValueFactory.setParser(aggregationBuilder.getRight()); } @@ -184,10 +177,11 @@ public void pushDownSort(List> sortBuilders) { } /** - * Push down size (limit) and from (offset) to DSL request. + * Pushdown size (limit) and from (offset) to DSL request. */ public void pushDownLimit(Integer limit, Integer offset) { - querySize = limit; + requestedTotalSize = limit; + startFrom = offset; sourceBuilder.from(offset).size(limit); } @@ -195,6 +189,10 @@ public void pushDownTrackedScore(boolean trackScores) { sourceBuilder.trackScores(trackScores); } + public void pushDownPageSize(int pageSize) { + this.pageSize = pageSize; + } + /** * Add highlight to DSL requests. * @param field name of the field to highlight @@ -229,26 +227,22 @@ public void pushDownHighlight(String field, Map arguments) { } /** - * Push down project list to DSL requets. + * Push down project list to DSL requests. */ public void pushDownProjects(Set projects) { - final Set projectsSet = - projects.stream().map(ReferenceExpression::getAttr).collect(Collectors.toSet()); - sourceBuilder.fetchSource(projectsSet.toArray(new String[0]), new String[0]); + sourceBuilder.fetchSource( + projects.stream().map(ReferenceExpression::getAttr).distinct().toArray(String[]::new), + new String[0]); } public void pushTypeMapping(Map typeMapping) { exprValueFactory.extendTypeMapping(typeMapping); } - private boolean isBoolFilterQuery(QueryBuilder current) { - return (current instanceof BoolQueryBuilder); - } - private boolean isSortByDocOnly() { List> sorts = sourceBuilder.sorts(); if (sorts != null) { - return sorts.equals(Arrays.asList(SortBuilders.fieldSort(DOC_FIELD_NAME))); + return sorts.equals(List.of(SortBuilders.fieldSort(DOC_FIELD_NAME))); } return false; } @@ -286,6 +280,10 @@ private List extractNestedQueries(QueryBuilder query) { return result; } + public int getMaxResponseSize() { + return pageSize == null ? requestedTotalSize : pageSize; + } + /** * Initialize bool query for push down. */ diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequest.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequest.java index 9b0d6ca074..403626c610 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequest.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequest.java @@ -6,6 +6,7 @@ package org.opensearch.sql.opensearch.request; +import java.io.IOException; import java.util.Arrays; import java.util.List; import java.util.Objects; @@ -18,11 +19,14 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchScrollRequest; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.unit.TimeValue; import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.response.OpenSearchResponse; +import org.opensearch.sql.opensearch.storage.OpenSearchIndex; +import org.opensearch.sql.opensearch.storage.OpenSearchStorageEngine; /** * OpenSearch scroll search request. This has to be stateful because it needs to: @@ -34,9 +38,9 @@ @Getter @ToString public class OpenSearchScrollRequest implements OpenSearchRequest { - - /** Default scroll context timeout in minutes. */ - public static final TimeValue DEFAULT_SCROLL_TIMEOUT = TimeValue.timeValueMinutes(1L); + private final SearchRequest initialSearchRequest; + /** Scroll context timeout. */ + private final TimeValue scrollTimeout; /** * {@link OpenSearchRequest.IndexName}. @@ -49,83 +53,80 @@ public class OpenSearchScrollRequest implements OpenSearchRequest { private final OpenSearchExprValueFactory exprValueFactory; /** - * Scroll id which is set after first request issued. Because ElasticsearchClient is shared by - * multi-thread so this state has to be maintained here. + * Scroll id which is set after first request issued. Because OpenSearchClient is shared by + * multiple threads so this state has to be maintained here. */ @Setter - private String scrollId; + @Getter + private String scrollId = NO_SCROLL_ID; - /** Search request source builder. */ - private final SearchSourceBuilder sourceBuilder; + public static final String NO_SCROLL_ID = ""; - /** Constructor. */ - public OpenSearchScrollRequest(IndexName indexName, OpenSearchExprValueFactory exprValueFactory) { - this.indexName = indexName; - this.sourceBuilder = new SearchSourceBuilder(); - this.exprValueFactory = exprValueFactory; - } + @EqualsAndHashCode.Exclude + private boolean needClean = true; - public OpenSearchScrollRequest(String indexName, OpenSearchExprValueFactory exprValueFactory) { - this(new IndexName(indexName), exprValueFactory); - } + @Getter + private final List includes; /** Constructor. */ public OpenSearchScrollRequest(IndexName indexName, + TimeValue scrollTimeout, SearchSourceBuilder sourceBuilder, OpenSearchExprValueFactory exprValueFactory) { this.indexName = indexName; - this.sourceBuilder = sourceBuilder; + this.scrollTimeout = scrollTimeout; this.exprValueFactory = exprValueFactory; + this.initialSearchRequest = new SearchRequest() + .indices(indexName.getIndexNames()) + .scroll(scrollTimeout) + .source(sourceBuilder); + + includes = sourceBuilder.fetchSource() == null + ? List.of() + : Arrays.asList(sourceBuilder.fetchSource().includes()); } - /** Constructor. */ + + /** Executes request using either {@param searchAction} or {@param scrollAction} as appropriate. + */ @Override public OpenSearchResponse search(Function searchAction, Function scrollAction) { SearchResponse openSearchResponse; - if (isScrollStarted()) { + if (isScroll()) { openSearchResponse = scrollAction.apply(scrollRequest()); } else { - openSearchResponse = searchAction.apply(searchRequest()); + openSearchResponse = searchAction.apply(initialSearchRequest); + } + + var response = new OpenSearchResponse(openSearchResponse, exprValueFactory, includes); + needClean = response.isEmpty(); + if (!needClean) { + setScrollId(openSearchResponse.getScrollId()); } - setScrollId(openSearchResponse.getScrollId()); - FetchSourceContext fetchSource = this.sourceBuilder.fetchSource(); - List includes = fetchSource != null && fetchSource.includes() != null - ? Arrays.asList(this.sourceBuilder.fetchSource().includes()) - : List.of(); - return new OpenSearchResponse(openSearchResponse, exprValueFactory, includes); + return response; } @Override public void clean(Consumer cleanAction) { try { - if (isScrollStarted()) { + // clean on the last page only, to prevent closing the scroll/cursor in the middle of paging. + if (needClean && isScroll()) { cleanAction.accept(getScrollId()); + setScrollId(NO_SCROLL_ID); } } finally { reset(); } } - /** - * Generate OpenSearch search request. - * - * @return search request - */ - public SearchRequest searchRequest() { - return new SearchRequest() - .indices(indexName.getIndexNames()) - .scroll(DEFAULT_SCROLL_TIMEOUT) - .source(sourceBuilder); - } - /** * Is scroll started which means pages after first is being requested. * * @return true if scroll started */ - public boolean isScrollStarted() { - return (scrollId != null); + public boolean isScroll() { + return !scrollId.equals(NO_SCROLL_ID); } /** @@ -135,7 +136,7 @@ public boolean isScrollStarted() { */ public SearchScrollRequest scrollRequest() { Objects.requireNonNull(scrollId, "Scroll id cannot be null"); - return new SearchScrollRequest().scroll(DEFAULT_SCROLL_TIMEOUT).scrollId(scrollId); + return new SearchScrollRequest().scroll(scrollTimeout).scrollId(scrollId); } /** @@ -143,6 +144,37 @@ public SearchScrollRequest scrollRequest() { * to be reused across different physical plan. */ public void reset() { - scrollId = null; + scrollId = NO_SCROLL_ID; + } + + @Override + public boolean hasAnotherBatch() { + return !needClean && !scrollId.equals(NO_SCROLL_ID); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + initialSearchRequest.writeTo(out); + out.writeTimeValue(scrollTimeout); + out.writeString(scrollId); + out.writeStringCollection(includes); + indexName.writeTo(out); + } + + /** + * Constructs OpenSearchScrollRequest from serialized representation. + * @param in stream to read data from. + * @param engine OpenSearchSqlEngine to get node-specific context. + * @throws IOException thrown if reading from input {@param in} fails. + */ + public OpenSearchScrollRequest(StreamInput in, OpenSearchStorageEngine engine) + throws IOException { + initialSearchRequest = new SearchRequest(in); + scrollTimeout = in.readTimeValue(); + scrollId = in.readString(); + includes = in.readStringList(); + indexName = new IndexName(in); + OpenSearchIndex index = (OpenSearchIndex) engine.getTable(null, indexName.toString()); + exprValueFactory = new OpenSearchExprValueFactory(index.getFieldOpenSearchTypes()); } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java index 204a6bca22..733fad6203 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java @@ -57,13 +57,13 @@ public class OpenSearchResponse implements Iterable { private final List includes; /** - * ElasticsearchExprValueFactory used to build ExprValue from search result. + * OpenSearchExprValueFactory used to build ExprValue from search result. */ @EqualsAndHashCode.Exclude private final OpenSearchExprValueFactory exprValueFactory; /** - * Constructor of ElasticsearchResponse. + * Constructor of OpenSearchResponse. */ public OpenSearchResponse(SearchResponse searchResponse, OpenSearchExprValueFactory exprValueFactory, @@ -75,7 +75,7 @@ public OpenSearchResponse(SearchResponse searchResponse, } /** - * Constructor of ElasticsearchResponse with SearchHits. + * Constructor of OpenSearchResponse with SearchHits. */ public OpenSearchResponse(SearchHits hits, OpenSearchExprValueFactory exprValueFactory, diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java index ae5174d678..accd356041 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java @@ -99,8 +99,8 @@ public class OpenSearchSettings extends Settings { Setting.Property.Dynamic); /** - * Construct ElasticsearchSetting. - * The ElasticsearchSetting must be singleton. + * Construct OpenSearchSetting. + * The OpenSearchSetting must be singleton. */ @SuppressWarnings("unchecked") public OpenSearchSettings(ClusterSettings clusterSettings) { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java index cf09b32de9..532d62333d 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java @@ -11,6 +11,7 @@ import java.util.LinkedHashMap; import java.util.Map; import lombok.RequiredArgsConstructor; +import org.opensearch.common.unit.TimeValue; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; @@ -21,7 +22,9 @@ import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator; import org.opensearch.sql.opensearch.planner.physical.MLOperator; import org.opensearch.sql.opensearch.request.OpenSearchRequest; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.opensearch.request.system.OpenSearchDescribeIndexRequest; +import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScan; import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScanBuilder; import org.opensearch.sql.planner.DefaultImplementor; import org.opensearch.sql.planner.logical.LogicalAD; @@ -30,6 +33,7 @@ import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.storage.Table; +import org.opensearch.sql.storage.TableScanOperator; import org.opensearch.sql.storage.read.TableScanBuilder; /** OpenSearch table (index) implementation. */ @@ -164,19 +168,29 @@ public PhysicalPlan implement(LogicalPlan plan) { } @Override - public LogicalPlan optimize(LogicalPlan plan) { - // No-op because optimization already done in Planner - return plan; + public TableScanBuilder createScanBuilder() { + final int querySizeLimit = settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT); + + var builder = new OpenSearchRequestBuilder( + querySizeLimit, + createExprValueFactory()); + + return new OpenSearchIndexScanBuilder(builder) { + @Override + protected TableScanOperator createScan(OpenSearchRequestBuilder requestBuilder) { + final TimeValue cursorKeepAlive = + settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE); + return new OpenSearchIndexScan(client, requestBuilder.getMaxResponseSize(), + requestBuilder.build(indexName, getMaxResultWindow(), cursorKeepAlive)); + } + }; } - @Override - public TableScanBuilder createScanBuilder() { + private OpenSearchExprValueFactory createExprValueFactory() { Map allFields = new HashMap<>(); getReservedFieldTypes().forEach((k, v) -> allFields.put(k, OpenSearchDataType.of(v))); allFields.putAll(getFieldOpenSearchTypes()); - OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, settings, indexName, - getMaxResultWindow(), new OpenSearchExprValueFactory(allFields)); - return new OpenSearchIndexScanBuilder(indexScan); + return new OpenSearchExprValueFactory(allFields); } @VisibleForTesting diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScan.java deleted file mode 100644 index a26e64a809..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScan.java +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.storage; - -import java.util.Collections; -import java.util.Iterator; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.ToString; -import org.opensearch.sql.common.setting.Settings; -import org.opensearch.sql.data.model.ExprValue; -import org.opensearch.sql.opensearch.client.OpenSearchClient; -import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; -import org.opensearch.sql.opensearch.request.OpenSearchRequest; -import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; -import org.opensearch.sql.opensearch.response.OpenSearchResponse; -import org.opensearch.sql.storage.TableScanOperator; - -/** - * OpenSearch index scan operator. - */ -@EqualsAndHashCode(onlyExplicitlyIncluded = true, callSuper = false) -@ToString(onlyExplicitlyIncluded = true) -public class OpenSearchIndexScan extends TableScanOperator { - - /** OpenSearch client. */ - private final OpenSearchClient client; - - /** Search request builder. */ - @EqualsAndHashCode.Include - @Getter - @ToString.Include - private final OpenSearchRequestBuilder requestBuilder; - - /** Search request. */ - @EqualsAndHashCode.Include - @ToString.Include - private OpenSearchRequest request; - - /** Total query size. */ - @EqualsAndHashCode.Include - @ToString.Include - private Integer querySize; - - /** Number of rows returned. */ - private Integer queryCount; - - /** Search response for current batch. */ - private Iterator iterator; - - /** - * Constructor. - */ - public OpenSearchIndexScan(OpenSearchClient client, Settings settings, - String indexName, Integer maxResultWindow, - OpenSearchExprValueFactory exprValueFactory) { - this( - client, - settings, - new OpenSearchRequest.IndexName(indexName), - maxResultWindow, - exprValueFactory - ); - } - - /** - * Constructor. - */ - public OpenSearchIndexScan(OpenSearchClient client, Settings settings, - OpenSearchRequest.IndexName indexName, Integer maxResultWindow, - OpenSearchExprValueFactory exprValueFactory) { - this.client = client; - this.requestBuilder = new OpenSearchRequestBuilder( - indexName, maxResultWindow, settings, exprValueFactory); - } - - @Override - public void open() { - super.open(); - querySize = requestBuilder.getQuerySize(); - request = requestBuilder.build(); - iterator = Collections.emptyIterator(); - queryCount = 0; - fetchNextBatch(); - } - - @Override - public boolean hasNext() { - if (queryCount >= querySize) { - iterator = Collections.emptyIterator(); - } else if (!iterator.hasNext()) { - fetchNextBatch(); - } - return iterator.hasNext(); - } - - @Override - public ExprValue next() { - queryCount++; - return iterator.next(); - } - - private void fetchNextBatch() { - OpenSearchResponse response = client.search(request); - if (!response.isEmpty()) { - iterator = response.iterator(); - } - } - - @Override - public void close() { - super.close(); - - client.cleanup(request); - } - - @Override - public String explain() { - return getRequestBuilder().build().toString(); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngine.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngine.java index 4a3393abc9..c915fa549b 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngine.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngine.java @@ -8,6 +8,7 @@ import static org.opensearch.sql.utils.SystemIndexUtils.isSystemIndex; +import lombok.Getter; import lombok.RequiredArgsConstructor; import org.opensearch.sql.DataSourceSchemaName; import org.opensearch.sql.common.setting.Settings; @@ -21,8 +22,9 @@ public class OpenSearchStorageEngine implements StorageEngine { /** OpenSearch client connection. */ + @Getter private final OpenSearchClient client; - + @Getter private final Settings settings; @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScan.java new file mode 100644 index 0000000000..e216e1e2fe --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScan.java @@ -0,0 +1,154 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.sql.opensearch.storage.scan; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.util.Collections; +import java.util.Iterator; +import lombok.EqualsAndHashCode; +import lombok.ToString; +import org.opensearch.common.io.stream.BytesStreamInput; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.exception.NoCursorException; +import org.opensearch.sql.executor.pagination.PlanSerializer; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.request.OpenSearchRequest; +import org.opensearch.sql.opensearch.request.OpenSearchScrollRequest; +import org.opensearch.sql.opensearch.response.OpenSearchResponse; +import org.opensearch.sql.opensearch.storage.OpenSearchStorageEngine; +import org.opensearch.sql.planner.SerializablePlan; +import org.opensearch.sql.storage.TableScanOperator; + +/** + * OpenSearch index scan operator. + */ +@EqualsAndHashCode(onlyExplicitlyIncluded = true, callSuper = false) +@ToString(onlyExplicitlyIncluded = true) +public class OpenSearchIndexScan extends TableScanOperator implements SerializablePlan { + + /** OpenSearch client. */ + private OpenSearchClient client; + + /** Search request. */ + @EqualsAndHashCode.Include + @ToString.Include + private OpenSearchRequest request; + + /** Largest number of rows allowed in the response. */ + @EqualsAndHashCode.Include + @ToString.Include + private int maxResponseSize; + + /** Number of rows returned. */ + private Integer queryCount; + + /** Search response for current batch. */ + private Iterator iterator; + + /** + * Creates index scan based on a provided OpenSearchRequestBuilder. + */ + public OpenSearchIndexScan(OpenSearchClient client, + int maxResponseSize, + OpenSearchRequest request) { + this.client = client; + this.maxResponseSize = maxResponseSize; + this.request = request; + } + + @Override + public void open() { + super.open(); + iterator = Collections.emptyIterator(); + queryCount = 0; + fetchNextBatch(); + } + + @Override + public boolean hasNext() { + if (queryCount >= maxResponseSize) { + iterator = Collections.emptyIterator(); + } else if (!iterator.hasNext()) { + fetchNextBatch(); + } + return iterator.hasNext(); + } + + @Override + public ExprValue next() { + queryCount++; + return iterator.next(); + } + + private void fetchNextBatch() { + OpenSearchResponse response = client.search(request); + if (!response.isEmpty()) { + iterator = response.iterator(); + } + } + + @Override + public void close() { + super.close(); + + client.cleanup(request); + } + + @Override + public String explain() { + return request.toString(); + } + + /** No-args constructor. + * @deprecated Exists only to satisfy Java serialization API. + */ + @Deprecated(since = "introduction") + public OpenSearchIndexScan() { + } + + @Override + public void readExternal(ObjectInput in) throws IOException { + int reqSize = in.readInt(); + byte[] requestStream = new byte[reqSize]; + in.read(requestStream); + + var engine = (OpenSearchStorageEngine) ((PlanSerializer.CursorDeserializationStream) in) + .resolveObject("engine"); + + try (BytesStreamInput bsi = new BytesStreamInput(requestStream)) { + request = new OpenSearchScrollRequest(bsi, engine); + } + maxResponseSize = in.readInt(); + + client = engine.getClient(); + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + if (!request.hasAnotherBatch()) { + throw new NoCursorException(); + } + // request is not directly Serializable so.. + // 1. Serialize request to an opensearch byte stream. + BytesStreamOutput reqOut = new BytesStreamOutput(); + request.writeTo(reqOut); + reqOut.flush(); + + // 2. Extract byte[] from the opensearch byte stream + var reqAsBytes = reqOut.bytes().toBytesRef().bytes; + + // 3. Write out the byte[] to object output stream. + out.writeInt(reqAsBytes.length); + out.write(reqAsBytes); + + out.writeInt(maxResponseSize); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanAggregationBuilder.java index e52fc566cd..84883b5209 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanAggregationBuilder.java @@ -8,6 +8,7 @@ import java.util.List; import java.util.Set; import java.util.stream.Collectors; +import lombok.EqualsAndHashCode; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.sql.ast.tree.Sort; @@ -15,58 +16,60 @@ import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.aggregation.NamedAggregator; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; -import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; import org.opensearch.sql.opensearch.storage.script.aggregation.AggregationQueryBuilder; import org.opensearch.sql.opensearch.storage.serialization.DefaultExpressionSerializer; import org.opensearch.sql.planner.logical.LogicalAggregation; +import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalHighlight; +import org.opensearch.sql.planner.logical.LogicalLimit; +import org.opensearch.sql.planner.logical.LogicalNested; +import org.opensearch.sql.planner.logical.LogicalPaginate; +import org.opensearch.sql.planner.logical.LogicalProject; import org.opensearch.sql.planner.logical.LogicalSort; -import org.opensearch.sql.storage.TableScanOperator; -import org.opensearch.sql.storage.read.TableScanBuilder; /** * Index scan builder for aggregate query used by {@link OpenSearchIndexScanBuilder} internally. */ -class OpenSearchIndexScanAggregationBuilder extends TableScanBuilder { +@EqualsAndHashCode +class OpenSearchIndexScanAggregationBuilder implements PushDownQueryBuilder { /** OpenSearch index scan to be optimized. */ - private final OpenSearchIndexScan indexScan; + private final OpenSearchRequestBuilder requestBuilder; /** Aggregators pushed down. */ - private List aggregatorList; + private final List aggregatorList; /** Grouping items pushed down. */ - private List groupByList; + private final List groupByList; /** Sorting items pushed down. */ private List> sortList; - /** - * Initialize with given index scan and perform push-down optimization later. - * - * @param indexScan index scan not fully optimized yet - */ - OpenSearchIndexScanAggregationBuilder(OpenSearchIndexScan indexScan) { - this.indexScan = indexScan; + + OpenSearchIndexScanAggregationBuilder(OpenSearchRequestBuilder requestBuilder, + LogicalAggregation aggregation) { + this.requestBuilder = requestBuilder; + aggregatorList = aggregation.getAggregatorList(); + groupByList = aggregation.getGroupByList(); } @Override - public TableScanOperator build() { + public OpenSearchRequestBuilder build() { AggregationQueryBuilder builder = new AggregationQueryBuilder(new DefaultExpressionSerializer()); Pair, OpenSearchAggregationResponseParser> aggregationBuilder = builder.buildAggregationBuilder(aggregatorList, groupByList, sortList); - indexScan.getRequestBuilder().pushDownAggregation(aggregationBuilder); - indexScan.getRequestBuilder().pushTypeMapping( + requestBuilder.pushDownAggregation(aggregationBuilder); + requestBuilder.pushTypeMapping( builder.buildTypeMapping(aggregatorList, groupByList)); - return indexScan; + return requestBuilder; } @Override - public boolean pushDownAggregation(LogicalAggregation aggregation) { - aggregatorList = aggregation.getAggregatorList(); - groupByList = aggregation.getGroupByList(); - return true; + public boolean pushDownFilter(LogicalFilter filter) { + return false; } @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java index 8e6c57d7d5..c6df692095 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java @@ -5,15 +5,15 @@ package org.opensearch.sql.opensearch.storage.scan; -import com.google.common.annotations.VisibleForTesting; import lombok.EqualsAndHashCode; import org.opensearch.sql.expression.ReferenceExpression; -import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.planner.logical.LogicalAggregation; import org.opensearch.sql.planner.logical.LogicalFilter; import org.opensearch.sql.planner.logical.LogicalHighlight; import org.opensearch.sql.planner.logical.LogicalLimit; import org.opensearch.sql.planner.logical.LogicalNested; +import org.opensearch.sql.planner.logical.LogicalPaginate; import org.opensearch.sql.planner.logical.LogicalProject; import org.opensearch.sql.planner.logical.LogicalSort; import org.opensearch.sql.storage.TableScanOperator; @@ -24,36 +24,39 @@ * by delegated builder internally. This is to avoid conditional check of different push down logic * for non-aggregate and aggregate query everywhere. */ -public class OpenSearchIndexScanBuilder extends TableScanBuilder { +public abstract class OpenSearchIndexScanBuilder extends TableScanBuilder { /** * Delegated index scan builder for non-aggregate or aggregate query. */ @EqualsAndHashCode.Include - private TableScanBuilder delegate; + private PushDownQueryBuilder delegate; /** Is limit operator pushed down. */ private boolean isLimitPushedDown = false; - @VisibleForTesting - OpenSearchIndexScanBuilder(TableScanBuilder delegate) { - this.delegate = delegate; + /** + * Constructor used during query execution. + */ + protected OpenSearchIndexScanBuilder(OpenSearchRequestBuilder requestBuilder) { + this.delegate = new OpenSearchIndexScanQueryBuilder(requestBuilder); + } /** - * Initialize with given index scan. - * - * @param indexScan index scan to optimize + * Constructor used for unit tests. */ - public OpenSearchIndexScanBuilder(OpenSearchIndexScan indexScan) { - this.delegate = new OpenSearchIndexScanQueryBuilder(indexScan); + protected OpenSearchIndexScanBuilder(PushDownQueryBuilder translator) { + this.delegate = translator; } @Override public TableScanOperator build() { - return delegate.build(); + return createScan(delegate.build()); } + protected abstract TableScanOperator createScan(OpenSearchRequestBuilder requestBuilder); + @Override public boolean pushDownFilter(LogicalFilter filter) { return delegate.pushDownFilter(filter); @@ -67,10 +70,13 @@ public boolean pushDownAggregation(LogicalAggregation aggregation) { // Switch to builder for aggregate query which has different push down logic // for later filter, sort and limit operator. - delegate = new OpenSearchIndexScanAggregationBuilder( - (OpenSearchIndexScan) delegate.build()); + delegate = new OpenSearchIndexScanAggregationBuilder(delegate.build(), aggregation); + return true; + } - return delegate.pushDownAggregation(aggregation); + @Override + public boolean pushDownPageSize(LogicalPaginate paginate) { + return delegate.pushDownPageSize(paginate); } @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java index f20556ccc5..590272a9f1 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java @@ -22,7 +22,7 @@ import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.function.OpenSearchFunctions; -import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder; import org.opensearch.sql.opensearch.storage.script.sort.SortQueryBuilder; import org.opensearch.sql.opensearch.storage.serialization.DefaultExpressionSerializer; @@ -30,34 +30,22 @@ import org.opensearch.sql.planner.logical.LogicalHighlight; import org.opensearch.sql.planner.logical.LogicalLimit; import org.opensearch.sql.planner.logical.LogicalNested; +import org.opensearch.sql.planner.logical.LogicalPaginate; import org.opensearch.sql.planner.logical.LogicalProject; import org.opensearch.sql.planner.logical.LogicalSort; -import org.opensearch.sql.storage.TableScanOperator; -import org.opensearch.sql.storage.read.TableScanBuilder; /** * Index scan builder for simple non-aggregate query used by * {@link OpenSearchIndexScanBuilder} internally. */ @VisibleForTesting -class OpenSearchIndexScanQueryBuilder extends TableScanBuilder { +@EqualsAndHashCode +class OpenSearchIndexScanQueryBuilder implements PushDownQueryBuilder { - /** OpenSearch index scan to be optimized. */ - @EqualsAndHashCode.Include - private final OpenSearchIndexScan indexScan; + OpenSearchRequestBuilder requestBuilder; - /** - * Initialize with given index scan and perform push-down optimization later. - * - * @param indexScan index scan not optimized yet - */ - OpenSearchIndexScanQueryBuilder(OpenSearchIndexScan indexScan) { - this.indexScan = indexScan; - } - - @Override - public TableScanOperator build() { - return indexScan; + public OpenSearchIndexScanQueryBuilder(OpenSearchRequestBuilder requestBuilder) { + this.requestBuilder = requestBuilder; } @Override @@ -66,8 +54,8 @@ public boolean pushDownFilter(LogicalFilter filter) { new DefaultExpressionSerializer()); Expression queryCondition = filter.getCondition(); QueryBuilder query = queryBuilder.build(queryCondition); - indexScan.getRequestBuilder().pushDown(query); - indexScan.getRequestBuilder().pushDownTrackedScore( + requestBuilder.pushDownFilter(query); + requestBuilder.pushDownTrackedScore( trackScoresFromOpenSearchFunction(queryCondition)); return true; } @@ -76,7 +64,7 @@ public boolean pushDownFilter(LogicalFilter filter) { public boolean pushDownSort(LogicalSort sort) { List> sortList = sort.getSortList(); final SortQueryBuilder builder = new SortQueryBuilder(); - indexScan.getRequestBuilder().pushDownSort(sortList.stream() + requestBuilder.pushDownSort(sortList.stream() .map(sortItem -> builder.build(sortItem.getValue(), sortItem.getKey())) .collect(Collectors.toList())); return true; @@ -84,13 +72,13 @@ public boolean pushDownSort(LogicalSort sort) { @Override public boolean pushDownLimit(LogicalLimit limit) { - indexScan.getRequestBuilder().pushDownLimit(limit.getLimit(), limit.getOffset()); + requestBuilder.pushDownLimit(limit.getLimit(), limit.getOffset()); return true; } @Override public boolean pushDownProject(LogicalProject project) { - indexScan.getRequestBuilder().pushDownProjects( + requestBuilder.pushDownProjects( findReferenceExpressions(project.getProjectList())); // Return false intentionally to keep the original project operator @@ -99,12 +87,18 @@ public boolean pushDownProject(LogicalProject project) { @Override public boolean pushDownHighlight(LogicalHighlight highlight) { - indexScan.getRequestBuilder().pushDownHighlight( + requestBuilder.pushDownHighlight( StringUtils.unquoteText(highlight.getHighlightField().toString()), highlight.getArguments()); return true; } + @Override + public boolean pushDownPageSize(LogicalPaginate paginate) { + requestBuilder.pushDownPageSize(paginate.getPageSize()); + return true; + } + private boolean trackScoresFromOpenSearchFunction(Expression condition) { if (condition instanceof OpenSearchFunctions.OpenSearchFunction && ((OpenSearchFunctions.OpenSearchFunction) condition).isScoreTracked()) { @@ -119,8 +113,8 @@ private boolean trackScoresFromOpenSearchFunction(Expression condition) { @Override public boolean pushDownNested(LogicalNested nested) { - indexScan.getRequestBuilder().pushDownNested(nested.getFields()); - indexScan.getRequestBuilder().pushDownProjects( + requestBuilder.pushDownNested(nested.getFields()); + requestBuilder.pushDownProjects( findReferenceExpressions(nested.getProjectList())); // Return false intentionally to keep the original nested operator // Since we return false we need to pushDownProject here as it won't be @@ -129,11 +123,16 @@ public boolean pushDownNested(LogicalNested nested) { return false; } + @Override + public OpenSearchRequestBuilder build() { + return requestBuilder; + } + /** * Find reference expression from expression. * @param expressions a list of expression. * - * @return a list of ReferenceExpression + * @return a set of ReferenceExpression */ public static Set findReferenceExpressions( List expressions) { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/PushDownQueryBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/PushDownQueryBuilder.java new file mode 100644 index 0000000000..274bc4647d --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/PushDownQueryBuilder.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalHighlight; +import org.opensearch.sql.planner.logical.LogicalLimit; +import org.opensearch.sql.planner.logical.LogicalNested; +import org.opensearch.sql.planner.logical.LogicalPaginate; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.planner.logical.LogicalSort; + +/** + * Translates a logical query plan into OpenSearch DSL and an appropriate request. + */ +public interface PushDownQueryBuilder { + default boolean pushDownFilter(LogicalFilter filter) { + return false; + } + + default boolean pushDownSort(LogicalSort sort) { + return false; + } + + default boolean pushDownLimit(LogicalLimit limit) { + return false; + } + + default boolean pushDownProject(LogicalProject project) { + return false; + } + + default boolean pushDownHighlight(LogicalHighlight highlight) { + return false; + } + + default boolean pushDownPageSize(LogicalPaginate paginate) { + return false; + } + + default boolean pushDownNested(LogicalNested nested) { + return false; + } + + OpenSearchRequestBuilder build(); +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilder.java index 1efa5b65d5..8b1cb08cfa 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilder.java @@ -24,8 +24,6 @@ import org.opensearch.search.aggregations.bucket.missing.MissingOrder; import org.opensearch.search.sort.SortOrder; import org.opensearch.sql.ast.tree.Sort; -import org.opensearch.sql.data.type.ExprCoreType; -import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.ExpressionNodeVisitor; import org.opensearch.sql.expression.NamedExpression; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilder.java index 560bd52da9..51b10d2c41 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilder.java @@ -14,18 +14,15 @@ import java.util.Map; import java.util.function.BiFunction; import lombok.RequiredArgsConstructor; -import org.apache.lucene.search.join.ScoreMode; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.ScriptQueryBuilder; import org.opensearch.script.Script; -import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.ExpressionNodeVisitor; import org.opensearch.sql.expression.FunctionExpression; -import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.FunctionName; import org.opensearch.sql.opensearch.storage.script.filter.lucene.LikeQuery; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/system/OpenSearchSystemIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/system/OpenSearchSystemIndexScan.java index eb4cb865e2..ee377263c1 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/system/OpenSearchSystemIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/system/OpenSearchSystemIndexScan.java @@ -22,7 +22,7 @@ @ToString(onlyExplicitlyIncluded = true) public class OpenSearchSystemIndexScan extends TableScanOperator { /** - * OpenSearch client. + * OpenSearch request. */ private final OpenSearchSystemRequest request; @@ -33,7 +33,8 @@ public class OpenSearchSystemIndexScan extends TableScanOperator { @Override public void open() { - iterator = request.search().iterator(); + var response = request.search(); + iterator = response.iterator(); } @Override diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java index 8af9a4bbfa..b378fae297 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java @@ -29,13 +29,16 @@ import com.google.common.io.Resources; import java.io.IOException; import java.net.URL; -import java.util.Arrays; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; +import lombok.SneakyThrows; +import org.apache.commons.lang3.reflect.FieldUtils; import org.apache.lucene.search.TotalHits; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.InOrder; @@ -56,6 +59,7 @@ import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.metadata.MappingMetadata; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.xcontent.DeprecationHandler; @@ -64,6 +68,7 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.sql.data.model.ExprIntegerValue; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; @@ -71,10 +76,12 @@ import org.opensearch.sql.opensearch.data.type.OpenSearchTextType; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.mapping.IndexMapping; +import org.opensearch.sql.opensearch.request.OpenSearchRequest; import org.opensearch.sql.opensearch.request.OpenSearchScrollRequest; import org.opensearch.sql.opensearch.response.OpenSearchResponse; @ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class OpenSearchNodeClientTest { private static final String TEST_MAPPING_FILE = "mappings/accounts.json"; @@ -89,14 +96,11 @@ class OpenSearchNodeClientTest { @Mock private SearchHit searchHit; - @Mock - private ThreadContext threadContext; - @Mock private GetIndexResponse indexResponse; - private ExprTupleValue exprTupleValue = ExprTupleValue.fromExprValueMap(ImmutableMap.of("id", - new ExprIntegerValue(1))); + private final ExprTupleValue exprTupleValue = ExprTupleValue.fromExprValueMap( + Map.of("id", new ExprIntegerValue(1))); private OpenSearchClient client; @@ -106,7 +110,7 @@ void setUp() { } @Test - void isIndexExist() { + void is_index_exist() { when(nodeClient.admin().indices() .exists(any(IndicesExistsRequest.class)).actionGet()) .thenReturn(new IndicesExistsResponse(true)); @@ -115,7 +119,7 @@ void isIndexExist() { } @Test - void isIndexNotExist() { + void is_index_not_exist() { String indexName = "test"; when(nodeClient.admin().indices() .exists(any(IndicesExistsRequest.class)).actionGet()) @@ -125,14 +129,14 @@ void isIndexNotExist() { } @Test - void isIndexExistWithException() { + void is_index_exist_with_exception() { when(nodeClient.admin().indices().exists(any())).thenThrow(RuntimeException.class); assertThrows(IllegalStateException.class, () -> client.exists("test")); } @Test - void createIndex() { + void create_index() { String indexName = "test"; Map mappings = ImmutableMap.of( "properties", @@ -145,7 +149,7 @@ void createIndex() { } @Test - void createIndexWithException() { + void create_index_with_exception() { when(nodeClient.admin().indices().create(any())).thenThrow(RuntimeException.class); assertThrows(IllegalStateException.class, @@ -153,7 +157,7 @@ void createIndexWithException() { } @Test - void getIndexMappings() throws IOException { + void get_index_mappings() throws IOException { URL url = Resources.getResource(TEST_MAPPING_FILE); String mappings = Resources.toString(url, Charsets.UTF_8); String indexName = "test"; @@ -224,7 +228,7 @@ void getIndexMappings() throws IOException { } @Test - void getIndexMappingsWithEmptyMapping() { + void get_index_mappings_with_empty_mapping() { String indexName = "test"; mockNodeClientIndicesMappings(indexName, ""); Map indexMappings = client.getIndexMappings(indexName); @@ -235,7 +239,7 @@ void getIndexMappingsWithEmptyMapping() { } @Test - void getIndexMappingsWithIOException() { + void get_index_mappings_with_IOException() { String indexName = "test"; when(nodeClient.admin().indices()).thenThrow(RuntimeException.class); @@ -243,7 +247,7 @@ void getIndexMappingsWithIOException() { } @Test - void getIndexMappingsWithNonExistIndex() { + void get_index_mappings_with_non_exist_index() { when(nodeClient.admin().indices() .prepareGetMappings(any()) .setLocal(anyBoolean()) @@ -254,7 +258,7 @@ void getIndexMappingsWithNonExistIndex() { } @Test - void getIndexMaxResultWindows() throws IOException { + void get_index_max_result_windows() throws IOException { URL url = Resources.getResource(TEST_MAPPING_SETTINGS_FILE); String indexMetadata = Resources.toString(url, Charsets.UTF_8); String indexName = "accounts"; @@ -268,7 +272,7 @@ void getIndexMaxResultWindows() throws IOException { } @Test - void getIndexMaxResultWindowsWithDefaultSettings() throws IOException { + void get_index_max_result_windows_with_default_settings() throws IOException { URL url = Resources.getResource(TEST_MAPPING_FILE); String indexMetadata = Resources.toString(url, Charsets.UTF_8); String indexName = "accounts"; @@ -282,7 +286,7 @@ void getIndexMaxResultWindowsWithDefaultSettings() throws IOException { } @Test - void getIndexMaxResultWindowsWithIOException() { + void get_index_max_result_windows_with_IOException() { String indexName = "test"; when(nodeClient.admin().indices()).thenThrow(RuntimeException.class); @@ -291,7 +295,7 @@ void getIndexMaxResultWindowsWithIOException() { /** Jacoco enforce this constant lambda be tested. */ @Test - void testAllFieldsPredicate() { + void test_all_fields_predicate() { assertTrue(OpenSearchNodeClient.ALL_FIELDS.apply("any_index").test("any_field")); } @@ -314,11 +318,12 @@ void search() { // Mock second scroll request followed SearchResponse scrollResponse = mock(SearchResponse.class); when(nodeClient.searchScroll(any()).actionGet()).thenReturn(scrollResponse); - when(scrollResponse.getScrollId()).thenReturn("scroll456"); when(scrollResponse.getHits()).thenReturn(SearchHits.empty()); // Verify response for first scroll request - OpenSearchScrollRequest request = new OpenSearchScrollRequest("test", factory); + OpenSearchScrollRequest request = new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1), + new SearchSourceBuilder(), factory); OpenSearchResponse response1 = client.search(request); assertFalse(response1.isEmpty()); @@ -328,6 +333,7 @@ void search() { assertFalse(hits.hasNext()); // Verify response for second scroll request + request.setScrollId("scroll123"); OpenSearchResponse response2 = client.search(request); assertTrue(response2.isEmpty()); } @@ -343,16 +349,21 @@ void schedule() { } @Test + @SneakyThrows void cleanup() { ClearScrollRequestBuilder requestBuilder = mock(ClearScrollRequestBuilder.class); when(nodeClient.prepareClearScroll()).thenReturn(requestBuilder); when(requestBuilder.addScrollId(any())).thenReturn(requestBuilder); when(requestBuilder.get()).thenReturn(null); - OpenSearchScrollRequest request = new OpenSearchScrollRequest("test", factory); + OpenSearchScrollRequest request = new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1), + new SearchSourceBuilder(), factory); request.setScrollId("scroll123"); + // Enforce cleaning by setting a private field. + FieldUtils.writeField(request, "needClean", true, true); client.cleanup(request); - assertFalse(request.isScrollStarted()); + assertFalse(request.isScroll()); InOrder inOrder = Mockito.inOrder(nodeClient, requestBuilder); inOrder.verify(nodeClient).prepareClearScroll(); @@ -361,14 +372,30 @@ void cleanup() { } @Test - void cleanupWithoutScrollId() { - OpenSearchScrollRequest request = new OpenSearchScrollRequest("test", factory); + void cleanup_without_scrollId() { + OpenSearchScrollRequest request = new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1), + new SearchSourceBuilder(), factory); client.cleanup(request); verify(nodeClient, never()).prepareClearScroll(); } @Test - void getIndices() { + @SneakyThrows + void cleanup_rethrows_exception() { + when(nodeClient.prepareClearScroll()).thenThrow(new RuntimeException()); + + OpenSearchScrollRequest request = new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1), + new SearchSourceBuilder(), factory); + request.setScrollId("scroll123"); + // Enforce cleaning by setting a private field. + FieldUtils.writeField(request, "needClean", true, true); + assertThrows(IllegalStateException.class, () -> client.cleanup(request)); + } + + @Test + void get_indices() { AliasMetadata aliasMetadata = mock(AliasMetadata.class); final var openMap = Map.of("index", List.of(aliasMetadata)); when(aliasMetadata.alias()).thenReturn("index_alias"); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java index 141e21c38a..2958fa1100 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java @@ -30,8 +30,12 @@ import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; +import lombok.SneakyThrows; +import org.apache.commons.lang3.reflect.FieldUtils; import org.apache.lucene.search.TotalHits; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -50,12 +54,14 @@ import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.metadata.MappingMetadata; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.sql.data.model.ExprIntegerValue; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; @@ -63,14 +69,15 @@ import org.opensearch.sql.opensearch.data.type.OpenSearchTextType; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.mapping.IndexMapping; +import org.opensearch.sql.opensearch.request.OpenSearchRequest; import org.opensearch.sql.opensearch.request.OpenSearchScrollRequest; import org.opensearch.sql.opensearch.response.OpenSearchResponse; @ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class OpenSearchRestClientTest { private static final String TEST_MAPPING_FILE = "mappings/accounts.json"; - @Mock(answer = RETURNS_DEEP_STUBS) private RestHighLevelClient restClient; @@ -85,8 +92,8 @@ class OpenSearchRestClientTest { @Mock private GetIndexResponse getIndexResponse; - private ExprTupleValue exprTupleValue = ExprTupleValue.fromExprValueMap(ImmutableMap.of("id", - new ExprIntegerValue(1))); + private final ExprTupleValue exprTupleValue = ExprTupleValue.fromExprValueMap( + Map.of("id", new ExprIntegerValue(1))); @BeforeEach void setUp() { @@ -94,7 +101,7 @@ void setUp() { } @Test - void isIndexExist() throws IOException { + void is_index_exist() throws IOException { when(restClient.indices() .exists(any(), any())) // use any() because missing equals() in GetIndexRequest .thenReturn(true); @@ -103,7 +110,7 @@ void isIndexExist() throws IOException { } @Test - void isIndexNotExist() throws IOException { + void is_index_not_exist() throws IOException { when(restClient.indices() .exists(any(), any())) // use any() because missing equals() in GetIndexRequest .thenReturn(false); @@ -112,14 +119,14 @@ void isIndexNotExist() throws IOException { } @Test - void isIndexExistWithException() throws IOException { + void is_index_exist_with_exception() throws IOException { when(restClient.indices().exists(any(), any())).thenThrow(IOException.class); assertThrows(IllegalStateException.class, () -> client.exists("test")); } @Test - void createIndex() throws IOException { + void create_index() throws IOException { String indexName = "test"; Map mappings = ImmutableMap.of( "properties", @@ -132,7 +139,7 @@ void createIndex() throws IOException { } @Test - void createIndexWithIOException() throws IOException { + void create_index_with_IOException() throws IOException { when(restClient.indices().create(any(), any())).thenThrow(IOException.class); assertThrows(IllegalStateException.class, @@ -140,7 +147,7 @@ void createIndexWithIOException() throws IOException { } @Test - void getIndexMappings() throws IOException { + void get_index_mappings() throws IOException { URL url = Resources.getResource(TEST_MAPPING_FILE); String mappings = Resources.toString(url, Charsets.UTF_8); String indexName = "test"; @@ -215,14 +222,14 @@ void getIndexMappings() throws IOException { } @Test - void getIndexMappingsWithIOException() throws IOException { + void get_index_mappings_with_IOException() throws IOException { when(restClient.indices().getMapping(any(GetMappingsRequest.class), any())) .thenThrow(new IOException()); assertThrows(IllegalStateException.class, () -> client.getIndexMappings("test")); } @Test - void getIndexMaxResultWindowsSettings() throws IOException { + void get_index_max_result_windows_settings() throws IOException { String indexName = "test"; Integer maxResultWindow = 1000; @@ -246,7 +253,7 @@ void getIndexMaxResultWindowsSettings() throws IOException { } @Test - void getIndexMaxResultWindowsDefaultSettings() throws IOException { + void get_index_max_result_windows_default_settings() throws IOException { String indexName = "test"; Integer maxResultWindow = 10000; @@ -270,7 +277,7 @@ void getIndexMaxResultWindowsDefaultSettings() throws IOException { } @Test - void getIndexMaxResultWindowsWithIOException() throws IOException { + void get_index_max_result_windows_with_IOException() throws IOException { when(restClient.indices().getSettings(any(GetSettingsRequest.class), any())) .thenThrow(new IOException()); assertThrows(IllegalStateException.class, () -> client.getIndexMaxResultWindows("test")); @@ -295,11 +302,12 @@ void search() throws IOException { // Mock second scroll request followed SearchResponse scrollResponse = mock(SearchResponse.class); when(restClient.scroll(any(), any())).thenReturn(scrollResponse); - when(scrollResponse.getScrollId()).thenReturn("scroll456"); when(scrollResponse.getHits()).thenReturn(SearchHits.empty()); // Verify response for first scroll request - OpenSearchScrollRequest request = new OpenSearchScrollRequest("test", factory); + OpenSearchScrollRequest request = new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1), + new SearchSourceBuilder(), factory); OpenSearchResponse response1 = client.search(request); assertFalse(response1.isEmpty()); @@ -309,20 +317,23 @@ void search() throws IOException { assertFalse(hits.hasNext()); // Verify response for second scroll request + request.setScrollId("scroll123"); OpenSearchResponse response2 = client.search(request); assertTrue(response2.isEmpty()); } @Test - void searchWithIOException() throws IOException { + void search_with_IOException() throws IOException { when(restClient.search(any(), any())).thenThrow(new IOException()); assertThrows( IllegalStateException.class, - () -> client.search(new OpenSearchScrollRequest("test", factory))); + () -> client.search(new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1), + new SearchSourceBuilder(), factory))); } @Test - void scrollWithIOException() throws IOException { + void scroll_with_IOException() throws IOException { // Mock first scroll request SearchResponse searchResponse = mock(SearchResponse.class); when(restClient.search(any(), any())).thenReturn(searchResponse); @@ -338,7 +349,9 @@ void scrollWithIOException() throws IOException { when(restClient.scroll(any(), any())).thenThrow(new IOException()); // First request run successfully - OpenSearchScrollRequest scrollRequest = new OpenSearchScrollRequest("test", factory); + OpenSearchScrollRequest scrollRequest = new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1), + new SearchSourceBuilder(), factory); client.search(scrollRequest); assertThrows( IllegalStateException.class, () -> client.search(scrollRequest)); @@ -348,39 +361,49 @@ void scrollWithIOException() throws IOException { void schedule() { AtomicBoolean isRun = new AtomicBoolean(false); client.schedule( - () -> { - isRun.set(true); - }); + () -> isRun.set(true)); assertTrue(isRun.get()); } @Test - void cleanup() throws IOException { - OpenSearchScrollRequest request = new OpenSearchScrollRequest("test", factory); + @SneakyThrows + void cleanup() { + OpenSearchScrollRequest request = new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1), + new SearchSourceBuilder(), factory); + // Enforce cleaning by setting a private field. + FieldUtils.writeField(request, "needClean", true, true); request.setScrollId("scroll123"); client.cleanup(request); verify(restClient).clearScroll(any(), any()); - assertFalse(request.isScrollStarted()); + assertFalse(request.isScroll()); } @Test - void cleanupWithoutScrollId() throws IOException { - OpenSearchScrollRequest request = new OpenSearchScrollRequest("test", factory); + void cleanup_without_scrollId() throws IOException { + OpenSearchScrollRequest request = new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1), + new SearchSourceBuilder(), factory); client.cleanup(request); verify(restClient, never()).clearScroll(any(), any()); } @Test - void cleanupWithIOException() throws IOException { + @SneakyThrows + void cleanup_with_IOException() { when(restClient.clearScroll(any(), any())).thenThrow(new IOException()); - OpenSearchScrollRequest request = new OpenSearchScrollRequest("test", factory); + OpenSearchScrollRequest request = new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1), + new SearchSourceBuilder(), factory); + // Enforce cleaning by setting a private field. + FieldUtils.writeField(request, "needClean", true, true); request.setScrollId("scroll123"); assertThrows(IllegalStateException.class, () -> client.cleanup(request)); } @Test - void getIndices() throws IOException { + void get_indices() throws IOException { when(restClient.indices().get(any(GetIndexRequest.class), any(RequestOptions.class))) .thenReturn(getIndexResponse); when(getIndexResponse.getIndices()).thenReturn(new String[] {"index"}); @@ -390,7 +413,7 @@ void getIndices() throws IOException { } @Test - void getIndicesWithIOException() throws IOException { + void get_indices_with_IOException() throws IOException { when(restClient.indices().get(any(GetIndexRequest.class), any(RequestOptions.class))) .thenThrow(new IOException()); assertThrows(IllegalStateException.class, () -> client.indices()); @@ -409,7 +432,7 @@ void meta() throws IOException { } @Test - void metaWithIOException() throws IOException { + void meta_with_IOException() throws IOException { when(restClient.cluster().getSettings(any(), any(RequestOptions.class))) .thenThrow(new IOException()); @@ -417,7 +440,7 @@ void metaWithIOException() throws IOException { } @Test - void mlWithException() { + void ml_with_exception() { assertThrows(UnsupportedOperationException.class, () -> client.getNodeClient()); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngineTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngineTest.java index 4a0c6e24f1..330793a5d6 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngineTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngineTest.java @@ -17,38 +17,47 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.sql.common.setting.Settings.Key.QUERY_SIZE_LIMIT; +import static org.opensearch.sql.common.setting.Settings.Key.SQL_CURSOR_KEEP_ALIVE; import static org.opensearch.sql.data.model.ExprValueUtils.tupleValue; import static org.opensearch.sql.executor.ExecutionEngine.QueryResponse; +import java.io.ObjectInput; +import java.io.ObjectOutput; import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.concurrent.atomic.AtomicReference; import lombok.RequiredArgsConstructor; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.common.unit.TimeValue; import org.opensearch.sql.common.response.ResponseListener; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.executor.ExecutionContext; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.executor.ExecutionEngine.ExplainResponse; +import org.opensearch.sql.executor.pagination.PlanSerializer; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.executor.protector.OpenSearchExecutionProtector; -import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; +import org.opensearch.sql.opensearch.request.OpenSearchRequest; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScan; +import org.opensearch.sql.planner.SerializablePlan; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.storage.TableScanOperator; import org.opensearch.sql.storage.split.Split; @ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class OpenSearchExecutionEngineTest { @Mock private OpenSearchClient client; @@ -75,28 +84,29 @@ void setUp() { } @Test - void executeSuccessfully() { + void execute_successfully() { List expected = Arrays.asList( tupleValue(of("name", "John", "age", 20)), tupleValue(of("name", "Allen", "age", 30))); FakePhysicalPlan plan = new FakePhysicalPlan(expected.iterator()); when(protector.protect(plan)).thenReturn(plan); - OpenSearchExecutionEngine executor = new OpenSearchExecutionEngine(client, protector); + OpenSearchExecutionEngine executor = new OpenSearchExecutionEngine(client, protector, + new PlanSerializer(null)); List actual = new ArrayList<>(); executor.execute( plan, - new ResponseListener() { - @Override - public void onResponse(QueryResponse response) { - actual.addAll(response.getResults()); - } - - @Override - public void onFailure(Exception e) { - fail("Error occurred during execution", e); - } - }); + new ResponseListener<>() { + @Override + public void onResponse(QueryResponse response) { + actual.addAll(response.getResults()); + } + + @Override + public void onFailure(Exception e) { + fail("Error occurred during execution", e); + } + }); assertTrue(plan.hasOpen); assertEquals(expected, actual); @@ -104,41 +114,80 @@ public void onFailure(Exception e) { } @Test - void executeWithFailure() { + void execute_with_cursor() { + List expected = + Arrays.asList( + tupleValue(of("name", "John", "age", 20)), tupleValue(of("name", "Allen", "age", 30))); + var plan = new FakePhysicalPlan(expected.iterator()); + when(protector.protect(plan)).thenReturn(plan); + + OpenSearchExecutionEngine executor = new OpenSearchExecutionEngine(client, protector, + new PlanSerializer(null)); + List actual = new ArrayList<>(); + executor.execute( + plan, + new ResponseListener<>() { + @Override + public void onResponse(QueryResponse response) { + actual.addAll(response.getResults()); + assertTrue(response.getCursor().toString().startsWith("n:")); + } + + @Override + public void onFailure(Exception e) { + fail("Error occurred during execution", e); + } + }); + + assertEquals(expected, actual); + } + + @Test + void execute_with_failure() { PhysicalPlan plan = mock(PhysicalPlan.class); RuntimeException expected = new RuntimeException("Execution error"); when(plan.hasNext()).thenThrow(expected); when(protector.protect(plan)).thenReturn(plan); - OpenSearchExecutionEngine executor = new OpenSearchExecutionEngine(client, protector); + OpenSearchExecutionEngine executor = new OpenSearchExecutionEngine(client, protector, + new PlanSerializer(null)); AtomicReference actual = new AtomicReference<>(); executor.execute( plan, - new ResponseListener() { - @Override - public void onResponse(QueryResponse response) { - fail("Expected error didn't happen"); - } - - @Override - public void onFailure(Exception e) { - actual.set(e); - } - }); + new ResponseListener<>() { + @Override + public void onResponse(QueryResponse response) { + fail("Expected error didn't happen"); + } + + @Override + public void onFailure(Exception e) { + actual.set(e); + } + }); assertEquals(expected, actual.get()); verify(plan).close(); } @Test - void explainSuccessfully() { - OpenSearchExecutionEngine executor = new OpenSearchExecutionEngine(client, protector); + void explain_successfully() { + OpenSearchExecutionEngine executor = new OpenSearchExecutionEngine(client, protector, + new PlanSerializer(null)); Settings settings = mock(Settings.class); - when(settings.getSettingValue(QUERY_SIZE_LIMIT)).thenReturn(100); + when(settings.getSettingValue(SQL_CURSOR_KEEP_ALIVE)) + .thenReturn(TimeValue.timeValueMinutes(1)); + + OpenSearchExprValueFactory exprValueFactory = mock(OpenSearchExprValueFactory.class); + final var name = new OpenSearchRequest.IndexName("test"); + final int defaultQuerySize = 100; + final int maxResultWindow = 10000; + final var requestBuilder = new OpenSearchRequestBuilder(defaultQuerySize, exprValueFactory); PhysicalPlan plan = new OpenSearchIndexScan(mock(OpenSearchClient.class), - settings, "test", 10000, mock(OpenSearchExprValueFactory.class)); + maxResultWindow, requestBuilder.build(name, maxResultWindow, + settings.getSettingValue(SQL_CURSOR_KEEP_ALIVE))); AtomicReference result = new AtomicReference<>(); - executor.explain(plan, new ResponseListener() { + executor.explain(plan, new ResponseListener<>() { @Override public void onResponse(ExplainResponse response) { result.set(response); @@ -154,13 +203,14 @@ public void onFailure(Exception e) { } @Test - void explainWithFailure() { - OpenSearchExecutionEngine executor = new OpenSearchExecutionEngine(client, protector); + void explain_with_failure() { + OpenSearchExecutionEngine executor = new OpenSearchExecutionEngine(client, protector, + new PlanSerializer(null)); PhysicalPlan plan = mock(PhysicalPlan.class); when(plan.accept(any(), any())).thenThrow(IllegalStateException.class); AtomicReference result = new AtomicReference<>(); - executor.explain(plan, new ResponseListener() { + executor.explain(plan, new ResponseListener<>() { @Override public void onResponse(ExplainResponse response) { fail("Should fail as expected"); @@ -176,7 +226,7 @@ public void onFailure(Exception e) { } @Test - void callAddSplitAndOpenInOrder() { + void call_add_split_and_open_in_order() { List expected = Arrays.asList( tupleValue(of("name", "John", "age", 20)), tupleValue(of("name", "Allen", "age", 30))); @@ -184,7 +234,8 @@ void callAddSplitAndOpenInOrder() { when(protector.protect(plan)).thenReturn(plan); when(executionContext.getSplit()).thenReturn(Optional.of(split)); - OpenSearchExecutionEngine executor = new OpenSearchExecutionEngine(client, protector); + OpenSearchExecutionEngine executor = new OpenSearchExecutionEngine(client, protector, + new PlanSerializer(null)); List actual = new ArrayList<>(); executor.execute( plan, @@ -208,12 +259,20 @@ public void onFailure(Exception e) { } @RequiredArgsConstructor - private static class FakePhysicalPlan extends TableScanOperator { + private static class FakePhysicalPlan extends TableScanOperator implements SerializablePlan { private final Iterator it; private boolean hasOpen; private boolean hasClosed; private boolean hasSplit; + @Override + public void readExternal(ObjectInput in) { + } + + @Override + public void writeExternal(ObjectOutput out) { + } + @Override public void open() { super.open(); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/ResourceMonitorPlanTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/ResourceMonitorPlanTest.java index d4d987a7df..96e85a8173 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/ResourceMonitorPlanTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/ResourceMonitorPlanTest.java @@ -8,9 +8,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -19,6 +21,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.monitor.ResourceMonitor; import org.opensearch.sql.opensearch.executor.protector.ResourceMonitorPlan; +import org.opensearch.sql.planner.SerializablePlan; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; @@ -107,4 +110,18 @@ void acceptSuccess() { monitorPlan.accept(visitor, context); verify(plan, times(1)).accept(visitor, context); } + + @Test + void getPlanForSerialization() { + plan = mock(PhysicalPlan.class, withSettings().extraInterfaces(SerializablePlan.class)); + monitorPlan = new ResourceMonitorPlan(plan, resourceMonitor); + assertEquals(plan, monitorPlan.getPlanForSerialization()); + } + + @Test + void notSerializable() { + // ResourceMonitorPlan shouldn't be serialized, attempt should throw an exception + assertThrows(UnsupportedOperationException.class, () -> monitorPlan.writeExternal(null)); + assertThrows(UnsupportedOperationException.class, () -> monitorPlan.readExternal(null)); + } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java index f1fcaf677f..fd5e747b5f 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java @@ -8,8 +8,10 @@ import static java.util.Collections.emptyList; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC; import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; @@ -24,7 +26,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -32,6 +33,8 @@ import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -58,13 +61,17 @@ import org.opensearch.sql.opensearch.planner.physical.ADOperator; import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator; import org.opensearch.sql.opensearch.planner.physical.MLOperator; +import org.opensearch.sql.opensearch.request.OpenSearchRequest; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.opensearch.setting.OpenSearchSettings; -import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; +import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScan; +import org.opensearch.sql.planner.physical.CursorCloseOperator; import org.opensearch.sql.planner.physical.NestedOperator; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanDSL; @ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class OpenSearchExecutionProtectorTest { @Mock @@ -87,21 +94,20 @@ public void setup() { } @Test - public void testProtectIndexScan() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - + void test_protect_indexScan() { String indexName = "test"; - Integer maxResultWindow = 10000; + final int maxResultWindow = 10000; + final int querySizeLimit = 200; NamedExpression include = named("age", ref("age", INTEGER)); ReferenceExpression exclude = ref("name", STRING); ReferenceExpression dedupeField = ref("name", STRING); ReferenceExpression topField = ref("name", STRING); - List topExprs = Arrays.asList(ref("age", INTEGER)); + List topExprs = List.of(ref("age", INTEGER)); Expression filterExpr = literal(ExprBooleanValue.of(true)); - List groupByExprs = Arrays.asList(named("age", ref("age", INTEGER))); + List groupByExprs = List.of(named("age", ref("age", INTEGER))); List aggregators = - Arrays.asList(named("avg(age)", new AvgAggregator(Arrays.asList(ref("age", INTEGER)), - DOUBLE))); + List.of(named("avg(age)", new AvgAggregator(List.of(ref("age", INTEGER)), + DOUBLE))); Map mappings = ImmutableMap.of(ref("name", STRING), ref("lastname", STRING)); Pair newEvalField = @@ -111,6 +117,10 @@ public void testProtectIndexScan() { Integer limit = 10; Integer offset = 10; + final var name = new OpenSearchRequest.IndexName(indexName); + final var request = new OpenSearchRequestBuilder(querySizeLimit, exprValueFactory) + .build(name, maxResultWindow, + settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE)); assertEquals( PhysicalPlanDSL.project( PhysicalPlanDSL.limit( @@ -124,9 +134,8 @@ public void testProtectIndexScan() { PhysicalPlanDSL.agg( filter( resourceMonitor( - new OpenSearchIndexScan( - client, settings, indexName, - maxResultWindow, exprValueFactory)), + new OpenSearchIndexScan(client, + maxResultWindow, request)), filterExpr), aggregators, groupByExprs), @@ -152,9 +161,8 @@ public void testProtectIndexScan() { PhysicalPlanDSL.rename( PhysicalPlanDSL.agg( filter( - new OpenSearchIndexScan( - client, settings, indexName, - maxResultWindow, exprValueFactory), + new OpenSearchIndexScan(client, + maxResultWindow, request), filterExpr), aggregators, groupByExprs), @@ -173,7 +181,7 @@ public void testProtectIndexScan() { @SuppressWarnings("unchecked") @Test - public void testProtectSortForWindowOperator() { + void test_protect_sort_for_windowOperator() { NamedExpression rank = named(mock(RankFunction.class)); Pair sortItem = ImmutablePair.of(DEFAULT_ASC, DSL.ref("age", INTEGER)); @@ -199,7 +207,7 @@ public void testProtectSortForWindowOperator() { } @Test - public void testProtectWindowOperatorInput() { + void test_protect_windowOperator_input() { NamedExpression avg = named(mock(AggregateWindowFunction.class)); WindowDefinition windowDefinition = mock(WindowDefinition.class); @@ -218,7 +226,7 @@ public void testProtectWindowOperatorInput() { @SuppressWarnings("unchecked") @Test - public void testNotProtectWindowOperatorInputIfAlreadyProtected() { + void test_not_protect_windowOperator_input_if_already_protected() { NamedExpression avg = named(mock(AggregateWindowFunction.class)); Pair sortItem = ImmutablePair.of(DEFAULT_ASC, DSL.ref("age", INTEGER)); @@ -243,7 +251,7 @@ public void testNotProtectWindowOperatorInputIfAlreadyProtected() { } @Test - public void testWithoutProtection() { + void test_without_protection() { Expression filterExpr = literal(ExprBooleanValue.of(true)); assertEquals( @@ -259,7 +267,7 @@ public void testWithoutProtection() { } @Test - public void testVisitMlCommons() { + void test_visitMLcommons() { NodeClient nodeClient = mock(NodeClient.class); MLCommonsOperator mlCommonsOperator = new MLCommonsOperator( @@ -277,7 +285,7 @@ public void testVisitMlCommons() { } @Test - public void testVisitAD() { + void test_visitAD() { NodeClient nodeClient = mock(NodeClient.class); ADOperator adOperator = new ADOperator( @@ -295,7 +303,7 @@ public void testVisitAD() { } @Test - public void testVisitML() { + void test_visitML() { NodeClient nodeClient = mock(NodeClient.class); MLOperator mlOperator = new MLOperator( @@ -315,7 +323,7 @@ public void testVisitML() { } @Test - public void testVisitNested() { + void test_visitNested() { Set args = Set.of("message.info"); Map> groupedFieldsByPath = Map.of("message", List.of("message.info")); @@ -329,6 +337,14 @@ public void testVisitNested() { executionProtector.visitNested(nestedOperator, values(emptyList()))); } + @Test + void do_nothing_with_CursorCloseOperator_and_children() { + var child = mock(PhysicalPlan.class); + var plan = new CursorCloseOperator(child); + assertSame(plan, executionProtector.protect(plan)); + verify(child, never()).accept(executionProtector, null); + } + PhysicalPlan resourceMonitor(PhysicalPlan input) { return new ResourceMonitorPlan(input, resourceMonitor); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequestTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequestTest.java index 2e1ded6322..e188bd7c5c 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequestTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequestTest.java @@ -8,16 +8,19 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.sql.opensearch.request.OpenSearchRequest.DEFAULT_QUERY_TIMEOUT; -import java.util.Iterator; import java.util.function.Consumer; import java.util.function.Function; +import org.apache.lucene.search.TotalHits; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -25,12 +28,12 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchScrollRequest; +import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; -import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.response.OpenSearchResponse; @@ -104,10 +107,10 @@ void search_withoutContext() { when(searchAction.apply(any())).thenReturn(searchResponse); when(searchResponse.getHits()).thenReturn(searchHits); when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); - OpenSearchResponse searchResponse = request.search(searchAction, scrollAction); verify(sourceBuilder, times(1)).fetchSource(); assertFalse(searchResponse.isEmpty()); + assertFalse(request.hasAnotherBatch()); } @Test @@ -145,29 +148,45 @@ void clean() { void searchRequest() { request.getSourceBuilder().query(QueryBuilders.termQuery("name", "John")); - assertEquals( - new SearchRequest() - .indices("test") - .source(new SearchSourceBuilder() - .timeout(OpenSearchQueryRequest.DEFAULT_QUERY_TIMEOUT) - .from(0) - .size(200) - .query(QueryBuilders.termQuery("name", "John"))), - request.searchRequest()); + assertSearchRequest(new SearchRequest() + .indices("test") + .source(new SearchSourceBuilder() + .timeout(DEFAULT_QUERY_TIMEOUT) + .from(0) + .size(200) + .query(QueryBuilders.termQuery("name", "John"))), + request); } @Test void searchCrossClusterRequest() { remoteRequest.getSourceBuilder().query(QueryBuilders.termQuery("name", "John")); - assertEquals( + assertSearchRequest( new SearchRequest() .indices("ccs:test") .source(new SearchSourceBuilder() - .timeout(OpenSearchQueryRequest.DEFAULT_QUERY_TIMEOUT) + .timeout(DEFAULT_QUERY_TIMEOUT) .from(0) .size(200) .query(QueryBuilders.termQuery("name", "John"))), - remoteRequest.searchRequest()); + remoteRequest); + } + + @Test + void writeTo_unsupported() { + assertThrows(UnsupportedOperationException.class, + () -> request.writeTo(mock(StreamOutput.class))); + } + + private void assertSearchRequest(SearchRequest expected, OpenSearchQueryRequest request) { + Function querySearch = searchRequest -> { + assertEquals(expected, searchRequest); + return when(mock(SearchResponse.class).getHits()) + .thenReturn(new SearchHits(new SearchHit[0], + new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0.0f)) + .getMock(); + }; + request.search(querySearch, searchScrollRequest -> null); } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java index 4685609f58..e8d15bd0bb 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java @@ -6,7 +6,9 @@ package org.opensearch.sql.opensearch.request; +import static org.junit.Assert.assertThrows; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.index.query.QueryBuilders.matchAllQuery; @@ -20,18 +22,27 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Function; import org.apache.commons.lang3.tuple.Pair; +import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.join.ScoreMode; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchScrollRequest; import org.opensearch.common.unit.TimeValue; import org.opensearch.index.query.InnerHitBuilder; import org.opensearch.index.query.NestedQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; @@ -40,7 +51,7 @@ import org.opensearch.search.sort.FieldSortBuilder; import org.opensearch.search.sort.ScoreSortBuilder; import org.opensearch.search.sort.SortBuilders; -import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; @@ -52,15 +63,16 @@ import org.opensearch.sql.planner.logical.LogicalNested; @ExtendWith(MockitoExtension.class) -public class OpenSearchRequestBuilderTest { +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +class OpenSearchRequestBuilderTest { private static final TimeValue DEFAULT_QUERY_TIMEOUT = TimeValue.timeValueMinutes(1L); private static final Integer DEFAULT_OFFSET = 0; private static final Integer DEFAULT_LIMIT = 200; private static final Integer MAX_RESULT_WINDOW = 500; - @Mock - private Settings settings; + private static final OpenSearchRequest.IndexName indexName + = new OpenSearchRequest.IndexName("test"); @Mock private OpenSearchExprValueFactory exprValueFactory; @@ -69,14 +81,11 @@ public class OpenSearchRequestBuilderTest { @BeforeEach void setup() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - - requestBuilder = new OpenSearchRequestBuilder( - "test", MAX_RESULT_WINDOW, settings, exprValueFactory); + requestBuilder = new OpenSearchRequestBuilder(DEFAULT_LIMIT, exprValueFactory); } @Test - void buildQueryRequest() { + void build_query_request() { Integer limit = 200; Integer offset = 0; requestBuilder.pushDownLimit(limit, offset); @@ -91,44 +100,53 @@ void buildQueryRequest() { .timeout(DEFAULT_QUERY_TIMEOUT) .trackScores(true), exprValueFactory), - requestBuilder.build()); + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT)); } @Test - void buildScrollRequestWithCorrectSize() { + void build_scroll_request_with_correct_size() { Integer limit = 800; Integer offset = 10; requestBuilder.pushDownLimit(limit, offset); assertEquals( new OpenSearchScrollRequest( - new OpenSearchRequest.IndexName("test"), + new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1), new SearchSourceBuilder() .from(offset) .size(MAX_RESULT_WINDOW - offset) .timeout(DEFAULT_QUERY_TIMEOUT), exprValueFactory), - requestBuilder.build()); + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT)); } @Test - void testPushDownQuery() { + void test_push_down_query() { QueryBuilder query = QueryBuilders.termQuery("intA", 1); - requestBuilder.pushDown(query); + requestBuilder.pushDownFilter(query); - assertEquals( + var r = requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT); + Function querySearch = searchRequest -> { + assertEquals( new SearchSourceBuilder() - .from(DEFAULT_OFFSET) - .size(DEFAULT_LIMIT) - .timeout(DEFAULT_QUERY_TIMEOUT) - .query(query) - .sort(DOC_FIELD_NAME, ASC), - requestBuilder.getSourceBuilder() - ); + .from(DEFAULT_OFFSET) + .size(DEFAULT_LIMIT) + .timeout(DEFAULT_QUERY_TIMEOUT) + .query(query) + .sort(DOC_FIELD_NAME, ASC), + searchRequest.source() + ); + return mock(); + }; + Function scrollSearch = searchScrollRequest -> { + throw new UnsupportedOperationException(); + }; + r.search(querySearch, scrollSearch); + } @Test - void testPushDownAggregation() { + void test_push_down_aggregation() { AggregationBuilder aggBuilder = AggregationBuilders.composite( "composite_buckets", Collections.singletonList(new TermsValuesSourceBuilder("longA"))); @@ -149,83 +167,100 @@ void testPushDownAggregation() { } @Test - void testPushDownQueryAndSort() { + void test_push_down_query_and_sort() { QueryBuilder query = QueryBuilders.termQuery("intA", 1); - requestBuilder.pushDown(query); + requestBuilder.pushDownFilter(query); FieldSortBuilder sortBuilder = SortBuilders.fieldSort("intA"); requestBuilder.pushDownSort(List.of(sortBuilder)); - assertEquals( + assertSearchSourceBuilder( new SearchSourceBuilder() .from(DEFAULT_OFFSET) .size(DEFAULT_LIMIT) .timeout(DEFAULT_QUERY_TIMEOUT) .query(query) .sort(sortBuilder), - requestBuilder.getSourceBuilder()); + requestBuilder); + } + + void assertSearchSourceBuilder(SearchSourceBuilder expected, + OpenSearchRequestBuilder requestBuilder) + throws UnsupportedOperationException { + Function querySearch = searchRequest -> { + assertEquals(expected, searchRequest.source()); + return when(mock(SearchResponse.class).getHits()) + .thenReturn(new SearchHits(new SearchHit[0], new TotalHits(0, + TotalHits.Relation.EQUAL_TO), 0.0f)) + .getMock(); + }; + Function scrollSearch = searchScrollRequest -> { + throw new UnsupportedOperationException(); + }; + requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT).search( + querySearch, scrollSearch); } @Test - void testPushDownSort() { + void test_push_down_sort() { FieldSortBuilder sortBuilder = SortBuilders.fieldSort("intA"); requestBuilder.pushDownSort(List.of(sortBuilder)); - assertEquals( + assertSearchSourceBuilder( new SearchSourceBuilder() .from(DEFAULT_OFFSET) .size(DEFAULT_LIMIT) .timeout(DEFAULT_QUERY_TIMEOUT) .sort(sortBuilder), - requestBuilder.getSourceBuilder()); + requestBuilder); } @Test - void testPushDownNonFieldSort() { + void test_push_down_non_field_sort() { ScoreSortBuilder sortBuilder = SortBuilders.scoreSort(); requestBuilder.pushDownSort(List.of(sortBuilder)); - assertEquals( + assertSearchSourceBuilder( new SearchSourceBuilder() .from(DEFAULT_OFFSET) .size(DEFAULT_LIMIT) .timeout(DEFAULT_QUERY_TIMEOUT) .sort(sortBuilder), - requestBuilder.getSourceBuilder()); + requestBuilder); } @Test - void testPushDownMultipleSort() { + void test_push_down_multiple_sort() { requestBuilder.pushDownSort(List.of( SortBuilders.fieldSort("intA"), SortBuilders.fieldSort("intB"))); - assertEquals( + assertSearchSourceBuilder( new SearchSourceBuilder() .from(DEFAULT_OFFSET) .size(DEFAULT_LIMIT) .timeout(DEFAULT_QUERY_TIMEOUT) .sort(SortBuilders.fieldSort("intA")) .sort(SortBuilders.fieldSort("intB")), - requestBuilder.getSourceBuilder()); + requestBuilder); } @Test - void testPushDownProject() { + void test_push_down_project() { Set references = Set.of(DSL.ref("intA", INTEGER)); requestBuilder.pushDownProjects(references); - assertEquals( + assertSearchSourceBuilder( new SearchSourceBuilder() .from(DEFAULT_OFFSET) .size(DEFAULT_LIMIT) .timeout(DEFAULT_QUERY_TIMEOUT) .fetchSource(new String[]{"intA"}, new String[0]), - requestBuilder.getSourceBuilder()); + requestBuilder); } @Test - void testPushDownNested() { + void test_push_down_nested() { List> args = List.of( Map.of( "field", new ReferenceExpression("message.info", STRING), @@ -245,17 +280,17 @@ void testPushDownNested() { .innerHit(new InnerHitBuilder().setFetchSourceContext( new FetchSourceContext(true, new String[]{"message.info"}, null))); - assertEquals( + assertSearchSourceBuilder( new SearchSourceBuilder() .query(QueryBuilders.boolQuery().filter(QueryBuilders.boolQuery().must(nestedQuery))) .from(DEFAULT_OFFSET) .size(DEFAULT_LIMIT) .timeout(DEFAULT_QUERY_TIMEOUT), - requestBuilder.getSourceBuilder()); + requestBuilder); } @Test - void testPushDownMultipleNestedWithSamePath() { + void test_push_down_multiple_nested_with_same_path() { List> args = List.of( Map.of( "field", new ReferenceExpression("message.info", STRING), @@ -278,17 +313,17 @@ void testPushDownMultipleNestedWithSamePath() { NestedQueryBuilder nestedQuery = nestedQuery("message", matchAllQuery(), ScoreMode.None) .innerHit(new InnerHitBuilder().setFetchSourceContext( new FetchSourceContext(true, new String[]{"message.info", "message.from"}, null))); - assertEquals( + assertSearchSourceBuilder( new SearchSourceBuilder() .query(QueryBuilders.boolQuery().filter(QueryBuilders.boolQuery().must(nestedQuery))) .from(DEFAULT_OFFSET) .size(DEFAULT_LIMIT) .timeout(DEFAULT_QUERY_TIMEOUT), - requestBuilder.getSourceBuilder()); + requestBuilder); } @Test - void testPushDownNestedWithFilter() { + void test_push_down_nested_with_filter() { List> args = List.of( Map.of( "field", new ReferenceExpression("message.info", STRING), @@ -309,7 +344,7 @@ void testPushDownNestedWithFilter() { .innerHit(new InnerHitBuilder().setFetchSourceContext( new FetchSourceContext(true, new String[]{"message.info"}, null))); - assertEquals( + assertSearchSourceBuilder( new SearchSourceBuilder() .query( QueryBuilders.boolQuery().filter( @@ -321,7 +356,7 @@ void testPushDownNestedWithFilter() { .from(DEFAULT_OFFSET) .size(DEFAULT_LIMIT) .timeout(DEFAULT_QUERY_TIMEOUT), - requestBuilder.getSourceBuilder()); + requestBuilder); } @Test @@ -349,25 +384,62 @@ void testPushDownNestedWithNestedFilter() { .innerHit(new InnerHitBuilder().setFetchSourceContext( new FetchSourceContext(true, new String[]{"message.info"}, null))); - assertEquals( - new SearchSourceBuilder() - .query( - QueryBuilders.boolQuery().filter( - QueryBuilders.boolQuery() - .must(filterQuery) - ) - ) - .from(DEFAULT_OFFSET) - .size(DEFAULT_LIMIT) - .timeout(DEFAULT_QUERY_TIMEOUT), - requestBuilder.getSourceBuilder()); + assertSearchSourceBuilder(new SearchSourceBuilder() + .query( + QueryBuilders.boolQuery().filter( + QueryBuilders.boolQuery() + .must(filterQuery) + ) + ) + .from(DEFAULT_OFFSET) + .size(DEFAULT_LIMIT) + .timeout(DEFAULT_QUERY_TIMEOUT), requestBuilder); } @Test - void testPushTypeMapping() { + void test_push_type_mapping() { Map typeMapping = Map.of("intA", OpenSearchDataType.of(INTEGER)); requestBuilder.pushTypeMapping(typeMapping); verify(exprValueFactory).extendTypeMapping(typeMapping); } + + @Test + void push_down_highlight_with_repeating_fields() { + requestBuilder.pushDownHighlight("name", Map.of()); + var exception = assertThrows(SemanticCheckException.class, () -> + requestBuilder.pushDownHighlight("name", Map.of())); + assertEquals("Duplicate field name in highlight", exception.getMessage()); + } + + @Test + void push_down_page_size() { + requestBuilder.pushDownPageSize(3); + assertSearchSourceBuilder( + new SearchSourceBuilder() + .from(DEFAULT_OFFSET) + .size(3) + .timeout(DEFAULT_QUERY_TIMEOUT), + requestBuilder); + } + + @Test + void exception_when_non_zero_offset_and_page_size() { + requestBuilder.pushDownPageSize(3); + requestBuilder.pushDownLimit(300, 2); + assertThrows(UnsupportedOperationException.class, + () -> requestBuilder.build(indexName, MAX_RESULT_WINDOW, DEFAULT_QUERY_TIMEOUT)); + } + + @Test + void maxResponseSize_is_page_size() { + requestBuilder.pushDownPageSize(4); + assertEquals(4, requestBuilder.getMaxResponseSize()); + } + + @Test + void maxResponseSize_is_limit() { + requestBuilder.pushDownLimit(100, 0); + assertEquals(100, requestBuilder.getMaxResponseSize()); + } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequestTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequestTest.java index b3c049ce03..69f38ee7f2 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequestTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequestTest.java @@ -8,13 +8,25 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.sql.opensearch.request.OpenSearchScrollRequest.NO_SCROLL_ID; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; +import lombok.SneakyThrows; +import org.apache.commons.lang3.reflect.FieldUtils; +import org.apache.lucene.search.TotalHits; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -22,17 +34,25 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchScrollRequest; +import org.opensearch.common.io.stream.BytesStreamInput; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.unit.TimeValue; import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.response.OpenSearchResponse; +import org.opensearch.sql.opensearch.storage.OpenSearchIndex; +import org.opensearch.sql.opensearch.storage.OpenSearchStorageEngine; @ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class OpenSearchScrollRequestTest { + public static final OpenSearchRequest.IndexName INDEX_NAME + = new OpenSearchRequest.IndexName("test"); + public static final TimeValue SCROLL_TIMEOUT = TimeValue.timeValueMinutes(1); @Mock private Function searchAction; @@ -51,32 +71,55 @@ class OpenSearchScrollRequestTest { @Mock private SearchSourceBuilder sourceBuilder; - @Mock - private FetchSourceContext fetchSourceContext; @Mock private OpenSearchExprValueFactory factory; - private final OpenSearchScrollRequest request = - new OpenSearchScrollRequest("test", factory); + private final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + private final OpenSearchScrollRequest request = new OpenSearchScrollRequest( + INDEX_NAME, SCROLL_TIMEOUT, + searchSourceBuilder, factory); @Test - void searchRequest() { - request.getSourceBuilder().query(QueryBuilders.termQuery("name", "John")); + void constructor() { + searchSourceBuilder.fetchSource(new String[] {"test"}, null); + var request = new OpenSearchScrollRequest(INDEX_NAME, SCROLL_TIMEOUT, + searchSourceBuilder, factory); + assertNotEquals(List.of(), request.getIncludes()); + } - assertEquals( + @Test + void constructor2() { + searchSourceBuilder.fetchSource(new String[]{"test"}, null); + var request = new OpenSearchScrollRequest(INDEX_NAME, SCROLL_TIMEOUT, searchSourceBuilder, + factory); + assertNotEquals(List.of(), request.getIncludes()); + } + + @Test + void searchRequest() { + searchSourceBuilder.query(QueryBuilders.termQuery("name", "John")); + request.search(searchRequest -> { + assertEquals( new SearchRequest() - .indices("test") - .scroll(OpenSearchScrollRequest.DEFAULT_SCROLL_TIMEOUT) - .source(new SearchSourceBuilder().query(QueryBuilders.termQuery("name", "John"))), - request.searchRequest()); + .indices("test") + .scroll(TimeValue.timeValueMinutes(1)) + .source(new SearchSourceBuilder().query(QueryBuilders.termQuery("name", "John"))), + searchRequest); + SearchHits searchHitsMock = when(mock(SearchHits.class).getHits()) + .thenReturn(new SearchHit[0]).getMock(); + return when(mock(SearchResponse.class).getHits()).thenReturn(searchHitsMock).getMock(); + }, searchScrollRequest -> null); } @Test void isScrollStarted() { - assertFalse(request.isScrollStarted()); + assertFalse(request.isScroll()); request.setScrollId("scroll123"); - assertTrue(request.isScrollStarted()); + assertTrue(request.isScroll()); + + request.reset(); + assertFalse(request.isScroll()); } @Test @@ -84,7 +127,7 @@ void scrollRequest() { request.setScrollId("scroll123"); assertEquals( new SearchScrollRequest() - .scroll(OpenSearchScrollRequest.DEFAULT_SCROLL_TIMEOUT) + .scroll(TimeValue.timeValueMinutes(1)) .scrollId("scroll123"), request.scrollRequest()); } @@ -93,31 +136,32 @@ void scrollRequest() { void search() { OpenSearchScrollRequest request = new OpenSearchScrollRequest( new OpenSearchRequest.IndexName("test"), + TimeValue.timeValueMinutes(1), sourceBuilder, factory ); - String[] includes = {"_id", "_index"}; - when(sourceBuilder.fetchSource()).thenReturn(fetchSourceContext); - when(fetchSourceContext.includes()).thenReturn(includes); - when(searchAction.apply(any())).thenReturn(searchResponse); when(searchResponse.getHits()).thenReturn(searchHits); when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); - OpenSearchResponse searchResponse = request.search(searchAction, scrollAction); - verify(fetchSourceContext, times(2)).includes(); - assertFalse(searchResponse.isEmpty()); + Function scrollSearch = searchScrollRequest -> { + throw new AssertionError(); + }; + OpenSearchResponse openSearchResponse = request.search(searchRequest -> searchResponse, + scrollSearch); + + assertFalse(openSearchResponse.isEmpty()); } @Test void search_withoutContext() { OpenSearchScrollRequest request = new OpenSearchScrollRequest( new OpenSearchRequest.IndexName("test"), + TimeValue.timeValueMinutes(1), sourceBuilder, factory ); - when(sourceBuilder.fetchSource()).thenReturn(null); when(searchAction.apply(any())).thenReturn(searchResponse); when(searchResponse.getHits()).thenReturn(searchHits); when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); @@ -131,18 +175,152 @@ void search_withoutContext() { void search_withoutIncludes() { OpenSearchScrollRequest request = new OpenSearchScrollRequest( new OpenSearchRequest.IndexName("test"), + TimeValue.timeValueMinutes(1), sourceBuilder, factory ); - when(sourceBuilder.fetchSource()).thenReturn(fetchSourceContext); - when(fetchSourceContext.includes()).thenReturn(null); when(searchAction.apply(any())).thenReturn(searchResponse); when(searchResponse.getHits()).thenReturn(searchHits); when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); OpenSearchResponse searchResponse = request.search(searchAction, scrollAction); - verify(fetchSourceContext, times(1)).includes(); assertFalse(searchResponse.isEmpty()); } + + @Test + @SneakyThrows + void hasAnotherBatch() { + FieldUtils.writeField(request, "needClean", false, true); + request.setScrollId("scroll123"); + assertTrue(request.hasAnotherBatch()); + + request.reset(); + assertFalse(request.hasAnotherBatch()); + + request.setScrollId(""); + assertFalse(request.hasAnotherBatch()); + } + + @Test + void clean_on_empty_response() { + // This could happen on sequential search calls + SearchResponse searchResponse = mock(); + when(searchResponse.getScrollId()).thenReturn("scroll1", "scroll2"); + when(searchResponse.getHits()).thenReturn( + new SearchHits(new SearchHit[1], new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1F), + new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 1F)); + + request.search((x) -> searchResponse, (x) -> searchResponse); + assertEquals("scroll1", request.getScrollId()); + request.search((x) -> searchResponse, (x) -> searchResponse); + assertEquals("scroll1", request.getScrollId()); + + AtomicBoolean cleanCalled = new AtomicBoolean(false); + request.clean((s) -> cleanCalled.set(true)); + + assertEquals(NO_SCROLL_ID, request.getScrollId()); + assertTrue(cleanCalled.get()); + } + + @Test + void no_clean_on_non_empty_response() { + SearchResponse searchResponse = mock(); + when(searchResponse.getScrollId()).thenReturn("scroll"); + when(searchResponse.getHits()).thenReturn( + new SearchHits(new SearchHit[1], new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1F)); + + request.search((x) -> searchResponse, (x) -> searchResponse); + assertEquals("scroll", request.getScrollId()); + + request.clean((s) -> fail()); + assertEquals(NO_SCROLL_ID, request.getScrollId()); + } + + @Test + void no_cursor_on_empty_response() { + SearchResponse searchResponse = mock(); + when(searchResponse.getHits()).thenReturn( + new SearchHits(new SearchHit[0], null, 1f)); + + request.search((x) -> searchResponse, (x) -> searchResponse); + assertFalse(request.hasAnotherBatch()); + } + + @Test + void no_clean_if_no_scroll_in_response() { + SearchResponse searchResponse = mock(); + when(searchResponse.getHits()).thenReturn( + new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 1F)); + + request.search((x) -> searchResponse, (x) -> searchResponse); + assertEquals(NO_SCROLL_ID, request.getScrollId()); + + request.clean((s) -> fail()); + } + + @Test + @SneakyThrows + void serialize_deserialize_no_needClean() { + var stream = new BytesStreamOutput(); + request.writeTo(stream); + stream.flush(); + assertTrue(stream.size() > 0); + + // deserialize + var inStream = new BytesStreamInput(stream.bytes().toBytesRef().bytes); + var indexMock = mock(OpenSearchIndex.class); + var engine = mock(OpenSearchStorageEngine.class); + when(engine.getTable(any(), any())).thenReturn(indexMock); + var newRequest = new OpenSearchScrollRequest(inStream, engine); + assertEquals(request.getInitialSearchRequest(), newRequest.getInitialSearchRequest()); + assertEquals("", newRequest.getScrollId()); + } + + @Test + @SneakyThrows + void serialize_deserialize_needClean() { + lenient().when(searchResponse.getHits()).thenReturn( + new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 1F)); + lenient().when(searchResponse.getScrollId()).thenReturn(""); + + var stream = new BytesStreamOutput(); + request.search(searchRequest -> searchResponse, null); + request.writeTo(stream); + stream.flush(); + assertTrue(stream.size() > 0); + + // deserialize + var inStream = new BytesStreamInput(stream.bytes().toBytesRef().bytes); + var indexMock = mock(OpenSearchIndex.class); + var engine = mock(OpenSearchStorageEngine.class); + when(engine.getTable(any(), any())).thenReturn(indexMock); + var newRequest = new OpenSearchScrollRequest(inStream, engine); + assertEquals(request.getInitialSearchRequest(), newRequest.getInitialSearchRequest()); + assertEquals("", newRequest.getScrollId()); + } + + @Test + void setScrollId() { + request.setScrollId("test"); + assertEquals("test", request.getScrollId()); + } + + @Test + void includes() { + + assertIncludes(List.of(), searchSourceBuilder); + + searchSourceBuilder.fetchSource((String[])null, (String[])null); + assertIncludes(List.of(), searchSourceBuilder); + + searchSourceBuilder.fetchSource(new String[] {"test"}, null); + assertIncludes(List.of("test"), searchSourceBuilder); + + } + + void assertIncludes(List expected, SearchSourceBuilder sourceBuilder) { + assertEquals(expected, new OpenSearchScrollRequest( + INDEX_NAME, SCROLL_TIMEOUT, sourceBuilder, factory).getIncludes()); + } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java index 65568cf5f1..079a82b783 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java @@ -80,20 +80,25 @@ void isEmpty() { new TotalHits(2L, TotalHits.Relation.EQUAL_TO), 1.0F)); - assertFalse(new OpenSearchResponse(searchResponse, factory, includes).isEmpty()); + var response = new OpenSearchResponse(searchResponse, factory, includes); + assertFalse(response.isEmpty()); when(searchResponse.getHits()).thenReturn(SearchHits.empty()); when(searchResponse.getAggregations()).thenReturn(null); - assertTrue(new OpenSearchResponse(searchResponse, factory, includes).isEmpty()); + + response = new OpenSearchResponse(searchResponse, factory, includes); + assertTrue(response.isEmpty()); when(searchResponse.getHits()) .thenReturn(new SearchHits(null, new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0)); - OpenSearchResponse response3 = new OpenSearchResponse(searchResponse, factory, includes); - assertTrue(response3.isEmpty()); + response = new OpenSearchResponse(searchResponse, factory, includes); + assertTrue(response.isEmpty()); when(searchResponse.getHits()).thenReturn(SearchHits.empty()); when(searchResponse.getAggregations()).thenReturn(new Aggregations(emptyList())); - assertFalse(new OpenSearchResponse(searchResponse, factory, includes).isEmpty()); + + response = new OpenSearchResponse(searchResponse, factory, includes); + assertFalse(response.isEmpty()); } @Test diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScanTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScanTest.java deleted file mode 100644 index 8aec6a7d13..0000000000 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScanTest.java +++ /dev/null @@ -1,327 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.storage; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.opensearch.search.sort.FieldSortBuilder.DOC_FIELD_NAME; -import static org.opensearch.search.sort.SortOrder.ASC; -import static org.opensearch.sql.data.type.ExprCoreType.STRING; - -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.junit.jupiter.MockitoExtension; -import org.mockito.stubbing.Answer; -import org.opensearch.common.bytes.BytesArray; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.search.SearchHit; -import org.opensearch.search.fetch.subphase.highlight.HighlightBuilder; -import org.opensearch.sql.ast.expression.DataType; -import org.opensearch.sql.ast.expression.Literal; -import org.opensearch.sql.common.setting.Settings; -import org.opensearch.sql.data.model.ExprValue; -import org.opensearch.sql.data.model.ExprValueUtils; -import org.opensearch.sql.exception.SemanticCheckException; -import org.opensearch.sql.opensearch.client.OpenSearchClient; -import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; -import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; -import org.opensearch.sql.opensearch.request.OpenSearchQueryRequest; -import org.opensearch.sql.opensearch.request.OpenSearchRequest; -import org.opensearch.sql.opensearch.response.OpenSearchResponse; - -@ExtendWith(MockitoExtension.class) -class OpenSearchIndexScanTest { - - @Mock - private OpenSearchClient client; - - @Mock - private Settings settings; - - private OpenSearchExprValueFactory exprValueFactory = new OpenSearchExprValueFactory( - Map.of("name", OpenSearchDataType.of(STRING), - "department", OpenSearchDataType.of(STRING))); - - @BeforeEach - void setup() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - } - - @Test - void queryEmptyResult() { - mockResponse(); - try (OpenSearchIndexScan indexScan = - new OpenSearchIndexScan(client, settings, "test", 3, exprValueFactory)) { - indexScan.open(); - assertFalse(indexScan.hasNext()); - } - verify(client).cleanup(any()); - } - - @Test - void queryAllResultsWithQuery() { - mockResponse(new ExprValue[]{ - employee(1, "John", "IT"), - employee(2, "Smith", "HR"), - employee(3, "Allen", "IT")}); - - try (OpenSearchIndexScan indexScan = - new OpenSearchIndexScan(client, settings, "employees", 10, exprValueFactory)) { - indexScan.open(); - - assertTrue(indexScan.hasNext()); - assertEquals(employee(1, "John", "IT"), indexScan.next()); - - assertTrue(indexScan.hasNext()); - assertEquals(employee(2, "Smith", "HR"), indexScan.next()); - - assertTrue(indexScan.hasNext()); - assertEquals(employee(3, "Allen", "IT"), indexScan.next()); - - assertFalse(indexScan.hasNext()); - } - verify(client).cleanup(any()); - } - - @Test - void queryAllResultsWithScroll() { - mockResponse( - new ExprValue[]{employee(1, "John", "IT"), employee(2, "Smith", "HR")}, - new ExprValue[]{employee(3, "Allen", "IT")}); - - try (OpenSearchIndexScan indexScan = - new OpenSearchIndexScan(client, settings, "employees", 2, exprValueFactory)) { - indexScan.open(); - - assertTrue(indexScan.hasNext()); - assertEquals(employee(1, "John", "IT"), indexScan.next()); - - assertTrue(indexScan.hasNext()); - assertEquals(employee(2, "Smith", "HR"), indexScan.next()); - - assertTrue(indexScan.hasNext()); - assertEquals(employee(3, "Allen", "IT"), indexScan.next()); - - assertFalse(indexScan.hasNext()); - } - verify(client).cleanup(any()); - } - - @Test - void querySomeResultsWithQuery() { - mockResponse(new ExprValue[]{ - employee(1, "John", "IT"), - employee(2, "Smith", "HR"), - employee(3, "Allen", "IT"), - employee(4, "Bob", "HR")}); - - try (OpenSearchIndexScan indexScan = - new OpenSearchIndexScan(client, settings, "employees", 10, exprValueFactory)) { - indexScan.getRequestBuilder().pushDownLimit(3, 0); - indexScan.open(); - - assertTrue(indexScan.hasNext()); - assertEquals(employee(1, "John", "IT"), indexScan.next()); - - assertTrue(indexScan.hasNext()); - assertEquals(employee(2, "Smith", "HR"), indexScan.next()); - - assertTrue(indexScan.hasNext()); - assertEquals(employee(3, "Allen", "IT"), indexScan.next()); - - assertFalse(indexScan.hasNext()); - } - verify(client).cleanup(any()); - } - - @Test - void querySomeResultsWithScroll() { - mockResponse( - new ExprValue[]{employee(1, "John", "IT"), employee(2, "Smith", "HR")}, - new ExprValue[]{employee(3, "Allen", "IT"), employee(4, "Bob", "HR")}); - - try (OpenSearchIndexScan indexScan = - new OpenSearchIndexScan(client, settings, "employees", 2, exprValueFactory)) { - indexScan.getRequestBuilder().pushDownLimit(3, 0); - indexScan.open(); - - assertTrue(indexScan.hasNext()); - assertEquals(employee(1, "John", "IT"), indexScan.next()); - - assertTrue(indexScan.hasNext()); - assertEquals(employee(2, "Smith", "HR"), indexScan.next()); - - assertTrue(indexScan.hasNext()); - assertEquals(employee(3, "Allen", "IT"), indexScan.next()); - - assertFalse(indexScan.hasNext()); - } - verify(client).cleanup(any()); - } - - @Test - void pushDownFilters() { - assertThat() - .pushDown(QueryBuilders.termQuery("name", "John")) - .shouldQuery(QueryBuilders.termQuery("name", "John")) - .pushDown(QueryBuilders.termQuery("age", 30)) - .shouldQuery( - QueryBuilders.boolQuery() - .filter(QueryBuilders.termQuery("name", "John")) - .filter(QueryBuilders.termQuery("age", 30))) - .pushDown(QueryBuilders.rangeQuery("balance").gte(10000)) - .shouldQuery( - QueryBuilders.boolQuery() - .filter(QueryBuilders.termQuery("name", "John")) - .filter(QueryBuilders.termQuery("age", 30)) - .filter(QueryBuilders.rangeQuery("balance").gte(10000))); - } - - @Test - void pushDownHighlight() { - Map args = new HashMap<>(); - assertThat() - .pushDown(QueryBuilders.termQuery("name", "John")) - .pushDownHighlight("Title", args) - .pushDownHighlight("Body", args) - .shouldQueryHighlight(QueryBuilders.termQuery("name", "John"), - new HighlightBuilder().field("Title").field("Body")); - } - - @Test - void pushDownHighlightWithArguments() { - Map args = new HashMap<>(); - args.put("pre_tags", new Literal("", DataType.STRING)); - args.put("post_tags", new Literal("", DataType.STRING)); - HighlightBuilder highlightBuilder = new HighlightBuilder() - .field("Title"); - highlightBuilder.fields().get(0).preTags("").postTags(""); - assertThat() - .pushDown(QueryBuilders.termQuery("name", "John")) - .pushDownHighlight("Title", args) - .shouldQueryHighlight(QueryBuilders.termQuery("name", "John"), - highlightBuilder); - } - - @Test - void pushDownHighlightWithRepeatingFields() { - mockResponse( - new ExprValue[]{employee(1, "John", "IT"), employee(2, "Smith", "HR")}, - new ExprValue[]{employee(3, "Allen", "IT"), employee(4, "Bob", "HR")}); - - try (OpenSearchIndexScan indexScan = - new OpenSearchIndexScan(client, settings, "test", 2, exprValueFactory)) { - indexScan.getRequestBuilder().pushDownLimit(3, 0); - indexScan.open(); - Map args = new HashMap<>(); - indexScan.getRequestBuilder().pushDownHighlight("name", args); - indexScan.getRequestBuilder().pushDownHighlight("name", args); - } catch (SemanticCheckException e) { - assertTrue(e.getClass().equals(SemanticCheckException.class)); - } - verify(client).cleanup(any()); - } - - private PushDownAssertion assertThat() { - return new PushDownAssertion(client, exprValueFactory, settings); - } - - private static class PushDownAssertion { - private final OpenSearchClient client; - private final OpenSearchIndexScan indexScan; - private final OpenSearchResponse response; - private final OpenSearchExprValueFactory factory; - - public PushDownAssertion(OpenSearchClient client, - OpenSearchExprValueFactory valueFactory, - Settings settings) { - this.client = client; - this.indexScan = new OpenSearchIndexScan(client, settings, "test", 10000, valueFactory); - this.response = mock(OpenSearchResponse.class); - this.factory = valueFactory; - when(response.isEmpty()).thenReturn(true); - } - - PushDownAssertion pushDown(QueryBuilder query) { - indexScan.getRequestBuilder().pushDown(query); - return this; - } - - PushDownAssertion pushDownHighlight(String query, Map arguments) { - indexScan.getRequestBuilder().pushDownHighlight(query, arguments); - return this; - } - - PushDownAssertion shouldQueryHighlight(QueryBuilder query, HighlightBuilder highlight) { - OpenSearchRequest request = new OpenSearchQueryRequest("test", 200, factory); - request.getSourceBuilder() - .query(query) - .highlighter(highlight) - .sort(DOC_FIELD_NAME, ASC); - when(client.search(request)).thenReturn(response); - indexScan.open(); - return this; - } - - PushDownAssertion shouldQuery(QueryBuilder expected) { - OpenSearchRequest request = new OpenSearchQueryRequest("test", 200, factory); - request.getSourceBuilder() - .query(expected) - .sort(DOC_FIELD_NAME, ASC); - when(client.search(request)).thenReturn(response); - indexScan.open(); - return this; - } - } - - private void mockResponse(ExprValue[]... searchHitBatches) { - when(client.search(any())) - .thenAnswer( - new Answer() { - private int batchNum; - - @Override - public OpenSearchResponse answer(InvocationOnMock invocation) { - OpenSearchResponse response = mock(OpenSearchResponse.class); - int totalBatch = searchHitBatches.length; - if (batchNum < totalBatch) { - when(response.isEmpty()).thenReturn(false); - ExprValue[] searchHit = searchHitBatches[batchNum]; - when(response.iterator()).thenReturn(Arrays.asList(searchHit).iterator()); - } else { - when(response.isEmpty()).thenReturn(true); - } - - batchNum++; - return response; - } - }); - } - - protected ExprValue employee(int docId, String name, String department) { - SearchHit hit = new SearchHit(docId); - hit.sourceRef( - new BytesArray("{\"name\":\"" + name + "\",\"department\":\"" + department + "\"}")); - return tupleValue(hit); - } - - private ExprValue tupleValue(SearchHit hit) { - return ExprValueUtils.tupleValue(hit.getSourceAsMap()); - } -} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java index 3d856cb1e2..11694813cc 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.opensearch.storage; import static org.hamcrest.MatcherAssert.assertThat; @@ -12,13 +11,13 @@ import static org.hamcrest.Matchers.hasEntry; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.STRING; -import static org.opensearch.sql.expression.DSL.literal; import static org.opensearch.sql.expression.DSL.named; import static org.opensearch.sql.expression.DSL.ref; import static org.opensearch.sql.opensearch.data.type.OpenSearchDataType.MappingType; @@ -29,9 +28,7 @@ import static org.opensearch.sql.planner.logical.LogicalPlanDSL.sort; import com.google.common.collect.ImmutableMap; -import java.util.Arrays; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.ImmutablePair; @@ -41,30 +38,33 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.common.unit.TimeValue; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.common.setting.Settings; -import org.opensearch.sql.data.model.ExprBooleanValue; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; -import org.opensearch.sql.expression.aggregation.AvgAggregator; -import org.opensearch.sql.expression.aggregation.NamedAggregator; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; import org.opensearch.sql.opensearch.data.type.OpenSearchTextType; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.mapping.IndexMapping; +import org.opensearch.sql.opensearch.request.OpenSearchRequest; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScan; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalPlanDSL; import org.opensearch.sql.planner.physical.PhysicalPlanDSL; -import org.opensearch.sql.storage.Table; @ExtendWith(MockitoExtension.class) class OpenSearchIndexTest { - private final String indexName = "test"; + public static final int QUERY_SIZE_LIMIT = 200; + public static final TimeValue SCROLL_TIMEOUT = new TimeValue(1); + public static final OpenSearchRequest.IndexName INDEX_NAME + = new OpenSearchRequest.IndexName("test"); @Mock private OpenSearchClient client; @@ -75,9 +75,6 @@ class OpenSearchIndexTest { @Mock private Settings settings; - @Mock - private Table table; - @Mock private IndexMapping mapping; @@ -85,30 +82,31 @@ class OpenSearchIndexTest { @BeforeEach void setUp() { - this.index = new OpenSearchIndex(client, settings, indexName); + this.index = new OpenSearchIndex(client, settings, "test"); } @Test void isExist() { - when(client.exists(indexName)).thenReturn(true); + when(client.exists("test")).thenReturn(true); assertTrue(index.exists()); } @Test void createIndex() { - Map mappings = ImmutableMap.of( + Map mappings = Map.of( "properties", - ImmutableMap.of( + Map.of( "name", "text", "age", "integer")); - doNothing().when(client).createIndex(indexName, mappings); + doNothing().when(client).createIndex("test", mappings); Map schema = new HashMap<>(); schema.put("name", OpenSearchTextType.of(Map.of("keyword", OpenSearchDataType.of(MappingType.Keyword)))); schema.put("age", INTEGER); index.create(schema); + verify(client).createIndex(any(), any()); } @Test @@ -129,7 +127,7 @@ void getFieldTypes() { .put("id2", MappingType.Short) .put("blob", MappingType.Binary) .build().entrySet().stream().collect(Collectors.toMap( - e -> e.getKey(), e -> OpenSearchDataType.of(e.getValue()) + Map.Entry::getKey, e -> OpenSearchDataType.of(e.getValue()) ))); when(client.getIndexMappings("test")).thenReturn(ImmutableMap.of("test", mapping)); @@ -200,46 +198,38 @@ void getReservedFieldTypes() { @Test void implementRelationOperatorOnly() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - + when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); LogicalPlan plan = index.createScanBuilder(); Integer maxResultWindow = index.getMaxResultWindow(); - assertEquals( - new OpenSearchIndexScan(client, settings, indexName, maxResultWindow, exprValueFactory), - index.implement(plan)); + final var requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE_LIMIT, exprValueFactory); + assertEquals(new OpenSearchIndexScan(client, + 200, requestBuilder.build(INDEX_NAME, maxResultWindow, SCROLL_TIMEOUT)), + index.implement(index.optimize(plan))); } @Test void implementRelationOperatorWithOptimization() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - + when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); LogicalPlan plan = index.createScanBuilder(); Integer maxResultWindow = index.getMaxResultWindow(); - assertEquals( - new OpenSearchIndexScan(client, settings, indexName, maxResultWindow, exprValueFactory), - index.implement(index.optimize(plan))); + final var requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE_LIMIT, exprValueFactory); + assertEquals(new OpenSearchIndexScan(client, 200, + requestBuilder.build(INDEX_NAME, maxResultWindow, SCROLL_TIMEOUT)), index.implement(plan)); } @Test void implementOtherLogicalOperators() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - + when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); NamedExpression include = named("age", ref("age", INTEGER)); ReferenceExpression exclude = ref("name", STRING); ReferenceExpression dedupeField = ref("name", STRING); - Expression filterExpr = literal(ExprBooleanValue.of(true)); - List groupByExprs = Arrays.asList(named("age", ref("age", INTEGER))); - List aggregators = - Arrays.asList(named("avg(age)", new AvgAggregator(Arrays.asList(ref("age", INTEGER)), - DOUBLE))); Map mappings = ImmutableMap.of(ref("name", STRING), ref("lastname", STRING)); Pair newEvalField = ImmutablePair.of(ref("name1", STRING), ref("name", STRING)); - Integer sortCount = 100; Pair sortField = ImmutablePair.of(Sort.SortOption.DEFAULT_ASC, ref("name1", STRING)); @@ -259,6 +249,7 @@ void implementOtherLogicalOperators() { include); Integer maxResultWindow = index.getMaxResultWindow(); + final var requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE_LIMIT, exprValueFactory); assertEquals( PhysicalPlanDSL.project( PhysicalPlanDSL.dedupe( @@ -266,8 +257,9 @@ void implementOtherLogicalOperators() { PhysicalPlanDSL.eval( PhysicalPlanDSL.remove( PhysicalPlanDSL.rename( - new OpenSearchIndexScan(client, settings, indexName, - maxResultWindow, exprValueFactory), + new OpenSearchIndexScan(client, + QUERY_SIZE_LIMIT, requestBuilder.build(INDEX_NAME, maxResultWindow, + SCROLL_TIMEOUT)), mappings), exclude), newEvalField), diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java index ab87f4531c..1089e7e252 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java @@ -6,6 +6,7 @@ package org.opensearch.sql.opensearch.storage; +import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.opensearch.sql.analysis.DataSourceSchemaIdentifierNameResolver.DEFAULT_DATASOURCE_NAME; @@ -35,7 +36,10 @@ public void getTable() { OpenSearchStorageEngine engine = new OpenSearchStorageEngine(client, settings); Table table = engine.getTable(new DataSourceSchemaName(DEFAULT_DATASOURCE_NAME, "default"), "test"); - assertNotNull(table); + assertAll( + () -> assertNotNull(table), + () -> assertTrue(table instanceof OpenSearchIndex) + ); } @Test @@ -43,7 +47,9 @@ public void getSystemTable() { OpenSearchStorageEngine engine = new OpenSearchStorageEngine(client, settings); Table table = engine.getTable(new DataSourceSchemaName(DEFAULT_DATASOURCE_NAME, "default"), TABLE_INFO); - assertNotNull(table); - assertTrue(table instanceof OpenSearchSystemIndex); + assertAll( + () -> assertNotNull(table), + () -> assertTrue(table instanceof OpenSearchSystemIndex) + ); } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanAggregationBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanAggregationBuilderTest.java new file mode 100644 index 0000000000..5a510fefec --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanAggregationBuilderTest.java @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.planner.logical.LogicalAggregation; +import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalHighlight; +import org.opensearch.sql.planner.logical.LogicalLimit; +import org.opensearch.sql.planner.logical.LogicalNested; +import org.opensearch.sql.planner.logical.LogicalPaginate; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.planner.logical.LogicalSort; + +@ExtendWith(MockitoExtension.class) +class OpenSearchIndexScanAggregationBuilderTest { + @Mock + OpenSearchRequestBuilder requestBuilder; + @Mock + LogicalAggregation logicalAggregation; + OpenSearchIndexScanAggregationBuilder builder; + + @BeforeEach + void setup() { + builder = new OpenSearchIndexScanAggregationBuilder(requestBuilder, logicalAggregation); + } + + @Test + void pushDownFilter() { + assertFalse(builder.pushDownFilter(mock(LogicalFilter.class))); + } + + @Test + void pushDownSort() { + assertTrue(builder.pushDownSort(mock(LogicalSort.class))); + } + + @Test + void pushDownLimit() { + assertFalse(builder.pushDownLimit(mock(LogicalLimit.class))); + } + + @Test + void pushDownProject() { + assertFalse(builder.pushDownProject(mock(LogicalProject.class))); + } + + @Test + void pushDownHighlight() { + assertFalse(builder.pushDownHighlight(mock(LogicalHighlight.class))); + } + + @Test + void pushDownPageSize() { + assertFalse(builder.pushDownPageSize(mock(LogicalPaginate.class))); + } + + @Test + void pushDownNested() { + assertFalse(builder.pushDownNested(mock(LogicalNested.class))); + } + +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java index fa98f5a3b9..6bf9002a67 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java @@ -7,7 +7,8 @@ package org.opensearch.sql.opensearch.storage.scan; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.mockito.ArgumentMatchers.eq; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -24,6 +25,7 @@ import static org.opensearch.sql.planner.logical.LogicalPlanDSL.highlight; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.limit; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.nested; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.paginate; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.project; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.relation; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.sort; @@ -53,7 +55,6 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; -import org.opensearch.index.query.SpanOrQueryBuilder; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.bucket.composite.CompositeAggregationBuilder; @@ -65,7 +66,6 @@ import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValueUtils; -import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.FunctionExpression; @@ -78,15 +78,15 @@ import org.opensearch.sql.opensearch.response.agg.CompositeAggregationParser; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; import org.opensearch.sql.opensearch.response.agg.SingleValueParser; -import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; import org.opensearch.sql.opensearch.storage.script.aggregation.AggregationQueryBuilder; -import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalAggregation; import org.opensearch.sql.planner.logical.LogicalNested; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.optimizer.LogicalPlanOptimizer; +import org.opensearch.sql.planner.optimizer.PushDownPageSize; import org.opensearch.sql.planner.optimizer.rule.read.CreateTableScanBuilder; import org.opensearch.sql.storage.Table; - +import org.opensearch.sql.storage.TableScanOperator; @ExtendWith(MockitoExtension.class) class OpenSearchIndexScanOptimizationTest { @@ -106,16 +106,20 @@ class OpenSearchIndexScanOptimizationTest { @BeforeEach void setUp() { - indexScanBuilder = new OpenSearchIndexScanBuilder(indexScan); + indexScanBuilder = new OpenSearchIndexScanBuilder(requestBuilder) { + @Override + protected TableScanOperator createScan(OpenSearchRequestBuilder build) { + return indexScan; + } + }; when(table.createScanBuilder()).thenReturn(indexScanBuilder); - when(indexScan.getRequestBuilder()).thenReturn(requestBuilder); } @Test void test_project_push_down() { assertEqualsAfterOptimization( project( - indexScanAggBuilder( + indexScanBuilder( withProjectPushedDown(DSL.ref("intV", INTEGER))), DSL.named("i", DSL.ref("intV", INTEGER)) ), @@ -337,6 +341,21 @@ void test_sort_push_down() { ); } + @Test + void test_page_push_down() { + assertEqualsAfterOptimization( + project( + indexScanBuilder( + withPageSizePushDown(5)), + DSL.named("intV", DSL.ref("intV", INTEGER)) + ), + paginate(project( + relation("schema", table), + DSL.named("intV", DSL.ref("intV", INTEGER)) + ), 5 + )); + } + @Test void test_score_sort_push_down() { assertEqualsAfterOptimization( @@ -679,16 +698,28 @@ void project_literal_should_not_be_pushed_down() { private OpenSearchIndexScanBuilder indexScanBuilder(Runnable... verifyPushDownCalls) { this.verifyPushDownCalls = verifyPushDownCalls; - return new OpenSearchIndexScanBuilder(new OpenSearchIndexScanQueryBuilder(indexScan)); + return new OpenSearchIndexScanBuilder(new OpenSearchIndexScanQueryBuilder(requestBuilder)) { + @Override + protected TableScanOperator createScan(OpenSearchRequestBuilder build) { + return indexScan; + } + }; } private OpenSearchIndexScanBuilder indexScanAggBuilder(Runnable... verifyPushDownCalls) { this.verifyPushDownCalls = verifyPushDownCalls; - return new OpenSearchIndexScanBuilder(new OpenSearchIndexScanAggregationBuilder(indexScan)); + return new OpenSearchIndexScanBuilder(new OpenSearchIndexScanAggregationBuilder( + requestBuilder, mock(LogicalAggregation.class))) { + @Override + protected TableScanOperator createScan(OpenSearchRequestBuilder build) { + return indexScan; + } + }; } private void assertEqualsAfterOptimization(LogicalPlan expected, LogicalPlan actual) { - assertEquals(expected, optimize(actual)); + final var optimized = optimize(actual); + assertEquals(expected, optimized); // Trigger build to make sure all push down actually happened in scan builder indexScanBuilder.build(); @@ -702,7 +733,7 @@ private void assertEqualsAfterOptimization(LogicalPlan expected, LogicalPlan act } private Runnable withFilterPushedDown(QueryBuilder filteringCondition) { - return () -> verify(requestBuilder, times(1)).pushDown(filteringCondition); + return () -> verify(requestBuilder, times(1)).pushDownFilter(filteringCondition); } private Runnable withAggregationPushedDown( @@ -760,6 +791,10 @@ private Runnable withTrackedScoresPushedDown(boolean trackScores) { return () -> verify(requestBuilder, times(1)).pushDownTrackedScore(trackScores); } + private Runnable withPageSizePushDown(int pageSize) { + return () -> verify(requestBuilder, times(1)).pushDownPageSize(pageSize); + } + private static AggregationAssertHelper.AggregationAssertHelperBuilder aggregate(String aggName) { var aggBuilder = new AggregationAssertHelper.AggregationAssertHelperBuilder(); aggBuilder.aggregateName = aggName; @@ -785,6 +820,7 @@ private static class AggregationAssertHelper { private LogicalPlan optimize(LogicalPlan plan) { LogicalPlanOptimizer optimizer = new LogicalPlanOptimizer(List.of( new CreateTableScanBuilder(), + new PushDownPageSize(), PUSH_DOWN_FILTER, PUSH_DOWN_AGGREGATION, PUSH_DOWN_SORT, diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanPaginationTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanPaginationTest.java new file mode 100644 index 0000000000..67f0869d6e --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanPaginationTest.java @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.CALLS_REAL_METHODS; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScanTest.QUERY_SIZE; +import static org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScanTest.mockResponse; + +import java.io.ByteArrayOutputStream; +import java.io.ObjectOutputStream; +import java.util.Map; +import lombok.SneakyThrows; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.exception.NoCursorException; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; +import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; +import org.opensearch.sql.opensearch.request.OpenSearchRequest; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.response.OpenSearchResponse; + +@ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +public class OpenSearchIndexScanPaginationTest { + + public static final OpenSearchRequest.IndexName INDEX_NAME + = new OpenSearchRequest.IndexName("test"); + public static final int MAX_RESULT_WINDOW = 3; + public static final TimeValue SCROLL_TIMEOUT = TimeValue.timeValueMinutes(4); + @Mock + private Settings settings; + + @BeforeEach + void setup() { + lenient().when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(QUERY_SIZE); + lenient().when(settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE)) + .thenReturn(TimeValue.timeValueMinutes(1)); + } + + @Mock + private OpenSearchClient client; + + private final OpenSearchExprValueFactory exprValueFactory + = new OpenSearchExprValueFactory(Map.of( + "name", OpenSearchDataType.of(STRING), + "department", OpenSearchDataType.of(STRING))); + + @Test + void query_empty_result() { + mockResponse(client); + var builder = new OpenSearchRequestBuilder(QUERY_SIZE, exprValueFactory); + try (var indexScan = new OpenSearchIndexScan(client, MAX_RESULT_WINDOW, + builder.build(INDEX_NAME, MAX_RESULT_WINDOW, SCROLL_TIMEOUT))) { + indexScan.open(); + assertFalse(indexScan.hasNext()); + } + verify(client).cleanup(any()); + } + + @Test + void explain_not_implemented() { + assertThrows(Throwable.class, () -> mock(OpenSearchIndexScan.class, + withSettings().defaultAnswer(CALLS_REAL_METHODS)).explain()); + } + + @Test + @SneakyThrows + void dont_serialize_if_no_cursor() { + OpenSearchRequestBuilder builder = mock(); + OpenSearchRequest request = mock(); + OpenSearchResponse response = mock(); + when(builder.build(any(), anyInt(), any())).thenReturn(request); + when(client.search(any())).thenReturn(response); + try (var indexScan + = new OpenSearchIndexScan(client, MAX_RESULT_WINDOW, + builder.build(INDEX_NAME, MAX_RESULT_WINDOW, SCROLL_TIMEOUT))) { + indexScan.open(); + + when(request.hasAnotherBatch()).thenReturn(false); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + ObjectOutputStream objectOutput = new ObjectOutputStream(output); + assertThrows(NoCursorException.class, () -> objectOutput.writeObject(indexScan)); + } + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanTest.java new file mode 100644 index 0000000000..08590f8021 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanTest.java @@ -0,0 +1,444 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.sql.opensearch.storage.scan; + +import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.search.sort.FieldSortBuilder.DOC_FIELD_NAME; +import static org.opensearch.search.sort.SortOrder.ASC; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +import java.io.ByteArrayOutputStream; +import java.io.ObjectOutputStream; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import lombok.SneakyThrows; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.stubbing.Answer; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.common.bytes.BytesArray; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.subphase.highlight.HighlightBuilder; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.exception.NoCursorException; +import org.opensearch.sql.executor.pagination.PlanSerializer; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; +import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; +import org.opensearch.sql.opensearch.request.OpenSearchQueryRequest; +import org.opensearch.sql.opensearch.request.OpenSearchRequest; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.request.OpenSearchScrollRequest; +import org.opensearch.sql.opensearch.response.OpenSearchResponse; +import org.opensearch.sql.opensearch.storage.OpenSearchIndex; +import org.opensearch.sql.opensearch.storage.OpenSearchStorageEngine; + +@ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +class OpenSearchIndexScanTest { + + public static final int QUERY_SIZE = 200; + public static final OpenSearchRequest.IndexName INDEX_NAME + = new OpenSearchRequest.IndexName("employees"); + public static final int MAX_RESULT_WINDOW = 10000; + public static final TimeValue CURSOR_KEEP_ALIVE = TimeValue.timeValueMinutes(1); + @Mock + private OpenSearchClient client; + + private final OpenSearchExprValueFactory exprValueFactory = new OpenSearchExprValueFactory( + Map.of("name", OpenSearchDataType.of(STRING), + "department", OpenSearchDataType.of(STRING))); + + @BeforeEach + void setup() { + } + + @Test + void explain() { + var request = mock(OpenSearchRequest.class); + when(request.toString()).thenReturn("explain works!"); + try (var indexScan = new OpenSearchIndexScan(client, QUERY_SIZE, request)) { + assertEquals("explain works!", indexScan.explain()); + } + } + + @Test + @SneakyThrows + void throws_no_cursor_exception() { + var request = mock(OpenSearchRequest.class); + when(request.hasAnotherBatch()).thenReturn(false); + try (var indexScan = new OpenSearchIndexScan(client, QUERY_SIZE, request); + var byteStream = new ByteArrayOutputStream(); + var objectStream = new ObjectOutputStream(byteStream)) { + assertThrows(NoCursorException.class, () -> objectStream.writeObject(indexScan)); + } + } + + @Test + @SneakyThrows + void serialize() { + var searchSourceBuilder = new SearchSourceBuilder().size(4); + + var factory = mock(OpenSearchExprValueFactory.class); + var engine = mock(OpenSearchStorageEngine.class); + var index = mock(OpenSearchIndex.class); + when(engine.getClient()).thenReturn(client); + when(engine.getTable(any(), any())).thenReturn(index); + var request = new OpenSearchScrollRequest( + INDEX_NAME, CURSOR_KEEP_ALIVE, searchSourceBuilder, factory); + request.setScrollId("valid-id"); + // make a response, so OpenSearchResponse::isEmpty would return true and unset needClean + var response = mock(SearchResponse.class); + when(response.getAggregations()).thenReturn(mock()); + var hits = mock(SearchHits.class); + when(response.getHits()).thenReturn(hits); + when(response.getScrollId()).thenReturn("valid-id"); + when(hits.getHits()).thenReturn(new SearchHit[]{ mock() }); + request.search(null, (req) -> response); + + try (var indexScan = new OpenSearchIndexScan(client, QUERY_SIZE, request)) { + var planSerializer = new PlanSerializer(engine); + var cursor = planSerializer.convertToCursor(indexScan); + var newPlan = planSerializer.convertToPlan(cursor.toString()); + assertEquals(indexScan, newPlan); + } + } + + @Test + void plan_for_serialization() { + var request = mock(OpenSearchRequest.class); + try (var indexScan = new OpenSearchIndexScan(client, QUERY_SIZE, request)) { + assertEquals(indexScan, indexScan.getPlanForSerialization()); + } + } + + @Test + void query_empty_result() { + mockResponse(client); + final var name = new OpenSearchRequest.IndexName("test"); + final var requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE, exprValueFactory); + try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, + QUERY_SIZE, requestBuilder.build(name, MAX_RESULT_WINDOW, CURSOR_KEEP_ALIVE))) { + indexScan.open(); + assertFalse(indexScan.hasNext()); + } + verify(client).cleanup(any()); + } + + @Test + void query_all_results_with_query() { + mockResponse(client, new ExprValue[]{ + employee(1, "John", "IT"), + employee(2, "Smith", "HR"), + employee(3, "Allen", "IT")}); + + final var requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE, exprValueFactory); + try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, + 10, requestBuilder.build(INDEX_NAME, 10000, CURSOR_KEEP_ALIVE))) { + indexScan.open(); + + assertAll( + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(1, "John", "IT"), indexScan.next()), + + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(2, "Smith", "HR"), indexScan.next()), + + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(3, "Allen", "IT"), indexScan.next()), + + () -> assertFalse(indexScan.hasNext()) + ); + } + verify(client).cleanup(any()); + } + + static final OpenSearchRequest.IndexName EMPLOYEES_INDEX + = new OpenSearchRequest.IndexName("employees"); + + @Test + void query_all_results_with_scroll() { + mockResponse(client, + new ExprValue[]{employee(1, "John", "IT"), employee(2, "Smith", "HR")}, + new ExprValue[]{employee(3, "Allen", "IT")}); + + final var requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE, exprValueFactory); + try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, + 10, requestBuilder.build(INDEX_NAME, 10000, CURSOR_KEEP_ALIVE))) { + indexScan.open(); + + assertAll( + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(1, "John", "IT"), indexScan.next()), + + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(2, "Smith", "HR"), indexScan.next()), + + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(3, "Allen", "IT"), indexScan.next()), + + () -> assertFalse(indexScan.hasNext()) + ); + } + verify(client).cleanup(any()); + } + + @Test + void query_some_results_with_query() { + mockResponse(client, new ExprValue[]{ + employee(1, "John", "IT"), + employee(2, "Smith", "HR"), + employee(3, "Allen", "IT"), + employee(4, "Bob", "HR")}); + + final int limit = 3; + OpenSearchRequestBuilder builder = new OpenSearchRequestBuilder(0, exprValueFactory); + try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, + limit, builder.build(INDEX_NAME, MAX_RESULT_WINDOW, CURSOR_KEEP_ALIVE))) { + indexScan.open(); + + assertAll( + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(1, "John", "IT"), indexScan.next()), + + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(2, "Smith", "HR"), indexScan.next()), + + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(3, "Allen", "IT"), indexScan.next()), + + () -> assertFalse(indexScan.hasNext()) + ); + } + verify(client).cleanup(any()); + } + + @Test + void query_some_results_with_scroll() { + mockTwoPageResponse(client); + final var requestuilder = new OpenSearchRequestBuilder(10, exprValueFactory); + try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, + 3, requestuilder.build(INDEX_NAME, MAX_RESULT_WINDOW, CURSOR_KEEP_ALIVE))) { + indexScan.open(); + + assertAll( + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(1, "John", "IT"), indexScan.next()), + + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(2, "Smith", "HR"), indexScan.next()), + + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(3, "Allen", "IT"), indexScan.next()), + + () -> assertFalse(indexScan.hasNext()) + ); + } + verify(client).cleanup(any()); + } + + static void mockTwoPageResponse(OpenSearchClient client) { + mockResponse(client, + new ExprValue[]{employee(1, "John", "IT"), employee(2, "Smith", "HR")}, + new ExprValue[]{employee(3, "Allen", "IT"), employee(4, "Bob", "HR")}); + } + + @Test + void query_results_limited_by_query_size() { + mockResponse(client, new ExprValue[]{ + employee(1, "John", "IT"), + employee(2, "Smith", "HR"), + employee(3, "Allen", "IT"), + employee(4, "Bob", "HR")}); + + final int defaultQuerySize = 2; + final var requestBuilder = new OpenSearchRequestBuilder(defaultQuerySize, exprValueFactory); + try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, + defaultQuerySize, requestBuilder.build(INDEX_NAME, QUERY_SIZE, CURSOR_KEEP_ALIVE))) { + indexScan.open(); + + assertAll( + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(1, "John", "IT"), indexScan.next()), + + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(2, "Smith", "HR"), indexScan.next()), + + () -> assertFalse(indexScan.hasNext()) + ); + } + verify(client).cleanup(any()); + } + + @Test + void push_down_filters() { + assertThat() + .pushDown(QueryBuilders.termQuery("name", "John")) + .shouldQuery(QueryBuilders.termQuery("name", "John")) + .pushDown(QueryBuilders.termQuery("age", 30)) + .shouldQuery( + QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery("name", "John")) + .filter(QueryBuilders.termQuery("age", 30))) + .pushDown(QueryBuilders.rangeQuery("balance").gte(10000)) + .shouldQuery( + QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery("name", "John")) + .filter(QueryBuilders.termQuery("age", 30)) + .filter(QueryBuilders.rangeQuery("balance").gte(10000))); + } + + @Test + void push_down_highlight() { + Map args = new HashMap<>(); + assertThat() + .pushDown(QueryBuilders.termQuery("name", "John")) + .pushDownHighlight("Title", args) + .pushDownHighlight("Body", args) + .shouldQueryHighlight(QueryBuilders.termQuery("name", "John"), + new HighlightBuilder().field("Title").field("Body")); + } + + @Test + void push_down_highlight_with_arguments() { + Map args = new HashMap<>(); + args.put("pre_tags", new Literal("", DataType.STRING)); + args.put("post_tags", new Literal("", DataType.STRING)); + HighlightBuilder highlightBuilder = new HighlightBuilder() + .field("Title"); + highlightBuilder.fields().get(0).preTags("").postTags(""); + assertThat() + .pushDown(QueryBuilders.termQuery("name", "John")) + .pushDownHighlight("Title", args) + .shouldQueryHighlight(QueryBuilders.termQuery("name", "John"), + highlightBuilder); + } + + private PushDownAssertion assertThat() { + return new PushDownAssertion(client, exprValueFactory); + } + + private static class PushDownAssertion { + private final OpenSearchClient client; + private final OpenSearchRequestBuilder requestBuilder; + private final OpenSearchResponse response; + private final OpenSearchExprValueFactory factory; + + public PushDownAssertion(OpenSearchClient client, + OpenSearchExprValueFactory valueFactory) { + this.client = client; + this.requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE, valueFactory); + + this.response = mock(OpenSearchResponse.class); + this.factory = valueFactory; + when(response.isEmpty()).thenReturn(true); + } + + PushDownAssertion pushDown(QueryBuilder query) { + requestBuilder.pushDownFilter(query); + return this; + } + + PushDownAssertion pushDownHighlight(String query, Map arguments) { + requestBuilder.pushDownHighlight(query, arguments); + return this; + } + + PushDownAssertion shouldQueryHighlight(QueryBuilder query, HighlightBuilder highlight) { + var sourceBuilder = new SearchSourceBuilder() + .from(0) + .timeout(CURSOR_KEEP_ALIVE) + .query(query) + .size(QUERY_SIZE) + .highlighter(highlight) + .sort(DOC_FIELD_NAME, ASC); + OpenSearchRequest request = + new OpenSearchQueryRequest(EMPLOYEES_INDEX, sourceBuilder, factory); + + when(client.search(request)).thenReturn(response); + var indexScan = new OpenSearchIndexScan(client, + QUERY_SIZE, requestBuilder.build(EMPLOYEES_INDEX, 10000, CURSOR_KEEP_ALIVE)); + indexScan.open(); + return this; + } + + PushDownAssertion shouldQuery(QueryBuilder expected) { + var builder = new SearchSourceBuilder() + .from(0) + .query(expected) + .size(QUERY_SIZE) + .timeout(CURSOR_KEEP_ALIVE) + .sort(DOC_FIELD_NAME, ASC); + OpenSearchRequest request = new OpenSearchQueryRequest(EMPLOYEES_INDEX, builder, factory); + when(client.search(request)).thenReturn(response); + var indexScan = new OpenSearchIndexScan(client, + 10000, requestBuilder.build(EMPLOYEES_INDEX, 10000, CURSOR_KEEP_ALIVE)); + indexScan.open(); + return this; + } + } + + public static void mockResponse(OpenSearchClient client, ExprValue[]... searchHitBatches) { + when(client.search(any())) + .thenAnswer( + new Answer() { + private int batchNum; + + @Override + public OpenSearchResponse answer(InvocationOnMock invocation) { + OpenSearchResponse response = mock(OpenSearchResponse.class); + int totalBatch = searchHitBatches.length; + if (batchNum < totalBatch) { + when(response.isEmpty()).thenReturn(false); + ExprValue[] searchHit = searchHitBatches[batchNum]; + when(response.iterator()).thenReturn(Arrays.asList(searchHit).iterator()); + } else { + when(response.isEmpty()).thenReturn(true); + } + + batchNum++; + return response; + } + }); + } + + public static ExprValue employee(int docId, String name, String department) { + SearchHit hit = new SearchHit(docId); + hit.sourceRef( + new BytesArray("{\"name\":\"" + name + "\",\"department\":\"" + department + "\"}")); + return tupleValue(hit); + } + + private static ExprValue tupleValue(SearchHit hit) { + return ExprValueUtils.tupleValue(hit.getSourceAsMap()); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/PushDownQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/PushDownQueryBuilderTest.java new file mode 100644 index 0000000000..0b0568a6b7 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/PushDownQueryBuilderTest.java @@ -0,0 +1,42 @@ +package org.opensearch.sql.opensearch.storage.scan; + + +import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.mockito.Mockito.mock; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalHighlight; +import org.opensearch.sql.planner.logical.LogicalLimit; +import org.opensearch.sql.planner.logical.LogicalNested; +import org.opensearch.sql.planner.logical.LogicalPaginate; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.planner.logical.LogicalSort; + +@ExtendWith(MockitoExtension.class) +class PushDownQueryBuilderTest { + @Test + void default_implementations() { + var sample = new PushDownQueryBuilder() { + @Override + public OpenSearchRequestBuilder build() { + return null; + } + }; + assertAll( + () -> assertFalse(sample.pushDownFilter(mock(LogicalFilter.class))), + () -> assertFalse(sample.pushDownProject(mock(LogicalProject.class))), + () -> assertFalse(sample.pushDownHighlight(mock(LogicalHighlight.class))), + () -> assertFalse(sample.pushDownSort(mock(LogicalSort.class))), + () -> assertFalse(sample.pushDownNested(mock(LogicalNested.class))), + () -> assertFalse(sample.pushDownLimit(mock(LogicalLimit.class))), + () -> assertFalse(sample.pushDownPageSize(mock(LogicalPaginate.class))) + + ); + } + +} diff --git a/plugin/build.gradle b/plugin/build.gradle index e318103859..4a6c175d61 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -246,6 +246,7 @@ afterEvaluate { testClusters.integTest { plugin(project.tasks.bundlePlugin.archiveFile) + testDistribution = "ARCHIVE" // debug with command, ./gradlew opensearch-sql:run -DdebugJVM. --debug-jvm does not work with keystore. if (System.getProperty("debugJVM") != null) { diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/config/OpenSearchPluginModule.java b/plugin/src/main/java/org/opensearch/sql/plugin/config/OpenSearchPluginModule.java index 5ab4bbaecd..f301a242fb 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/config/OpenSearchPluginModule.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/config/OpenSearchPluginModule.java @@ -18,6 +18,7 @@ import org.opensearch.sql.executor.QueryManager; import org.opensearch.sql.executor.QueryService; import org.opensearch.sql.executor.execution.QueryPlanFactory; +import org.opensearch.sql.executor.pagination.PlanSerializer; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; import org.opensearch.sql.monitor.ResourceMonitor; import org.opensearch.sql.opensearch.client.OpenSearchClient; @@ -58,8 +59,9 @@ public StorageEngine storageEngine(OpenSearchClient client, Settings settings) { } @Provides - public ExecutionEngine executionEngine(OpenSearchClient client, ExecutionProtector protector) { - return new OpenSearchExecutionEngine(client, protector); + public ExecutionEngine executionEngine(OpenSearchClient client, ExecutionProtector protector, + PlanSerializer planSerializer) { + return new OpenSearchExecutionEngine(client, protector, planSerializer); } @Provides @@ -72,6 +74,11 @@ public ExecutionProtector protector(ResourceMonitor resourceMonitor) { return new OpenSearchExecutionProtector(resourceMonitor); } + @Provides + public PlanSerializer planSerializer(StorageEngine storageEngine) { + return new PlanSerializer(storageEngine); + } + @Provides @Singleton public QueryManager queryManager(NodeClient nodeClient) { @@ -92,12 +99,14 @@ public SQLService sqlService(QueryManager queryManager, QueryPlanFactory queryPl * {@link QueryPlanFactory}. */ @Provides - public QueryPlanFactory queryPlanFactory( - DataSourceService dataSourceService, ExecutionEngine executionEngine) { + public QueryPlanFactory queryPlanFactory(DataSourceService dataSourceService, + ExecutionEngine executionEngine) { Analyzer analyzer = new Analyzer( new ExpressionAnalyzer(functionRepository), dataSourceService, functionRepository); Planner planner = new Planner(LogicalPlanOptimizer.create()); - return new QueryPlanFactory(new QueryService(analyzer, executionEngine, planner)); + QueryService queryService = new QueryService( + analyzer, executionEngine, planner); + return new QueryPlanFactory(queryService); } } diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java index a5c094e956..dbe5230abf 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java @@ -139,7 +139,8 @@ private ResponseListener createListener( @Override public void onResponse(ExecutionEngine.QueryResponse response) { String responseContent = - formatter.format(new QueryResult(response.getSchema(), response.getResults())); + formatter.format(new QueryResult(response.getSchema(), response.getResults(), + response.getCursor())); listener.onResponse(new TransportPPLQueryResponse(responseContent)); } diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/PPLService.java b/ppl/src/main/java/org/opensearch/sql/ppl/PPLService.java index e11edc1646..40a7a85f78 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/PPLService.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/PPLService.java @@ -90,6 +90,7 @@ private AbstractPlan plan( QueryContext.getRequestId(), anonymizer.anonymizeStatement(statement)); - return queryExecutionFactory.create(statement, queryListener, explainListener); + return queryExecutionFactory.create( + statement, queryListener, explainListener); } } diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java index e4f40e9a11..3b7e5a78dd 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java @@ -33,7 +33,7 @@ public class AstStatementBuilder extends OpenSearchPPLParserBaseVisitor { ResponseListener listener = invocation.getArgument(1); - listener.onResponse(new QueryResponse(schema, Collections.emptyList())); + listener.onResponse(new QueryResponse(schema, Collections.emptyList(), Cursor.None)); return null; }).when(queryService).execute(any(), any()); @@ -87,7 +89,7 @@ public void onFailure(Exception e) { public void testExecuteCsvFormatShouldPass() { doAnswer(invocation -> { ResponseListener listener = invocation.getArgument(1); - listener.onResponse(new QueryResponse(schema, Collections.emptyList())); + listener.onResponse(new QueryResponse(schema, Collections.emptyList(), Cursor.None)); return null; }).when(queryService).execute(any(), any()); @@ -161,7 +163,7 @@ public void onFailure(Exception e) { public void testPrometheusQuery() { doAnswer(invocation -> { ResponseListener listener = invocation.getArgument(1); - listener.onResponse(new QueryResponse(schema, Collections.emptyList())); + listener.onResponse(new QueryResponse(schema, Collections.emptyList(), Cursor.None)); return null; }).when(queryService).execute(any(), any()); diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstStatementBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstStatementBuilderTest.java index 4760024692..de74e4932f 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstStatementBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstStatementBuilderTest.java @@ -39,7 +39,8 @@ public void buildQueryStatement() { "search source=t a=1", new Query( project( - filter(relation("t"), compare("=", field("a"), intLiteral(1))), AllFields.of()))); + filter(relation("t"), compare("=", field("a"), + intLiteral(1))), AllFields.of()), 0)); } @Test @@ -50,7 +51,7 @@ public void buildExplainStatement() { new Query( project( filter(relation("t"), compare("=", field("a"), intLiteral(1))), - AllFields.of())))); + AllFields.of()), 0))); } private void assertEqual(String query, Statement expectedStatement) { diff --git a/protocol/src/main/java/org/opensearch/sql/protocol/response/QueryResult.java b/protocol/src/main/java/org/opensearch/sql/protocol/response/QueryResult.java index 915a61f361..ae66364419 100644 --- a/protocol/src/main/java/org/opensearch/sql/protocol/response/QueryResult.java +++ b/protocol/src/main/java/org/opensearch/sql/protocol/response/QueryResult.java @@ -16,6 +16,7 @@ import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.executor.ExecutionEngine.Schema.Column; +import org.opensearch.sql.executor.pagination.Cursor; /** * Query response that encapsulates query results and isolate {@link ExprValue} @@ -32,6 +33,12 @@ public class QueryResult implements Iterable { */ private final Collection exprValues; + @Getter + private final Cursor cursor; + + public QueryResult(ExecutionEngine.Schema schema, Collection exprValues) { + this(schema, exprValues, Cursor.None); + } /** * size of results. diff --git a/protocol/src/main/java/org/opensearch/sql/protocol/response/format/CommandResponseFormatter.java b/protocol/src/main/java/org/opensearch/sql/protocol/response/format/CommandResponseFormatter.java new file mode 100644 index 0000000000..68d9be558b --- /dev/null +++ b/protocol/src/main/java/org/opensearch/sql/protocol/response/format/CommandResponseFormatter.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.protocol.response.format; + +import lombok.Getter; +import org.opensearch.sql.executor.execution.CommandPlan; +import org.opensearch.sql.opensearch.response.error.ErrorMessage; +import org.opensearch.sql.opensearch.response.error.ErrorMessageFactory; +import org.opensearch.sql.protocol.response.QueryResult; + +/** + * A simple response formatter which contains no data. + * Supposed to use with {@link CommandPlan} only. + */ +public class CommandResponseFormatter extends JsonResponseFormatter { + + public CommandResponseFormatter() { + super(Style.PRETTY); + } + + @Override + protected Object buildJsonObject(QueryResult response) { + return new NoQueryResponse(); + } + + @Override + public String format(Throwable t) { + return new JdbcResponseFormatter(Style.PRETTY).format(t); + } + + @Getter + public static class NoQueryResponse { + // in case of failure an exception is thrown + private final boolean succeeded = true; + } +} diff --git a/protocol/src/main/java/org/opensearch/sql/protocol/response/format/JdbcResponseFormatter.java b/protocol/src/main/java/org/opensearch/sql/protocol/response/format/JdbcResponseFormatter.java index 943287cb62..1ad3ffde34 100644 --- a/protocol/src/main/java/org/opensearch/sql/protocol/response/format/JdbcResponseFormatter.java +++ b/protocol/src/main/java/org/opensearch/sql/protocol/response/format/JdbcResponseFormatter.java @@ -15,6 +15,7 @@ import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.exception.QueryEngineException; import org.opensearch.sql.executor.ExecutionEngine.Schema; +import org.opensearch.sql.executor.pagination.Cursor; import org.opensearch.sql.opensearch.response.error.ErrorMessage; import org.opensearch.sql.opensearch.response.error.ErrorMessageFactory; import org.opensearch.sql.protocol.response.QueryResult; @@ -42,6 +43,9 @@ protected Object buildJsonObject(QueryResult response) { json.total(response.size()) .size(response.size()) .status(200); + if (!response.getCursor().equals(Cursor.None)) { + json.cursor(response.getCursor().toString()); + } return json.build(); } @@ -95,6 +99,8 @@ public static class JdbcResponse { private final long total; private final long size; private final int status; + + private final String cursor; } @RequiredArgsConstructor diff --git a/protocol/src/test/java/org/opensearch/sql/protocol/response/QueryResultTest.java b/protocol/src/test/java/org/opensearch/sql/protocol/response/QueryResultTest.java index 319965e2d0..4c58e189b8 100644 --- a/protocol/src/test/java/org/opensearch/sql/protocol/response/QueryResultTest.java +++ b/protocol/src/test/java/org/opensearch/sql/protocol/response/QueryResultTest.java @@ -19,6 +19,7 @@ import java.util.Collections; import org.junit.jupiter.api.Test; import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.executor.pagination.Cursor; class QueryResultTest { @@ -35,7 +36,7 @@ void size() { tupleValue(ImmutableMap.of("name", "John", "age", 20)), tupleValue(ImmutableMap.of("name", "Allen", "age", 30)), tupleValue(ImmutableMap.of("name", "Smith", "age", 40)) - )); + ), Cursor.None); assertEquals(3, response.size()); } @@ -45,7 +46,7 @@ void columnNameTypes() { schema, Collections.singletonList( tupleValue(ImmutableMap.of("name", "John", "age", 20)) - )); + ), Cursor.None); assertEquals( ImmutableMap.of("name", "string", "age", "integer"), @@ -59,7 +60,8 @@ void columnNameTypesWithAlias() { new ExecutionEngine.Schema.Column("name", "n", STRING))); QueryResult response = new QueryResult( schema, - Collections.singletonList(tupleValue(ImmutableMap.of("n", "John")))); + Collections.singletonList(tupleValue(ImmutableMap.of("n", "John"))), + Cursor.None); assertEquals( ImmutableMap.of("n", "string"), @@ -71,7 +73,7 @@ void columnNameTypesWithAlias() { void columnNameTypesFromEmptyExprValues() { QueryResult response = new QueryResult( schema, - Collections.emptyList()); + Collections.emptyList(), Cursor.None); assertEquals( ImmutableMap.of("name", "string", "age", "integer"), response.columnNameTypes() @@ -100,7 +102,7 @@ void iterate() { Arrays.asList( tupleValue(ImmutableMap.of("name", "John", "age", 20)), tupleValue(ImmutableMap.of("name", "Allen", "age", 30)) - )); + ), Cursor.None); int i = 0; for (Object[] objects : response) { diff --git a/protocol/src/test/java/org/opensearch/sql/protocol/response/format/CommandResponseFormatterTest.java b/protocol/src/test/java/org/opensearch/sql/protocol/response/format/CommandResponseFormatterTest.java new file mode 100644 index 0000000000..a3052324fe --- /dev/null +++ b/protocol/src/test/java/org/opensearch/sql/protocol/response/format/CommandResponseFormatterTest.java @@ -0,0 +1,59 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.protocol.response.format; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.mock; +import static org.opensearch.sql.data.model.ExprValueUtils.tupleValue; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.protocol.response.format.JsonResponseFormatter.Style.PRETTY; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.executor.pagination.Cursor; +import org.opensearch.sql.opensearch.data.type.OpenSearchTextType; +import org.opensearch.sql.protocol.response.QueryResult; + +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +public class CommandResponseFormatterTest { + + @Test + public void produces_always_same_output_for_any_query_response() { + var formatter = new CommandResponseFormatter(); + assertEquals(formatter.format(mock(QueryResult.class)), + formatter.format(mock(QueryResult.class))); + + QueryResult response = new QueryResult( + new ExecutionEngine.Schema(ImmutableList.of( + new ExecutionEngine.Schema.Column("name", "name", STRING), + new ExecutionEngine.Schema.Column("address", "address", OpenSearchTextType.of()), + new ExecutionEngine.Schema.Column("age", "age", INTEGER))), + ImmutableList.of( + tupleValue(ImmutableMap.builder() + .put("name", "John") + .put("address", "Seattle") + .put("age", 20) + .build())), + new Cursor("test_cursor")); + + assertEquals("{\n" + + " \"succeeded\": true\n" + + "}", + formatter.format(response)); + } + + @Test + public void formats_error_as_default_formatter() { + var exception = new Exception("pewpew", new RuntimeException("meow meow")); + assertEquals(new JdbcResponseFormatter(PRETTY).format(exception), + new CommandResponseFormatter().format(exception)); + } +} diff --git a/protocol/src/test/java/org/opensearch/sql/protocol/response/format/JdbcResponseFormatterTest.java b/protocol/src/test/java/org/opensearch/sql/protocol/response/format/JdbcResponseFormatterTest.java index a6671c66f8..9c79b1bf89 100644 --- a/protocol/src/test/java/org/opensearch/sql/protocol/response/format/JdbcResponseFormatterTest.java +++ b/protocol/src/test/java/org/opensearch/sql/protocol/response/format/JdbcResponseFormatterTest.java @@ -31,6 +31,7 @@ import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.executor.pagination.Cursor; import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; import org.opensearch.sql.opensearch.data.type.OpenSearchTextType; import org.opensearch.sql.protocol.response.QueryResult; @@ -83,6 +84,37 @@ void format_response() { formatter.format(response)); } + @Test + void format_response_with_cursor() { + QueryResult response = new QueryResult( + new Schema(ImmutableList.of( + new Column("name", "name", STRING), + new Column("address", "address", OpenSearchTextType.of()), + new Column("age", "age", INTEGER))), + ImmutableList.of( + tupleValue(ImmutableMap.builder() + .put("name", "John") + .put("address", "Seattle") + .put("age", 20) + .build())), + new Cursor("test_cursor")); + + assertJsonEquals( + "{" + + "\"schema\":[" + + "{\"name\":\"name\",\"alias\":\"name\",\"type\":\"keyword\"}," + + "{\"name\":\"address\",\"alias\":\"address\",\"type\":\"text\"}," + + "{\"name\":\"age\",\"alias\":\"age\",\"type\":\"integer\"}" + + "]," + + "\"datarows\":[" + + "[\"John\",\"Seattle\",20]]," + + "\"total\":1," + + "\"size\":1," + + "\"cursor\":\"test_cursor\"," + + "\"status\":200}", + formatter.format(response)); + } + @Test void format_response_with_missing_and_null_value() { QueryResult response = diff --git a/sql/src/main/java/org/opensearch/sql/sql/SQLService.java b/sql/src/main/java/org/opensearch/sql/sql/SQLService.java index 082a3e9581..91ec00cdd5 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/SQLService.java +++ b/sql/src/main/java/org/opensearch/sql/sql/SQLService.java @@ -65,16 +65,33 @@ private AbstractPlan plan( SQLQueryRequest request, Optional> queryListener, Optional> explainListener) { - // 1.Parse query and convert parse tree (CST) to abstract syntax tree (AST) - ParseTree cst = parser.parse(request.getQuery()); - Statement statement = - cst.accept( - new AstStatementBuilder( - new AstBuilder(request.getQuery()), - AstStatementBuilder.StatementBuilderContext.builder() - .isExplain(request.isExplainRequest()) - .build())); + boolean isExplainRequest = request.isExplainRequest(); + if (request.getCursor().isPresent()) { + // Handle v2 cursor here -- legacy cursor was handled earlier. + if (isExplainRequest) { + throw new UnsupportedOperationException("Explain of a paged query continuation " + + "is not supported. Use `explain` for the initial query request."); + } + if (request.isCursorCloseRequest()) { + return queryExecutionFactory.createCloseCursor(request.getCursor().get(), + queryListener.orElse(null)); + } + return queryExecutionFactory.create(request.getCursor().get(), + isExplainRequest, queryListener.orElse(null), explainListener.orElse(null)); + } else { + // 1.Parse query and convert parse tree (CST) to abstract syntax tree (AST) + ParseTree cst = parser.parse(request.getQuery()); + Statement statement = + cst.accept( + new AstStatementBuilder( + new AstBuilder(request.getQuery()), + AstStatementBuilder.StatementBuilderContext.builder() + .isExplain(isExplainRequest) + .fetchSize(request.getFetchSize()) + .build())); - return queryExecutionFactory.create(statement, queryListener, explainListener); + return queryExecutionFactory.create( + statement, queryListener, explainListener); + } } } diff --git a/sql/src/main/java/org/opensearch/sql/sql/domain/SQLQueryRequest.java b/sql/src/main/java/org/opensearch/sql/sql/domain/SQLQueryRequest.java index 508f80cee4..c9321f5775 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/domain/SQLQueryRequest.java +++ b/sql/src/main/java/org/opensearch/sql/sql/domain/SQLQueryRequest.java @@ -6,13 +6,12 @@ package org.opensearch.sql.sql.domain; -import com.google.common.base.Strings; -import com.google.common.collect.ImmutableSet; import java.util.Collections; import java.util.Locale; import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.stream.Stream; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -28,9 +27,9 @@ @EqualsAndHashCode @RequiredArgsConstructor public class SQLQueryRequest { - - private static final Set SUPPORTED_FIELDS = ImmutableSet.of( - "query", "fetch_size", "parameters"); + private static final String QUERY_FIELD_CURSOR = "cursor"; + private static final Set SUPPORTED_FIELDS = Set.of( + "query", "fetch_size", "parameters", QUERY_FIELD_CURSOR); private static final String QUERY_PARAMS_FORMAT = "format"; private static final String QUERY_PARAMS_SANITIZE = "sanitize"; @@ -64,41 +63,59 @@ public class SQLQueryRequest { @Accessors(fluent = true) private boolean sanitize = true; + private String cursor; + /** * Constructor of SQLQueryRequest that passes request params. */ - public SQLQueryRequest( - JSONObject jsonContent, String query, String path, Map params) { + public SQLQueryRequest(JSONObject jsonContent, String query, String path, + Map params, String cursor) { this.jsonContent = jsonContent; this.query = query; this.path = path; this.params = params; this.format = getFormat(params); this.sanitize = shouldSanitize(params); + this.cursor = cursor; } /** * Pre-check if the request can be supported by meeting ALL the following criteria: * 1.Only supported fields present in request body, ex. "filter" and "cursor" are not supported - * 2.No fetch_size or "fetch_size=0". In other word, it's not a cursor request - * 3.Response format is default or can be supported. + * 2.Response format is default or can be supported. * - * @return true if supported. + * @return true if supported. */ public boolean isSupported() { - return isOnlySupportedFieldInPayload() - && isFetchSizeZeroIfPresent() - && isSupportedFormat(); + var noCursor = !isCursor(); + var noQuery = query == null; + var noUnsupportedParams = params.isEmpty() + || (params.size() == 1 && params.containsKey(QUERY_PARAMS_FORMAT)); + var noContent = jsonContent == null || jsonContent.isEmpty(); + + return ((!noCursor && noQuery + && noUnsupportedParams && noContent) // if cursor is given, but other things + || (noCursor && !noQuery)) // or if cursor is not given, but query + && isOnlySupportedFieldInPayload() // and request has supported fields only + && isSupportedFormat(); // and request is in supported format + } + + private boolean isCursor() { + return cursor != null && !cursor.isEmpty(); } /** * Check if request is to explain rather than execute the query. - * @return true if it is a explain request + * @return true if it is an explain request */ public boolean isExplainRequest() { return path.endsWith("/_explain"); } + public boolean isCursorCloseRequest() { + return path.endsWith("/close"); + } + /** * Decide on the formatter by the requested format. */ @@ -113,23 +130,23 @@ public Format format() { } private boolean isOnlySupportedFieldInPayload() { - return SUPPORTED_FIELDS.containsAll(jsonContent.keySet()); + return jsonContent == null || SUPPORTED_FIELDS.containsAll(jsonContent.keySet()); + } + + public Optional getCursor() { + return Optional.ofNullable(cursor); } - private boolean isFetchSizeZeroIfPresent() { - return (jsonContent.optInt("fetch_size") == 0); + public int getFetchSize() { + return jsonContent.optInt("fetch_size"); } private boolean isSupportedFormat() { - return Strings.isNullOrEmpty(format) || "jdbc".equalsIgnoreCase(format) - || "csv".equalsIgnoreCase(format) || "raw".equalsIgnoreCase(format); + return Stream.of("csv", "jdbc", "raw").anyMatch(format::equalsIgnoreCase); } private String getFormat(Map params) { - if (params.containsKey(QUERY_PARAMS_FORMAT)) { - return params.get(QUERY_PARAMS_FORMAT); - } - return "jdbc"; + return params.getOrDefault(QUERY_PARAMS_FORMAT, "jdbc"); } private boolean shouldSanitize(Map params) { diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstStatementBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstStatementBuilder.java index 40d549764a..593e7b51ff 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstStatementBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstStatementBuilder.java @@ -26,7 +26,7 @@ public class AstStatementBuilder extends OpenSearchSQLParserBaseVisitor { - ResponseListener listener = invocation.getArgument(1); - listener.onResponse(new QueryResponse(schema, Collections.emptyList())); - return null; - }).when(queryService).execute(any(), any()); - + public void can_execute_sql_query() { sqlService.execute( new SQLQueryRequest(new JSONObject(), "SELECT 123", QUERY, "jdbc"), - new ResponseListener() { + new ResponseListener<>() { @Override public void onResponse(QueryResponse response) { assertNotNull(response); @@ -84,13 +78,42 @@ public void onFailure(Exception e) { } @Test - public void canExecuteCsvFormatRequest() { - doAnswer(invocation -> { - ResponseListener listener = invocation.getArgument(1); - listener.onResponse(new QueryResponse(schema, Collections.emptyList())); - return null; - }).when(queryService).execute(any(), any()); + public void can_execute_cursor_query() { + sqlService.execute( + new SQLQueryRequest(new JSONObject(), null, QUERY, Map.of("format", "jdbc"), "n:cursor"), + new ResponseListener<>() { + @Override + public void onResponse(QueryResponse response) { + assertNotNull(response); + } + + @Override + public void onFailure(Exception e) { + fail(e); + } + }); + } + + @Test + public void can_execute_close_cursor_query() { + sqlService.execute( + new SQLQueryRequest(new JSONObject(), null, QUERY + "/close", + Map.of("format", "jdbc"), "n:cursor"), + new ResponseListener<>() { + @Override + public void onResponse(QueryResponse response) { + assertNotNull(response); + } + + @Override + public void onFailure(Exception e) { + fail(e); + } + }); + } + @Test + public void can_execute_csv_format_request() { sqlService.execute( new SQLQueryRequest(new JSONObject(), "SELECT 123", QUERY, "csv"), new ResponseListener() { @@ -107,7 +130,7 @@ public void onFailure(Exception e) { } @Test - public void canExplainSqlQuery() { + public void can_explain_sql_query() { doAnswer(invocation -> { ResponseListener listener = invocation.getArgument(1); listener.onResponse(new ExplainResponse(new ExplainResponseNode("Test"))); @@ -129,7 +152,25 @@ public void onFailure(Exception e) { } @Test - public void canCaptureErrorDuringExecution() { + public void cannot_explain_cursor_query() { + sqlService.explain(new SQLQueryRequest(new JSONObject(), null, EXPLAIN, + Map.of("format", "jdbc"), "n:cursor"), + new ResponseListener() { + @Override + public void onResponse(ExplainResponse response) { + fail(response.toString()); + } + + @Override + public void onFailure(Exception e) { + assertEquals("Explain of a paged query continuation is not supported." + + " Use `explain` for the initial query request.", e.getMessage()); + } + }); + } + + @Test + public void can_capture_error_during_execution() { sqlService.execute( new SQLQueryRequest(new JSONObject(), "SELECT", QUERY, ""), new ResponseListener() { @@ -146,7 +187,7 @@ public void onFailure(Exception e) { } @Test - public void canCaptureErrorDuringExplain() { + public void can_capture_error_during_explain() { sqlService.explain( new SQLQueryRequest(new JSONObject(), "SELECT", EXPLAIN, ""), new ResponseListener() { @@ -161,5 +202,4 @@ public void onFailure(Exception e) { } }); } - } diff --git a/sql/src/test/java/org/opensearch/sql/sql/domain/SQLQueryRequestTest.java b/sql/src/test/java/org/opensearch/sql/sql/domain/SQLQueryRequestTest.java index 52a1f534e9..1ffa4f0fa8 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/domain/SQLQueryRequestTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/domain/SQLQueryRequestTest.java @@ -6,36 +6,43 @@ package org.opensearch.sql.sql.domain; +import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import com.google.common.collect.ImmutableMap; +import java.util.HashMap; import java.util.Map; import org.json.JSONObject; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.opensearch.sql.protocol.response.format.Format; +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) public class SQLQueryRequestTest { @Test - public void shouldSupportQuery() { + public void should_support_query() { SQLQueryRequest request = SQLQueryRequestBuilder.request("SELECT 1").build(); assertTrue(request.isSupported()); } @Test - public void shouldSupportQueryWithJDBCFormat() { + public void should_support_query_with_JDBC_format() { SQLQueryRequest request = SQLQueryRequestBuilder.request("SELECT 1") .format("jdbc") .build(); - assertTrue(request.isSupported()); - assertEquals(request.format(), Format.JDBC); + assertAll( + () -> assertTrue(request.isSupported()), + () -> assertEquals(request.format(), Format.JDBC) + ); } @Test - public void shouldSupportQueryWithQueryFieldOnly() { + public void should_support_query_with_query_field_only() { SQLQueryRequest request = SQLQueryRequestBuilder.request("SELECT 1") .jsonContent("{\"query\": \"SELECT 1\"}") @@ -44,16 +51,32 @@ public void shouldSupportQueryWithQueryFieldOnly() { } @Test - public void shouldSupportQueryWithParameters() { - SQLQueryRequest request = + public void should_support_query_with_parameters() { + SQLQueryRequest requestWithContent = SQLQueryRequestBuilder.request("SELECT 1") .jsonContent("{\"query\": \"SELECT 1\", \"parameters\":[]}") .build(); - assertTrue(request.isSupported()); + SQLQueryRequest requestWithParams = + SQLQueryRequestBuilder.request("SELECT 1") + .params(Map.of("one", "two")) + .build(); + assertAll( + () -> assertTrue(requestWithContent.isSupported()), + () -> assertTrue(requestWithParams.isSupported()) + ); } @Test - public void shouldSupportQueryWithZeroFetchSize() { + public void should_support_query_without_parameters() { + SQLQueryRequest requestWithNoParams = + SQLQueryRequestBuilder.request("SELECT 1") + .params(Map.of()) + .build(); + assertTrue(requestWithNoParams.isSupported()); + } + + @Test + public void should_support_query_with_zero_fetch_size() { SQLQueryRequest request = SQLQueryRequestBuilder.request("SELECT 1") .jsonContent("{\"query\": \"SELECT 1\", \"fetch_size\": 0}") @@ -62,7 +85,7 @@ public void shouldSupportQueryWithZeroFetchSize() { } @Test - public void shouldSupportQueryWithParametersAndZeroFetchSize() { + public void should_support_query_with_parameters_and_zero_fetch_size() { SQLQueryRequest request = SQLQueryRequestBuilder.request("SELECT 1") .jsonContent("{\"query\": \"SELECT 1\", \"fetch_size\": 0, \"parameters\":[]}") @@ -71,70 +94,184 @@ public void shouldSupportQueryWithParametersAndZeroFetchSize() { } @Test - public void shouldSupportExplain() { + public void should_support_explain() { SQLQueryRequest explainRequest = SQLQueryRequestBuilder.request("SELECT 1") .path("_plugins/_sql/_explain") .build(); - assertTrue(explainRequest.isExplainRequest()); - assertTrue(explainRequest.isSupported()); + + assertAll( + () -> assertTrue(explainRequest.isExplainRequest()), + () -> assertTrue(explainRequest.isSupported()) + ); } @Test - public void shouldNotSupportCursorRequest() { + public void should_support_cursor_request() { SQLQueryRequest fetchSizeRequest = SQLQueryRequestBuilder.request("SELECT 1") .jsonContent("{\"query\": \"SELECT 1\", \"fetch_size\": 5}") .build(); - assertFalse(fetchSizeRequest.isSupported()); SQLQueryRequest cursorRequest = + SQLQueryRequestBuilder.request(null) + .cursor("abcdefgh...") + .build(); + + assertAll( + () -> assertTrue(fetchSizeRequest.isSupported()), + () -> assertTrue(cursorRequest.isSupported()) + ); + } + + @Test + public void should_support_cursor_close_request() { + SQLQueryRequest closeRequest = + SQLQueryRequestBuilder.request(null) + .cursor("pewpew") + .path("_plugins/_sql/close") + .build(); + + SQLQueryRequest emptyCloseRequest = + SQLQueryRequestBuilder.request(null) + .cursor("") + .path("_plugins/_sql/close") + .build(); + + SQLQueryRequest pagingRequest = + SQLQueryRequestBuilder.request(null) + .cursor("pewpew") + .build(); + + assertAll( + () -> assertTrue(closeRequest.isSupported()), + () -> assertTrue(closeRequest.isCursorCloseRequest()), + () -> assertTrue(pagingRequest.isSupported()), + () -> assertFalse(pagingRequest.isCursorCloseRequest()), + () -> assertFalse(emptyCloseRequest.isSupported()), + () -> assertTrue(emptyCloseRequest.isCursorCloseRequest()) + ); + } + + @Test + public void should_not_support_request_with_empty_cursor() { + SQLQueryRequest requestWithEmptyCursor = + SQLQueryRequestBuilder.request(null) + .cursor("") + .build(); + SQLQueryRequest requestWithNullCursor = + SQLQueryRequestBuilder.request(null) + .cursor(null) + .build(); + assertAll( + () -> assertFalse(requestWithEmptyCursor.isSupported()), + () -> assertFalse(requestWithNullCursor.isSupported()) + ); + } + + @Test + public void should_not_support_request_with_unknown_field() { + SQLQueryRequest request = + SQLQueryRequestBuilder.request("SELECT 1") + .jsonContent("{\"pewpew\": 42}") + .build(); + assertFalse(request.isSupported()); + } + + @Test + public void should_not_support_request_with_cursor_and_something_else() { + SQLQueryRequest requestWithQuery = SQLQueryRequestBuilder.request("SELECT 1") - .jsonContent("{\"cursor\": \"abcdefgh...\"}") + .cursor("n:12356") .build(); - assertFalse(cursorRequest.isSupported()); + SQLQueryRequest requestWithParams = + SQLQueryRequestBuilder.request(null) + .cursor("n:12356") + .params(Map.of("one", "two")) + .build(); + SQLQueryRequest requestWithParamsWithFormat = + SQLQueryRequestBuilder.request(null) + .cursor("n:12356") + .params(Map.of("format", "jdbc")) + .build(); + SQLQueryRequest requestWithParamsWithFormatAnd = + SQLQueryRequestBuilder.request(null) + .cursor("n:12356") + .params(Map.of("format", "jdbc", "something", "else")) + .build(); + SQLQueryRequest requestWithFetchSize = + SQLQueryRequestBuilder.request(null) + .cursor("n:12356") + .jsonContent("{\"fetch_size\": 5}") + .build(); + SQLQueryRequest requestWithNoParams = + SQLQueryRequestBuilder.request(null) + .cursor("n:12356") + .params(Map.of()) + .build(); + SQLQueryRequest requestWithNoContent = + SQLQueryRequestBuilder.request(null) + .cursor("n:12356") + .jsonContent("{}") + .build(); + assertAll( + () -> assertFalse(requestWithQuery.isSupported()), + () -> assertFalse(requestWithParams.isSupported()), + () -> assertFalse(requestWithFetchSize.isSupported()), + () -> assertTrue(requestWithNoParams.isSupported()), + () -> assertTrue(requestWithParamsWithFormat.isSupported()), + () -> assertFalse(requestWithParamsWithFormatAnd.isSupported()), + () -> assertTrue(requestWithNoContent.isSupported()) + ); } @Test - public void shouldUseJDBCFormatByDefault() { + public void should_use_JDBC_format_by_default() { SQLQueryRequest request = SQLQueryRequestBuilder.request("SELECT 1").params(ImmutableMap.of()).build(); assertEquals(request.format(), Format.JDBC); } @Test - public void shouldSupportCSVFormatAndSanitize() { + public void should_support_CSV_format_and_sanitize() { SQLQueryRequest csvRequest = SQLQueryRequestBuilder.request("SELECT 1") .format("csv") .build(); - assertTrue(csvRequest.isSupported()); - assertEquals(csvRequest.format(), Format.CSV); - assertTrue(csvRequest.sanitize()); + assertAll( + () -> assertTrue(csvRequest.isSupported()), + () -> assertEquals(csvRequest.format(), Format.CSV), + () -> assertTrue(csvRequest.sanitize()) + ); } @Test - public void shouldSkipSanitizeIfSetFalse() { + public void should_skip_sanitize_if_set_false() { ImmutableMap.Builder builder = ImmutableMap.builder(); Map params = builder.put("format", "csv").put("sanitize", "false").build(); SQLQueryRequest csvRequest = SQLQueryRequestBuilder.request("SELECT 1").params(params).build(); - assertEquals(csvRequest.format(), Format.CSV); - assertFalse(csvRequest.sanitize()); + assertAll( + () -> assertEquals(csvRequest.format(), Format.CSV), + () -> assertFalse(csvRequest.sanitize()) + ); } @Test - public void shouldNotSupportOtherFormat() { + public void should_not_support_other_format() { SQLQueryRequest csvRequest = SQLQueryRequestBuilder.request("SELECT 1") .format("other") .build(); - assertFalse(csvRequest.isSupported()); - assertThrows(IllegalArgumentException.class, csvRequest::format, - "response in other format is not supported."); + + assertAll( + () -> assertFalse(csvRequest.isSupported()), + () -> assertEquals("response in other format is not supported.", + assertThrows(IllegalArgumentException.class, csvRequest::format).getMessage()) + ); } @Test - public void shouldSupportRawFormat() { + public void should_support_raw_format() { SQLQueryRequest csvRequest = SQLQueryRequestBuilder.request("SELECT 1") .format("raw") @@ -150,7 +287,8 @@ private static class SQLQueryRequestBuilder { private String query; private String path = "_plugins/_sql"; private String format; - private Map params; + private String cursor; + private Map params = new HashMap<>(); static SQLQueryRequestBuilder request(String query) { SQLQueryRequestBuilder builder = new SQLQueryRequestBuilder(); @@ -178,14 +316,17 @@ SQLQueryRequestBuilder params(Map params) { return this; } + SQLQueryRequestBuilder cursor(String cursor) { + this.cursor = cursor; + return this; + } + SQLQueryRequest build() { - if (jsonContent == null) { - jsonContent = "{\"query\": \"" + query + "\"}"; - } - if (params != null) { - return new SQLQueryRequest(new JSONObject(jsonContent), query, path, params); + if (format != null) { + params.put("format", format); } - return new SQLQueryRequest(new JSONObject(jsonContent), query, path, format); + return new SQLQueryRequest(jsonContent == null ? null : new JSONObject(jsonContent), + query, path, params, cursor); } }