Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#639: allow metadata fields and score opensearch function (#228) #1456

Merged
merged 9 commits into from
Apr 10, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import org.opensearch.sql.ast.tree.Values;
import org.opensearch.sql.data.model.ExprMissingValue;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.exception.SemanticCheckException;
Expand Down Expand Up @@ -157,6 +158,9 @@ public LogicalPlan visitRelation(Relation node, AnalysisContext context) {
dataSourceSchemaIdentifierNameResolver.getIdentifierName());
}
table.getFieldTypes().forEach((k, v) -> curEnv.define(new Symbol(Namespace.FIELD_NAME, k), v));
table.getReservedFieldTypes().forEach(
(k, v) -> curEnv.addReservedWord(new Symbol(Namespace.FIELD_NAME, k), v)
);

// Put index name or its alias in index namespace on type environment so qualifier
// can be removed when analyzing qualified name. The value (expr type) here doesn't matter.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

import static org.opensearch.sql.ast.dsl.AstDSL.and;
import static org.opensearch.sql.ast.dsl.AstDSL.compare;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.GTE;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.LTE;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand All @@ -31,6 +29,7 @@
import org.opensearch.sql.ast.expression.Case;
import org.opensearch.sql.ast.expression.Cast;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.DataType;
import org.opensearch.sql.ast.expression.EqualTo;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.Function;
Expand All @@ -42,6 +41,7 @@
import org.opensearch.sql.ast.expression.Or;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.RelevanceFieldList;
import org.opensearch.sql.ast.expression.ScoreFunction;
import org.opensearch.sql.ast.expression.Span;
import org.opensearch.sql.ast.expression.UnresolvedArgument;
import org.opensearch.sql.ast.expression.UnresolvedAttribute;
Expand All @@ -51,6 +51,7 @@
import org.opensearch.sql.ast.expression.Xor;
import org.opensearch.sql.common.antlr.SyntaxCheckException;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.DSL;
Expand All @@ -67,6 +68,7 @@
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.expression.function.BuiltinFunctionRepository;
import org.opensearch.sql.expression.function.FunctionName;
import org.opensearch.sql.expression.function.OpenSearchFunctions;
import org.opensearch.sql.expression.parse.ParseExpression;
import org.opensearch.sql.expression.span.SpanExpression;
import org.opensearch.sql.expression.window.aggregation.AggregateWindowFunction;
Expand Down Expand Up @@ -207,6 +209,65 @@ public Expression visitHighlightFunction(HighlightFunction node, AnalysisContext
return new HighlightExpression(expr);
}

/**
* visitScoreFunction removes the score function from the AST and replaces it with the child
* relevance function node. If the optional boost variable is provided, the boost argument
* of the relevance function is combined.
*
* @param node score function node
* @param context analysis context for the query
* @return resolved relevance function
*/
public Expression visitScoreFunction(ScoreFunction node, AnalysisContext context) {
penghuo marked this conversation as resolved.
Show resolved Hide resolved
Literal boostArg = node.getRelevanceFieldWeight();
if (!boostArg.getType().equals(DataType.DOUBLE)) {
throw new SemanticCheckException(String.format("Expected boost type '%s' but got '%s'",
DataType.DOUBLE.name(), boostArg.getType().name()));
}
Double thisBoostValue = ((Double) boostArg.getValue());

// update the existing unresolved expression to add a boost argument if it doesn't exist
// OR multiply the existing boost argument
Function relevanceQueryUnresolvedExpr = (Function) node.getRelevanceQuery();
List<UnresolvedExpression> relevanceFuncArgs = relevanceQueryUnresolvedExpr.getFuncArgs();

boolean doesFunctionContainBoostArgument = false;
List<UnresolvedExpression> updatedFuncArgs = new ArrayList<>();
for (UnresolvedExpression expr : relevanceFuncArgs) {
String argumentName = ((UnresolvedArgument) expr).getArgName();
if (argumentName.equalsIgnoreCase("boost")) {
doesFunctionContainBoostArgument = true;
Literal boostArgLiteral = (Literal) ((UnresolvedArgument) expr).getValue();
Double boostValue =
Double.parseDouble((String) boostArgLiteral.getValue()) * thisBoostValue;
UnresolvedArgument newBoostArg = new UnresolvedArgument(
argumentName,
new Literal(boostValue.toString(), DataType.STRING)
);
updatedFuncArgs.add(newBoostArg);
} else {
updatedFuncArgs.add(expr);
}
}

// since nothing was found, add an argument
if (!doesFunctionContainBoostArgument) {
UnresolvedArgument newBoostArg = new UnresolvedArgument(
"boost", new Literal(Double.toString(thisBoostValue), DataType.STRING));
updatedFuncArgs.add(newBoostArg);
}

// create a new function expression with boost argument and resolve it
Function updatedRelevanceQueryUnresolvedExpr = new Function(
relevanceQueryUnresolvedExpr.getFuncName(),
updatedFuncArgs);
OpenSearchFunctions.OpenSearchFunction relevanceQueryExpr =
(OpenSearchFunctions.OpenSearchFunction) updatedRelevanceQueryUnresolvedExpr
.accept(this, context);
relevanceQueryExpr.setScoreTracked(true);
return relevanceQueryExpr;
}

@Override
public Expression visitIn(In node, AnalysisContext context) {
return visitIn(node.getField(), node.getValueList(), context);
Expand Down Expand Up @@ -297,6 +358,20 @@ public Expression visitAllFields(AllFields node, AnalysisContext context) {
@Override
public Expression visitQualifiedName(QualifiedName node, AnalysisContext context) {
QualifierAnalyzer qualifierAnalyzer = new QualifierAnalyzer(context);

// check for reserved words in the identifier
TypeEnvironment typeEnv = context.peek();
penghuo marked this conversation as resolved.
Show resolved Hide resolved
for (String part : node.getParts()) {
Optional<ExprType> exprType = typeEnv.getReservedSymbolTable().lookup(
new Symbol(Namespace.FIELD_NAME, part));
if (exprType.isPresent()) {
return visitMetadata(
qualifierAnalyzer.unqualified(node),
(ExprCoreType) exprType.get(),
context
);
}
}
return visitIdentifier(qualifierAnalyzer.unqualified(node), context);
}

Expand All @@ -313,6 +388,19 @@ public Expression visitUnresolvedArgument(UnresolvedArgument node, AnalysisConte
return new NamedArgumentExpression(node.getArgName(), node.getValue().accept(this, context));
}

/**
* If QualifiedName is actually a reserved metadata field, return the expr type associated
* with the metadata field.
* @param ident metadata field name
* @param context analysis context
* @return DSL reference
*/
private Expression visitMetadata(String ident,
ExprCoreType exprCoreType,
AnalysisContext context) {
return DSL.ref(ident, exprCoreType);
}

private Expression visitIdentifier(String ident, AnalysisContext context) {
// ParseExpression will always override ReferenceExpression when ident conflicts
for (NamedExpression expr : context.getNamedParseExpressions()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.sql.expression.conditional.cases.CaseClause;
import org.opensearch.sql.expression.conditional.cases.WhenClause;
import org.opensearch.sql.expression.function.BuiltinFunctionRepository;
import org.opensearch.sql.expression.function.OpenSearchFunctions;
import org.opensearch.sql.planner.logical.LogicalAggregation;
import org.opensearch.sql.planner.logical.LogicalPlan;
import org.opensearch.sql.planner.logical.LogicalPlanNodeVisitor;
Expand Down Expand Up @@ -70,8 +71,17 @@ public Expression visitFunction(FunctionExpression node, AnalysisContext context
final List<Expression> args =
node.getArguments().stream().map(expr -> expr.accept(this, context))
.collect(Collectors.toList());
return (Expression) repository.compile(context.getFunctionProperties(),
node.getFunctionName(), args);
Expression optimizedFunctionExpression = (Expression) repository.compile(
context.getFunctionProperties(),
node.getFunctionName(),
args
);
// Propagate scoreTracked for OpenSearch functions
if (optimizedFunctionExpression instanceof OpenSearchFunctions.OpenSearchFunction) {
((OpenSearchFunctions.OpenSearchFunction) optimizedFunctionExpression).setScoreTracked(
((OpenSearchFunctions.OpenSearchFunction)node).isScoreTracked());
}
dai-chen marked this conversation as resolved.
Show resolved Hide resolved
return optimizedFunctionExpression;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,30 @@ public class TypeEnvironment implements Environment<Symbol, ExprType> {
private final TypeEnvironment parent;
private final SymbolTable symbolTable;

@Getter
private final SymbolTable reservedSymbolTable;
penghuo marked this conversation as resolved.
Show resolved Hide resolved

/**
* Constructor with empty symbol tables.
*
* @param parent parent environment
*/
public TypeEnvironment(TypeEnvironment parent) {
this.parent = parent;
this.symbolTable = new SymbolTable();
this.reservedSymbolTable = new SymbolTable();
}

/**
* Constructor with empty reserved symbol table.
*
* @param parent parent environment
* @param symbolTable type table
*/
public TypeEnvironment(TypeEnvironment parent, SymbolTable symbolTable) {
this.parent = parent;
this.symbolTable = symbolTable;
this.reservedSymbolTable = new SymbolTable();
}

/**
Expand All @@ -59,6 +75,7 @@ public ExprType resolve(Symbol symbol) {

/**
* Resolve all fields in the current environment.
*
* @param namespace a namespace
* @return all symbols in the namespace
*/
Expand Down Expand Up @@ -102,7 +119,11 @@ public void remove(ReferenceExpression ref) {
* Clear all fields in the current environment.
*/
public void clearAllFields() {
lookupAllFields(FIELD_NAME).keySet().stream()
.forEach(v -> remove(new Symbol(Namespace.FIELD_NAME, v)));
lookupAllFields(FIELD_NAME).keySet().forEach(
v -> remove(new Symbol(Namespace.FIELD_NAME, v)));
}

public void addReservedWord(Symbol symbol, ExprType type) {
reservedSymbolTable.store(symbol, type);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.opensearch.sql.ast.expression.Or;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.RelevanceFieldList;
import org.opensearch.sql.ast.expression.ScoreFunction;
import org.opensearch.sql.ast.expression.Span;
import org.opensearch.sql.ast.expression.UnresolvedArgument;
import org.opensearch.sql.ast.expression.UnresolvedAttribute;
Expand Down Expand Up @@ -278,6 +279,10 @@ public T visitHighlightFunction(HighlightFunction node, C context) {
return visitChildren(node, context);
}

public T visitScoreFunction(ScoreFunction node, C context) {
return visitChildren(node, context);
}

public T visitStatement(Statement node, C context) {
return visit(node, context);
}
Expand Down
7 changes: 6 additions & 1 deletion core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.opensearch.sql.ast.expression.Or;
import org.opensearch.sql.ast.expression.ParseMethod;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.ScoreFunction;
import org.opensearch.sql.ast.expression.Span;
import org.opensearch.sql.ast.expression.SpanUnit;
import org.opensearch.sql.ast.expression.UnresolvedArgument;
Expand All @@ -60,7 +61,6 @@
import org.opensearch.sql.ast.tree.TableFunction;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.ast.tree.Values;
import org.opensearch.sql.expression.function.BuiltinFunctionName;

/**
* Class of static methods to create specific node instances.
Expand Down Expand Up @@ -285,6 +285,11 @@ public UnresolvedExpression highlight(UnresolvedExpression fieldName,
return new HighlightFunction(fieldName, arguments);
}

public UnresolvedExpression score(UnresolvedExpression relevanceQuery,
Literal relevanceFieldWeight) {
return new ScoreFunction(relevanceQuery, relevanceFieldWeight);
}

public UnresolvedExpression window(UnresolvedExpression function,
List<UnresolvedExpression> partitionByList,
List<Pair<SortOption, UnresolvedExpression>> sortList) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.expression;

import java.util.List;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;

/**
* Expression node of Score function.
* Score takes a relevance-search expression as an argument and returns it
*/
@AllArgsConstructor
@EqualsAndHashCode(callSuper = false)
@Getter
@ToString
public class ScoreFunction extends UnresolvedExpression {
private final UnresolvedExpression relevanceQuery;
private final Literal relevanceFieldWeight;

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitScoreFunction(this, context);
}

@Override
public List<UnresolvedExpression> getChild() {
return List.of(relevanceQuery);
}
}
14 changes: 13 additions & 1 deletion core/src/main/java/org/opensearch/sql/expression/DSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,19 @@ public static FunctionExpression match_bool_prefix(Expression... args) {
}

public static FunctionExpression wildcard_query(Expression... args) {
return compile(FunctionProperties.None,BuiltinFunctionName.WILDCARD_QUERY, args);
return compile(FunctionProperties.None, BuiltinFunctionName.WILDCARD_QUERY, args);
}

public static FunctionExpression score(Expression... args) {
return compile(FunctionProperties.None, BuiltinFunctionName.SCORE, args);
}

public static FunctionExpression scorequery(Expression... args) {
return compile(FunctionProperties.None, BuiltinFunctionName.SCOREQUERY, args);
}

public static FunctionExpression score_query(Expression... args) {
return compile(FunctionProperties.None, BuiltinFunctionName.SCORE_QUERY, args);
}

public static FunctionExpression now(FunctionProperties functionProperties,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ public enum BuiltinFunctionName {
WEEK_OF_YEAR(FunctionName.of("week_of_year")),
YEAR(FunctionName.of("year")),
YEARWEEK(FunctionName.of("yearweek")),

// `now`-like functions
NOW(FunctionName.of("now")),
CURDATE(FunctionName.of("curdate")),
Expand All @@ -132,6 +133,7 @@ public enum BuiltinFunctionName {
CURRENT_TIMESTAMP(FunctionName.of("current_timestamp")),
LOCALTIMESTAMP(FunctionName.of("localtimestamp")),
SYSDATE(FunctionName.of("sysdate")),

/**
* Text Functions.
*/
Expand Down Expand Up @@ -255,6 +257,10 @@ public enum BuiltinFunctionName {
MATCH_BOOL_PREFIX(FunctionName.of("match_bool_prefix")),
HIGHLIGHT(FunctionName.of("highlight")),
MATCH_PHRASE_PREFIX(FunctionName.of("match_phrase_prefix")),
SCORE(FunctionName.of("score")),
SCOREQUERY(FunctionName.of("scorequery")),
SCORE_QUERY(FunctionName.of("score_query")),

/**
* Legacy Relevance Function.
*/
Expand Down
Loading