Skip to content

Commit

Permalink
Fixing spacing around headers in ExpressionReferenceOptimizer.java Se…
Browse files Browse the repository at this point in the history
…lectExpressionAnalyzer.java

Signed-off-by: Mitchell Gale <[email protected]>
  • Loading branch information
MitchellGale committed Jul 31, 2023
1 parent 7469139 commit a43d97d
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
* SPDX-License-Identifier: Apache-2.0
*/


package org.opensearch.sql.analysis;

import java.util.HashMap;
Expand All @@ -26,8 +25,8 @@
import org.opensearch.sql.planner.logical.LogicalWindow;

/**
* The optimizer used to replace the expression referred in the SelectClause</br>
* e.g. The query SELECT abs(name), sum(age)-avg(age) FROM test GROUP BY abs(name).<br>
* The optimizer used to replace the expression referred in the SelectClause</br> e.g. The query
* SELECT abs(name), sum(age)-avg(age) FROM test GROUP BY abs(name).<br>
* will be translated the AST<br>
* Project[abs(age), sub(sum(age), avg(age))<br>
* &ensp Agg(agg=[sum(age), avg(age)], group=[abs(age)]]<br>
Expand All @@ -43,8 +42,8 @@ public class ExpressionReferenceOptimizer
private final BuiltinFunctionRepository repository;

/**
* The map of expression and it's reference.
* For example, The NamedAggregator should produce the map of Aggregator to Ref(name)
* The map of expression and it's reference. For example, The NamedAggregator should produce the
* map of Aggregator to Ref(name)
*/
private final Map<Expression, Expression> expressionMap = new HashMap<>();

Expand All @@ -69,17 +68,16 @@ public Expression visitFunction(FunctionExpression node, AnalysisContext context
return expressionMap.get(node);
} else {
final List<Expression> args =
node.getArguments().stream().map(expr -> expr.accept(this, context))
node.getArguments().stream()
.map(expr -> expr.accept(this, context))
.collect(Collectors.toList());
Expression optimizedFunctionExpression = (Expression) repository.compile(
context.getFunctionProperties(),
node.getFunctionName(),
args
);
Expression optimizedFunctionExpression =
(Expression)
repository.compile(context.getFunctionProperties(), node.getFunctionName(), args);
// Propagate scoreTracked for OpenSearch functions
if (optimizedFunctionExpression instanceof OpenSearchFunctions.OpenSearchFunction) {
((OpenSearchFunctions.OpenSearchFunction) optimizedFunctionExpression).setScoreTracked(
((OpenSearchFunctions.OpenSearchFunction)node).isScoreTracked());
((OpenSearchFunctions.OpenSearchFunction) optimizedFunctionExpression)
.setScoreTracked(((OpenSearchFunctions.OpenSearchFunction) node).isScoreTracked());
}
return optimizedFunctionExpression;
}
Expand All @@ -98,19 +96,17 @@ public Expression visitNamed(NamedExpression node, AnalysisContext context) {
return node.getDelegated().accept(this, context);
}

/**
* Implement this because Case/When is not registered in function repository.
*/
/** Implement this because Case/When is not registered in function repository. */
@Override
public Expression visitCase(CaseClause node, AnalysisContext context) {
if (expressionMap.containsKey(node)) {
return expressionMap.get(node);
}

List<WhenClause> whenClauses = node.getWhenClauses()
.stream()
.map(expr -> (WhenClause) expr.accept(this, context))
.collect(Collectors.toList());
List<WhenClause> whenClauses =
node.getWhenClauses().stream()
.map(expr -> (WhenClause) expr.accept(this, context))
.collect(Collectors.toList());
Expression defaultResult = null;
if (node.getDefaultResult() != null) {
defaultResult = node.getDefaultResult().accept(this, context);
Expand All @@ -121,14 +117,10 @@ public Expression visitCase(CaseClause node, AnalysisContext context) {
@Override
public Expression visitWhen(WhenClause node, AnalysisContext context) {
return new WhenClause(
node.getCondition().accept(this, context),
node.getResult().accept(this, context));
node.getCondition().accept(this, context), node.getResult().accept(this, context));
}


/**
* Expression Map Builder.
*/
/** Expression Map Builder. */
class ExpressionMapBuilder extends LogicalPlanNodeVisitor<Void, Void> {

@Override
Expand All @@ -140,20 +132,27 @@ public Void visitNode(LogicalPlan plan, Void context) {
@Override
public Void visitAggregation(LogicalAggregation plan, Void context) {
// Create the mapping for all the aggregator.
plan.getAggregatorList().forEach(namedAggregator -> expressionMap
.put(namedAggregator.getDelegated(),
new ReferenceExpression(namedAggregator.getName(), namedAggregator.type())));
plan.getAggregatorList()
.forEach(
namedAggregator ->
expressionMap.put(
namedAggregator.getDelegated(),
new ReferenceExpression(namedAggregator.getName(), namedAggregator.type())));
// Create the mapping for all the group by.
plan.getGroupByList().forEach(groupBy -> expressionMap
.put(groupBy.getDelegated(),
new ReferenceExpression(groupBy.getNameOrAlias(), groupBy.type())));
plan.getGroupByList()
.forEach(
groupBy ->
expressionMap.put(
groupBy.getDelegated(),
new ReferenceExpression(groupBy.getNameOrAlias(), groupBy.type())));
return null;
}

@Override
public Void visitWindow(LogicalWindow plan, Void context) {
Expression windowFunc = plan.getWindowFunction();
expressionMap.put(windowFunc,
expressionMap.put(
windowFunc,
new ReferenceExpression(((NamedExpression) windowFunc).getName(), windowFunc.type()));
return visitNode(plan, context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
* SPDX-License-Identifier: Apache-2.0
*/


package org.opensearch.sql.analysis;

import com.google.common.collect.ImmutableList;
Expand All @@ -30,23 +29,21 @@
import org.opensearch.sql.expression.ReferenceExpression;

/**
* Analyze the select list in the {@link AnalysisContext} to construct the list of
* {@link NamedExpression}.
* Analyze the select list in the {@link AnalysisContext} to construct the list of {@link
* NamedExpression}.
*/
@RequiredArgsConstructor
public class SelectExpressionAnalyzer
extends
AbstractNodeVisitor<List<NamedExpression>, AnalysisContext> {
extends AbstractNodeVisitor<List<NamedExpression>, AnalysisContext> {
private final ExpressionAnalyzer expressionAnalyzer;

private ExpressionReferenceOptimizer optimizer;

/**
* Analyze Select fields.
*/
public List<NamedExpression> analyze(List<UnresolvedExpression> selectList,
AnalysisContext analysisContext,
ExpressionReferenceOptimizer optimizer) {
/** Analyze Select fields. */
public List<NamedExpression> analyze(
List<UnresolvedExpression> selectList,
AnalysisContext analysisContext,
ExpressionReferenceOptimizer optimizer) {
this.optimizer = optimizer;
ImmutableList.Builder<NamedExpression> builder = new ImmutableList.Builder<>();
for (UnresolvedExpression unresolvedExpression : selectList) {
Expand All @@ -68,10 +65,8 @@ public List<NamedExpression> visitAlias(Alias node, AnalysisContext context) {
}

Expression expr = referenceIfSymbolDefined(node, context);
return Collections.singletonList(DSL.named(
unqualifiedNameIfFieldOnly(node, context),
expr,
node.getAlias()));
return Collections.singletonList(
DSL.named(unqualifiedNameIfFieldOnly(node, context), expr, node.getAlias()));
}

/**
Expand All @@ -86,32 +81,32 @@ public List<NamedExpression> visitAlias(Alias node, AnalysisContext context) {
* Project(Alias("name", expr, l), Alias("AVG(age)", aggExpr))<br>
* Agg(Alias("AVG(age)", aggExpr), Alias("length(name)", groupExpr))<br>
*/
private Expression referenceIfSymbolDefined(Alias expr,
AnalysisContext context) {
private Expression referenceIfSymbolDefined(Alias expr, AnalysisContext context) {
UnresolvedExpression delegatedExpr = expr.getDelegated();

// Pass named expression because expression like window function loses full name
// (OVER clause) and thus depends on name in alias to be replaced correctly
return optimizer.optimize(
DSL.named(
expr.getName(),
delegatedExpr.accept(expressionAnalyzer, context),
expr.getAlias()),
expr.getName(), delegatedExpr.accept(expressionAnalyzer, context), expr.getAlias()),
context);
}

@Override
public List<NamedExpression> visitAllFields(AllFields node,
AnalysisContext context) {
public List<NamedExpression> visitAllFields(AllFields node, AnalysisContext context) {
TypeEnvironment environment = context.peek();
Map<String, ExprType> lookupAllFields = environment.lookupAllFields(Namespace.FIELD_NAME);
return lookupAllFields.entrySet().stream().map(entry -> DSL.named(entry.getKey(),
new ReferenceExpression(entry.getKey(), entry.getValue()))).collect(Collectors.toList());
return lookupAllFields.entrySet().stream()
.map(
entry ->
DSL.named(
entry.getKey(), new ReferenceExpression(entry.getKey(), entry.getValue())))
.collect(Collectors.toList());
}

@Override
public List<NamedExpression> visitNestedAllTupleFields(NestedAllTupleFields node,
AnalysisContext context) {
public List<NamedExpression> visitNestedAllTupleFields(
NestedAllTupleFields node, AnalysisContext context) {
TypeEnvironment environment = context.peek();
Map<String, ExprType> lookupAllTupleFields =
environment.lookupAllTupleFields(Namespace.FIELD_NAME);
Expand All @@ -121,14 +116,15 @@ public List<NamedExpression> visitNestedAllTupleFields(NestedAllTupleFields node
Pattern p = Pattern.compile(node.getPath() + "\\.[^\\.]+$");
return lookupAllTupleFields.entrySet().stream()
.filter(field -> p.matcher(field.getKey()).find())
.map(entry -> {
Expression nestedFunc = new Function(
"nested",
List.of(
new QualifiedName(List.of(entry.getKey().split("\\."))))
).accept(expressionAnalyzer, context);
return DSL.named("nested(" + entry.getKey() + ")", nestedFunc);
})
.map(
entry -> {
Expression nestedFunc =
new Function(
"nested",
List.of(new QualifiedName(List.of(entry.getKey().split("\\.")))))
.accept(expressionAnalyzer, context);
return DSL.named("nested(" + entry.getKey() + ")", nestedFunc);
})
.collect(Collectors.toList());
}

Expand All @@ -149,5 +145,4 @@ private String unqualifiedNameIfFieldOnly(Alias node, AnalysisContext context) {
}
return node.getName();
}

}

0 comments on commit a43d97d

Please sign in to comment.