Skip to content

Commit

Permalink
Add trendline PPL command (#3071)
Browse files Browse the repository at this point in the history
* Add trendline (With SWA) PPL command

---------

Signed-off-by: James Duong <[email protected]>
Signed-off-by: Andrew Carbonetto <[email protected]>
Co-authored-by: Andrew Carbonetto <[email protected]>
  • Loading branch information
jduo and acarbonetto authored Dec 12, 2024
1 parent 3e2cb1d commit ed0ca8d
Show file tree
Hide file tree
Showing 33 changed files with 1,601 additions and 23 deletions.
94 changes: 77 additions & 17 deletions core/src/main/java/org/opensearch/sql/analysis/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
import static org.opensearch.sql.ast.tree.Sort.NullOrder.NULL_LAST;
import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC;
import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC;
import static org.opensearch.sql.data.type.ExprCoreType.DATE;
import static org.opensearch.sql.data.type.ExprCoreType.STRUCT;
import static org.opensearch.sql.data.type.ExprCoreType.TIME;
import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALOUS;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALY_GRADE;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_SCORE;
Expand All @@ -22,6 +25,7 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
Expand Down Expand Up @@ -62,6 +66,7 @@
import org.opensearch.sql.ast.tree.Sort;
import org.opensearch.sql.ast.tree.Sort.SortOption;
import org.opensearch.sql.ast.tree.TableFunction;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.ast.tree.Values;
import org.opensearch.sql.common.antlr.SyntaxCheckException;
Expand Down Expand Up @@ -100,6 +105,7 @@
import org.opensearch.sql.planner.logical.LogicalRemove;
import org.opensearch.sql.planner.logical.LogicalRename;
import org.opensearch.sql.planner.logical.LogicalSort;
import org.opensearch.sql.planner.logical.LogicalTrendline;
import org.opensearch.sql.planner.logical.LogicalValues;
import org.opensearch.sql.planner.physical.datasource.DataSourceTable;
import org.opensearch.sql.storage.Table;
Expand Down Expand Up @@ -469,23 +475,7 @@ public LogicalPlan visitParse(Parse node, AnalysisContext context) {
@Override
public LogicalPlan visitSort(Sort node, AnalysisContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
ExpressionReferenceOptimizer optimizer =
new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child);

List<Pair<SortOption, Expression>> sortList =
node.getSortList().stream()
.map(
sortField -> {
var analyzed = expressionAnalyzer.analyze(sortField.getField(), context);
if (analyzed == null) {
throw new UnsupportedOperationException(
String.format("Invalid use of expression %s", sortField.getField()));
}
Expression expression = optimizer.optimize(analyzed, context);
return ImmutablePair.of(analyzeSortOption(sortField.getFieldArgs()), expression);
})
.collect(Collectors.toList());
return new LogicalSort(child, sortList);
return buildSort(child, context, node.getSortList());
}

/** Build {@link LogicalDedupe}. */
Expand Down Expand Up @@ -594,6 +584,55 @@ public LogicalPlan visitML(ML node, AnalysisContext context) {
return new LogicalML(child, node.getArguments());
}

/** Build {@link LogicalTrendline} for Trendline command. */
@Override
public LogicalPlan visitTrendline(Trendline node, AnalysisContext context) {
final LogicalPlan child = node.getChild().get(0).accept(this, context);

final TypeEnvironment currEnv = context.peek();
final List<Trendline.TrendlineComputation> computations = node.getComputations();
final ImmutableList.Builder<Pair<Trendline.TrendlineComputation, ExprCoreType>>
computationsAndTypes = ImmutableList.builder();
computations.forEach(
computation -> {
final Expression resolvedField =
expressionAnalyzer.analyze(computation.getDataField(), context);
final ExprCoreType averageType;
// Duplicate the semantics of AvgAggregator#create():
// - All numerical types have the DOUBLE type for the moving average.
// - All datetime types have the same datetime type for the moving average.
if (ExprCoreType.numberTypes().contains(resolvedField.type())) {
averageType = ExprCoreType.DOUBLE;
} else {
switch (resolvedField.type()) {
case DATE:
case TIME:
case TIMESTAMP:
averageType = (ExprCoreType) resolvedField.type();
break;
default:
throw new SemanticCheckException(
String.format(
"Invalid field used for trendline computation %s. Source field %s had type"
+ " %s but must be a numerical or datetime field.",
computation.getAlias(),
computation.getDataField().getChild().get(0),
resolvedField.type().typeName()));
}
}
currEnv.define(new Symbol(Namespace.FIELD_NAME, computation.getAlias()), averageType);
computationsAndTypes.add(Pair.of(computation, averageType));
});

if (node.getSortByField().isEmpty()) {
return new LogicalTrendline(child, computationsAndTypes.build());
}

return new LogicalTrendline(
buildSort(child, context, Collections.singletonList(node.getSortByField().get())),
computationsAndTypes.build());
}

@Override
public LogicalPlan visitPaginate(Paginate paginate, AnalysisContext context) {
LogicalPlan child = paginate.getChild().get(0).accept(this, context);
Expand All @@ -612,6 +651,27 @@ public LogicalPlan visitCloseCursor(CloseCursor closeCursor, AnalysisContext con
return new LogicalCloseCursor(closeCursor.getChild().get(0).accept(this, context));
}

private LogicalSort buildSort(
LogicalPlan child, AnalysisContext context, List<Field> sortFields) {
ExpressionReferenceOptimizer optimizer =
new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child);

List<Pair<SortOption, Expression>> sortList =
sortFields.stream()
.map(
sortField -> {
var analyzed = expressionAnalyzer.analyze(sortField.getField(), context);
if (analyzed == null) {
throw new UnsupportedOperationException(
String.format("Invalid use of expression %s", sortField.getField()));
}
Expression expression = optimizer.optimize(analyzed, context);
return ImmutablePair.of(analyzeSortOption(sortField.getFieldArgs()), expression);
})
.collect(Collectors.toList());
return new LogicalSort(child, sortList);
}

/**
* The first argument is always "asc", others are optional. Given nullFirst argument, use its
* value. Otherwise just use DEFAULT_ASC/DESC.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import org.opensearch.sql.ast.tree.Rename;
import org.opensearch.sql.ast.tree.Sort;
import org.opensearch.sql.ast.tree.TableFunction;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.ast.tree.Values;

/** AST nodes visitor Defines the traverse path. */
Expand Down Expand Up @@ -110,6 +111,14 @@ public T visitFilter(Filter node, C context) {
return visitChildren(node, context);
}

public T visitTrendline(Trendline node, C context) {
return visitChildren(node, context);
}

public T visitTrendlineComputation(Trendline.TrendlineComputation node, C context) {
return visitChildren(node, context);
}

public T visitProject(Project node, C context) {
return visitChildren(node, context);
}
Expand Down
14 changes: 14 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.google.common.collect.ImmutableList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import lombok.experimental.UtilityClass;
import org.apache.commons.lang3.tuple.ImmutablePair;
Expand Down Expand Up @@ -62,6 +63,7 @@
import org.opensearch.sql.ast.tree.Sort;
import org.opensearch.sql.ast.tree.Sort.SortOption;
import org.opensearch.sql.ast.tree.TableFunction;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.ast.tree.Values;

Expand Down Expand Up @@ -466,6 +468,18 @@ public static Limit limit(UnresolvedPlan input, Integer limit, Integer offset) {
return new Limit(limit, offset).attach(input);
}

public static Trendline trendline(
UnresolvedPlan input,
Optional<Field> sortField,
Trendline.TrendlineComputation... computations) {
return new Trendline(sortField, Arrays.asList(computations)).attach(input);
}

public static Trendline.TrendlineComputation computation(
Integer numDataPoints, Field dataField, String alias, Trendline.TrendlineType type) {
return new Trendline.TrendlineComputation(numDataPoints, dataField, alias, type);
}

public static Parse parse(
UnresolvedPlan input,
ParseMethod parseMethod,
Expand Down
71 changes: 71 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.tree;

import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Optional;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.Node;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.UnresolvedExpression;

@ToString
@Getter
@RequiredArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Trendline extends UnresolvedPlan {

private UnresolvedPlan child;
private final Optional<Field> sortByField;
private final List<TrendlineComputation> computations;

@Override
public Trendline attach(UnresolvedPlan child) {
this.child = child;
return this;
}

@Override
public List<? extends Node> getChild() {
return ImmutableList.of(child);
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> visitor, C context) {
return visitor.visitTrendline(this, context);
}

@Getter
public static class TrendlineComputation extends UnresolvedExpression {

private final Integer numberOfDataPoints;
private final Field dataField;
private final String alias;
private final TrendlineType computationType;

public TrendlineComputation(
Integer numberOfDataPoints, Field dataField, String alias, TrendlineType computationType) {
this.numberOfDataPoints = numberOfDataPoints;
this.dataField = dataField;
this.alias = alias;
this.computationType = computationType;
}

@Override
public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
return nodeVisitor.visitTrendlineComputation(this, context);
}
}

public enum TrendlineType {
SMA
}
}
32 changes: 32 additions & 0 deletions core/src/main/java/org/opensearch/sql/executor/Explain.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.ast.tree.Sort;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.executor.ExecutionEngine.ExplainResponse;
import org.opensearch.sql.executor.ExecutionEngine.ExplainResponseNode;
import org.opensearch.sql.expression.Expression;
Expand All @@ -31,6 +33,7 @@
import org.opensearch.sql.planner.physical.RenameOperator;
import org.opensearch.sql.planner.physical.SortOperator;
import org.opensearch.sql.planner.physical.TakeOrderedOperator;
import org.opensearch.sql.planner.physical.TrendlineOperator;
import org.opensearch.sql.planner.physical.ValuesOperator;
import org.opensearch.sql.planner.physical.WindowOperator;
import org.opensearch.sql.storage.TableScanOperator;
Expand Down Expand Up @@ -211,6 +214,21 @@ public ExplainResponseNode visitNested(NestedOperator node, Object context) {
explanNode -> explanNode.setDescription(ImmutableMap.of("nested", node.getFields())));
}

@Override
public ExplainResponseNode visitTrendline(TrendlineOperator node, Object context) {
return explain(
node,
context,
explainNode ->
explainNode.setDescription(
ImmutableMap.of(
"computations",
describeTrendlineComputations(
node.getComputations().stream()
.map(Pair::getKey)
.collect(Collectors.toList())))));
}

protected ExplainResponseNode explain(
PhysicalPlan node, Object context, Consumer<ExplainResponseNode> doExplain) {
ExplainResponseNode explainNode = new ExplainResponseNode(getOperatorName(node));
Expand Down Expand Up @@ -245,4 +263,18 @@ private Map<String, Map<String, String>> describeSortList(
"sortOrder", p.getLeft().getSortOrder().toString(),
"nullOrder", p.getLeft().getNullOrder().toString())));
}

private List<Map<String, String>> describeTrendlineComputations(
List<Trendline.TrendlineComputation> computations) {
return computations.stream()
.map(
computation ->
ImmutableMap.of(
"computationType",
computation.getComputationType().name().toLowerCase(Locale.ROOT),
"numberOfDataPoints", computation.getNumberOfDataPoints().toString(),
"dataField", computation.getDataField().getChild().get(0).toString(),
"alias", computation.getAlias()))
.collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.opensearch.sql.planner.logical.LogicalRemove;
import org.opensearch.sql.planner.logical.LogicalRename;
import org.opensearch.sql.planner.logical.LogicalSort;
import org.opensearch.sql.planner.logical.LogicalTrendline;
import org.opensearch.sql.planner.logical.LogicalValues;
import org.opensearch.sql.planner.logical.LogicalWindow;
import org.opensearch.sql.planner.physical.AggregationOperator;
Expand All @@ -39,6 +40,7 @@
import org.opensearch.sql.planner.physical.RenameOperator;
import org.opensearch.sql.planner.physical.SortOperator;
import org.opensearch.sql.planner.physical.TakeOrderedOperator;
import org.opensearch.sql.planner.physical.TrendlineOperator;
import org.opensearch.sql.planner.physical.ValuesOperator;
import org.opensearch.sql.planner.physical.WindowOperator;
import org.opensearch.sql.storage.read.TableScanBuilder;
Expand Down Expand Up @@ -166,6 +168,11 @@ public PhysicalPlan visitCloseCursor(LogicalCloseCursor node, C context) {
return new CursorCloseOperator(visitChild(node, context));
}

@Override
public PhysicalPlan visitTrendline(LogicalTrendline plan, C context) {
return new TrendlineOperator(visitChild(plan, context), plan.getComputations());
}

// Called when paging query requested without `FROM` clause only
@Override
public PhysicalPlan visitPaginate(LogicalPaginate plan, C context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.tree.RareTopN.CommandType;
import org.opensearch.sql.ast.tree.Sort.SortOption;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.LiteralExpression;
import org.opensearch.sql.expression.NamedExpression;
Expand Down Expand Up @@ -130,6 +132,11 @@ public static LogicalPlan rareTopN(
return new LogicalRareTopN(input, commandType, noOfResults, Arrays.asList(fields), groupByList);
}

public static LogicalTrendline trendline(
LogicalPlan input, Pair<Trendline.TrendlineComputation, ExprCoreType>... computations) {
return new LogicalTrendline(input, Arrays.asList(computations));
}

@SafeVarargs
public LogicalPlan values(List<LiteralExpression>... values) {
return new LogicalValues(Arrays.asList(values));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ public R visitAD(LogicalAD plan, C context) {
return visitNode(plan, context);
}

public R visitTrendline(LogicalTrendline plan, C context) {
return visitNode(plan, context);
}

public R visitPaginate(LogicalPaginate plan, C context) {
return visitNode(plan, context);
}
Expand Down
Loading

0 comments on commit ed0ca8d

Please sign in to comment.