Skip to content

Commit

Permalink
Physical plan tests
Browse files Browse the repository at this point in the history
Signed-off-by: James Duong <[email protected]>
  • Loading branch information
jduo committed Oct 25, 2024
1 parent f8c1b2e commit 8c8cb3e
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package org.opensearch.sql.planner.physical;

import com.google.common.base.Preconditions;
import com.google.common.collect.EvictingQueue;
import com.google.common.collect.ImmutableMap.Builder;
import java.util.Collections;
Expand All @@ -17,7 +16,6 @@
import lombok.ToString;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.data.model.ExprIntegerValue;
import org.opensearch.sql.data.model.ExprNullValue;
import org.opensearch.sql.data.model.ExprTupleValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
Expand All @@ -30,8 +28,8 @@
public class TrendlineOperator extends PhysicalPlan {
@Getter private final PhysicalPlan input;
@Getter private final List<Trendline.TrendlineComputation> computations;
private final List<TrendlineAccumulator> accumulators;
private final Map<String, Integer> fieldToIndexMap;
@EqualsAndHashCode.Exclude private final List<TrendlineAccumulator> accumulators;
@EqualsAndHashCode.Exclude private final Map<String, Integer> fieldToIndexMap;

public TrendlineOperator(PhysicalPlan input, List<Trendline.TrendlineComputation> computations) {
this.input = input;
Expand Down Expand Up @@ -61,7 +59,6 @@ public boolean hasNext() {

@Override
public ExprValue next() {
Preconditions.checkState(hasNext());
final ExprValue result;
final ExprValue next = input.next();
consumeInputTuple(next);
Expand All @@ -72,11 +69,13 @@ public ExprValue next() {
// Add calculated trendline values, which might overwrite existing fields from the input.
for (int i = 0; i < accumulators.size(); ++i) {
final ExprValue calculateResult = accumulators.get(i).calculate();
if (null != computations.get(i).getAlias()) {
mapBuilder.put(computations.get(i).getAlias(), calculateResult);
} else {
mapBuilder.put(
computations.get(i).getDataField().getChild().get(0).toString(), calculateResult);
if (null != calculateResult) {
if (null != computations.get(i).getAlias()) {
mapBuilder.put(computations.get(i).getAlias(), calculateResult);
} else {
mapBuilder.put(
computations.get(i).getDataField().getChild().get(0).toString(), calculateResult);
}
}
}
result = ExprTupleValue.fromExprValueMap(mapBuilder.buildKeepingLast());
Expand Down Expand Up @@ -172,7 +171,7 @@ public void accumulate(ExprValue value) {
@Override
public ExprValue calculate() {
if (receivedValues.size() < dataPointsNeeded.integerValue()) {
return ExprNullValue.of();
return null;
}
return runningAverage;
}
Expand Down
32 changes: 32 additions & 0 deletions core/src/test/java/org/opensearch/sql/executor/ExplainTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.values;
import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.window;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand All @@ -39,6 +40,7 @@
import org.junit.jupiter.api.DisplayNameGeneration;
import org.junit.jupiter.api.DisplayNameGenerator;
import org.junit.jupiter.api.Test;
import org.opensearch.sql.ast.dsl.AstDSL;
import org.opensearch.sql.ast.tree.Sort;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.executor.ExecutionEngine.ExplainResponse;
Expand All @@ -52,8 +54,11 @@
import org.opensearch.sql.expression.aggregation.NamedAggregator;
import org.opensearch.sql.expression.window.WindowDefinition;
import org.opensearch.sql.planner.physical.PhysicalPlan;
import org.opensearch.sql.planner.physical.TrendlineOperator;
import org.opensearch.sql.storage.TableScanOperator;

import com.google.common.collect.ImmutableMap;

@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class)
class ExplainTest extends ExpressionTestBase {

Expand Down Expand Up @@ -256,6 +261,33 @@ void can_explain_nested() {
explain.apply(plan));
}

@Test
void can_explain_trendline() {
PhysicalPlan plan = new TrendlineOperator(tableScan, Arrays.asList(
AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"),
AstDSL.computation(3, AstDSL.field("time"), "time_alias", "sma")));
assertEquals(
new ExplainResponse(
new ExplainResponseNode(
"TrendlineOperator",
ImmutableMap.of("computations", List.of(
ImmutableMap.of(
"computationType",
"sma",
"numberOfDataPoints", 2,
"dataField", "distance",
"alias", "distance_alias"),
ImmutableMap.of(
"computationType",
"sma",
"numberOfDataPoints", 3,
"dataField", "time",
"alias", "time_alias"))),
singletonList(tableScan.explainNode()))),
explain.apply(plan));
}


private static class FakeTableScan extends TableScanOperator {
@Override
public boolean hasNext() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.planner.physical;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.when;

import java.util.Arrays;
import java.util.Collections;

import org.junit.jupiter.api.DisplayNameGeneration;
import org.junit.jupiter.api.DisplayNameGenerator;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.sql.ast.dsl.AstDSL;
import org.opensearch.sql.data.model.ExprValueUtils;

import com.google.common.collect.ImmutableMap;

@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class)
@ExtendWith(MockitoExtension.class)
public class TrendlineOperatorTest {
@Mock private PhysicalPlan inputPlan;

@Test
public void calculates_simple_moving_average_one_field_one_sample() {
when(inputPlan.hasNext()).thenReturn(true, false);
when(inputPlan.next())
.thenReturn(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)));

var plan = new TrendlineOperator(inputPlan, Collections.singletonList(AstDSL.computation(
1, AstDSL.field("distance"), "distance_alias", "sma")));

plan.open();
assertTrue(plan.hasNext());
assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100)), plan.next());
}

@Test
public void calculates_simple_moving_average_one_field_two_samples() {
when(inputPlan.hasNext()).thenReturn(true, true, false);
when(inputPlan.next())
.thenReturn(
ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)),
ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)));


var plan = new TrendlineOperator(inputPlan, Collections.singletonList(AstDSL.computation(
2, AstDSL.field("distance"), "distance_alias", "sma")));

plan.open();
assertTrue(plan.hasNext());
assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next());
assertTrue(plan.hasNext());
assertEquals(plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)));
assertFalse(plan.hasNext());
}

@Test
public void calculates_simple_moving_average_one_field_two_samples_three_rows() {
when(inputPlan.hasNext()).thenReturn(true, true, true, false);
when(inputPlan.next())
.thenReturn(
ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)),
ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)),
ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)));

var plan = new TrendlineOperator(inputPlan, Collections.singletonList(AstDSL.computation(
2, AstDSL.field("distance"), "distance_alias", "sma")));

plan.open();
assertTrue(plan.hasNext());
assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next());
assertTrue(plan.hasNext());
assertEquals(plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)));
assertTrue(plan.hasNext());
assertEquals(plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)));
assertFalse(plan.hasNext());
}

@Test
public void calculates_simple_moving_average_multiple_computations() {
when(inputPlan.hasNext()).thenReturn(true, true, true, false);
when(inputPlan.next())
.thenReturn(
ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)),
ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20)),
ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20)));

var plan = new TrendlineOperator(inputPlan, Arrays.asList(
AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"),
AstDSL.computation(2, AstDSL.field("time"), "time_alias", "sma")));

plan.open();
assertTrue(plan.hasNext());
assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next());
assertTrue(plan.hasNext());
assertEquals(plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20, "distance_alias", 150.0, "time_alias", 15.0)));
assertTrue(plan.hasNext());
assertEquals(plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20, "distance_alias", 200.0, "time_alias", 20.0)));
assertFalse(plan.hasNext());
}

public void alias_overwrites_input_field() {
when(inputPlan.hasNext()).thenReturn(true, true, true, false);
when(inputPlan.next())
.thenReturn(
ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)),
ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)),
ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)));

var plan = new TrendlineOperator(inputPlan, Collections.singletonList(AstDSL.computation(
2, AstDSL.field("distance"), "time", "sma")));

plan.open();
assertTrue(plan.hasNext());
assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 100)), plan.next());
assertTrue(plan.hasNext());
assertEquals(plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 150.0)));
assertTrue(plan.hasNext());
assertEquals(plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 200.0)));
assertFalse(plan.hasNext());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ public String visitHead(Head node, String context) {
@Override
public String visitTrendline(Trendline node, String context) {
String child = node.getChild().get(0).accept(this, context);
String computations = visitExpressionList(node.getComputations());
String computations = visitExpressionList(node.getComputations(), " ");
return StringUtils.format("%s | trendline %s", child, computations);
}

Expand All @@ -234,9 +234,13 @@ private String visitFieldList(List<Field> fieldList) {
}

private String visitExpressionList(List<UnresolvedExpression> expressionList) {
return visitExpressionList(expressionList, ",");
}

private String visitExpressionList(List<UnresolvedExpression> expressionList, String delimiter) {
return expressionList.isEmpty()
? ""
: expressionList.stream().map(this::visitExpression).collect(Collectors.joining(","));
: expressionList.stream().map(this::visitExpression).collect(Collectors.joining(delimiter));
}

private String visitExpression(UnresolvedExpression expression) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ public void testDedupCommand() {
anonymize("source=t | dedup f1, f2"));
}

@Test
public void testTrendlineCommand() {
assertEquals(
"source=t | trendline sma(2, date) as date_alias sma(3, time) as time_alias",
anonymize("source=t | trendline sma(2, date) as date_alias sma(3, time) as time_alias"));
}

@Test
public void testHeadCommandWithNumber() {
assertEquals("source=t | head 3", anonymize("source=t | head 3"));
Expand Down

0 comments on commit 8c8cb3e

Please sign in to comment.