diff --git a/build.gradle b/build.gradle index ff29eb7687..7eb7dcdeef 100644 --- a/build.gradle +++ b/build.gradle @@ -83,10 +83,10 @@ repositories { // Spotless checks will be added as PRs are applied to resolve each style issue is approved. spotless { java { -// target fileTree('.') { -// include '**/*.java', 'src/*/java/**/*.java' -// exclude '**/build/**', '**/build-*/**' -// } + target fileTree('.') { + include '**/*.java', 'src/*/java/**/*.java' + exclude '**/build/**', '**/build-*/**' + } // importOrder() // licenseHeader("/*\n" + // " * Copyright OpenSearch Contributors\n" + @@ -95,7 +95,8 @@ spotless { // removeUnusedImports() // trimTrailingWhitespace() // endWithNewline() -// googleJavaFormat('1.17.0').reflowLongStrings().groupArtifact('com.google.googlejavaformat:google-java-format') + //googleJavaFormat('1.17.0').reflowLongStrings().groupArtifact('com.google.googlejavaformat:google-java-format') + eclipse().configFile rootProject.file('formatterConfig.xml') } } diff --git a/core/src/main/java/org/opensearch/sql/analysis/AnalysisContext.java b/core/src/main/java/org/opensearch/sql/analysis/AnalysisContext.java index 4704d0566b..91e537b58d 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/AnalysisContext.java +++ b/core/src/main/java/org/opensearch/sql/analysis/AnalysisContext.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.analysis; import java.util.ArrayList; @@ -17,56 +16,56 @@ * The context used for Analyzer. */ public class AnalysisContext { - /** - * Environment stack for symbol scope management. - */ - private TypeEnvironment environment; - @Getter - private final List namedParseExpressions; + /** + * Environment stack for symbol scope management. + */ + private TypeEnvironment environment; + @Getter + private final List namedParseExpressions; - @Getter - private final FunctionProperties functionProperties; + @Getter + private final FunctionProperties functionProperties; - public AnalysisContext() { - this(new TypeEnvironment(null)); - } + public AnalysisContext() { + this(new TypeEnvironment(null)); + } - /** - * Class CTOR. - * @param environment Env to set to a new instance. - */ - public AnalysisContext(TypeEnvironment environment) { - this.environment = environment; - this.namedParseExpressions = new ArrayList<>(); - this.functionProperties = new FunctionProperties(); - } + /** + * Class CTOR. + * @param environment Env to set to a new instance. + */ + public AnalysisContext(TypeEnvironment environment) { + this.environment = environment; + this.namedParseExpressions = new ArrayList<>(); + this.functionProperties = new FunctionProperties(); + } - /** - * Push a new environment. - */ - public void push() { - environment = new TypeEnvironment(environment); - } + /** + * Push a new environment. + */ + public void push() { + environment = new TypeEnvironment(environment); + } - /** - * Return current environment. - * - * @return current environment - */ - public TypeEnvironment peek() { - return environment; - } + /** + * Return current environment. + * + * @return current environment + */ + public TypeEnvironment peek() { + return environment; + } - /** - * Pop up current environment from environment chain. - * - * @return current environment (before pop) - */ - public TypeEnvironment pop() { - Objects.requireNonNull(environment, "Fail to pop context due to no environment present"); + /** + * Pop up current environment from environment chain. + * + * @return current environment (before pop) + */ + public TypeEnvironment pop() { + Objects.requireNonNull(environment, "Fail to pop context due to no environment present"); - TypeEnvironment curEnv = environment; - environment = curEnv.getParent(); - return curEnv; - } + TypeEnvironment curEnv = environment; + environment = curEnv.getParent(); + return curEnv; + } } diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 2c4647004c..d79f12413e 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.analysis; import static org.opensearch.sql.analysis.DataSourceSchemaIdentifierNameResolver.DEFAULT_DATASOURCE_NAME; @@ -112,497 +111,470 @@ */ public class Analyzer extends AbstractNodeVisitor { - private final ExpressionAnalyzer expressionAnalyzer; - - private final SelectExpressionAnalyzer selectExpressionAnalyzer; - - private final NamedExpressionAnalyzer namedExpressionAnalyzer; - - private final DataSourceService dataSourceService; - - private final BuiltinFunctionRepository repository; - - /** - * Constructor. - */ - public Analyzer( - ExpressionAnalyzer expressionAnalyzer, - DataSourceService dataSourceService, - BuiltinFunctionRepository repository) { - this.expressionAnalyzer = expressionAnalyzer; - this.dataSourceService = dataSourceService; - this.selectExpressionAnalyzer = new SelectExpressionAnalyzer(expressionAnalyzer); - this.namedExpressionAnalyzer = new NamedExpressionAnalyzer(expressionAnalyzer); - this.repository = repository; - } - - public LogicalPlan analyze(UnresolvedPlan unresolved, AnalysisContext context) { - return unresolved.accept(this, context); - } - - @Override - public LogicalPlan visitRelation(Relation node, AnalysisContext context) { - QualifiedName qualifiedName = node.getTableQualifiedName(); - DataSourceSchemaIdentifierNameResolver dataSourceSchemaIdentifierNameResolver - = new DataSourceSchemaIdentifierNameResolver(dataSourceService, qualifiedName.getParts()); - String tableName = dataSourceSchemaIdentifierNameResolver.getIdentifierName(); - context.push(); - TypeEnvironment curEnv = context.peek(); - Table table; - if (DATASOURCES_TABLE_NAME.equals(tableName)) { - table = new DataSourceTable(dataSourceService); - } else { - table = dataSourceService - .getDataSource(dataSourceSchemaIdentifierNameResolver.getDataSourceName()) - .getStorageEngine() - .getTable(new DataSourceSchemaName( - dataSourceSchemaIdentifierNameResolver.getDataSourceName(), - dataSourceSchemaIdentifierNameResolver.getSchemaName()), - dataSourceSchemaIdentifierNameResolver.getIdentifierName()); + private final ExpressionAnalyzer expressionAnalyzer; + + private final SelectExpressionAnalyzer selectExpressionAnalyzer; + + private final NamedExpressionAnalyzer namedExpressionAnalyzer; + + private final DataSourceService dataSourceService; + + private final BuiltinFunctionRepository repository; + + /** + * Constructor. + */ + public Analyzer(ExpressionAnalyzer expressionAnalyzer, DataSourceService dataSourceService, BuiltinFunctionRepository repository) { + this.expressionAnalyzer = expressionAnalyzer; + this.dataSourceService = dataSourceService; + this.selectExpressionAnalyzer = new SelectExpressionAnalyzer(expressionAnalyzer); + this.namedExpressionAnalyzer = new NamedExpressionAnalyzer(expressionAnalyzer); + this.repository = repository; } - table.getFieldTypes().forEach((k, v) -> curEnv.define(new Symbol(Namespace.FIELD_NAME, k), v)); - table.getReservedFieldTypes().forEach( - (k, v) -> curEnv.addReservedWord(new Symbol(Namespace.FIELD_NAME, k), v) - ); - - // Put index name or its alias in index namespace on type environment so qualifier - // can be removed when analyzing qualified name. The value (expr type) here doesn't matter. - curEnv.define(new Symbol(Namespace.INDEX_NAME, - (node.getAlias() == null) ? tableName : node.getAlias()), STRUCT); - - return new LogicalRelation(tableName, table); - } - - - @Override - public LogicalPlan visitRelationSubquery(RelationSubquery node, AnalysisContext context) { - LogicalPlan subquery = analyze(node.getChild().get(0), context); - // inherit the parent environment to keep the subquery fields in current environment - TypeEnvironment curEnv = context.peek(); - - // Put subquery alias in index namespace so the qualifier can be removed - // when analyzing qualified name in the subquery layer - curEnv.define(new Symbol(Namespace.INDEX_NAME, node.getAliasAsTableName()), STRUCT); - return subquery; - } - - @Override - public LogicalPlan visitTableFunction(TableFunction node, AnalysisContext context) { - QualifiedName qualifiedName = node.getFunctionName(); - DataSourceSchemaIdentifierNameResolver dataSourceSchemaIdentifierNameResolver - = new DataSourceSchemaIdentifierNameResolver(this.dataSourceService, - qualifiedName.getParts()); - - FunctionName functionName - = FunctionName.of(dataSourceSchemaIdentifierNameResolver.getIdentifierName()); - List arguments = node.getArguments().stream() - .map(unresolvedExpression -> this.expressionAnalyzer.analyze(unresolvedExpression, context)) - .collect(Collectors.toList()); - TableFunctionImplementation tableFunctionImplementation - = (TableFunctionImplementation) repository.compile(context.getFunctionProperties(), - dataSourceService - .getDataSource(dataSourceSchemaIdentifierNameResolver.getDataSourceName()) - .getStorageEngine().getFunctions(), functionName, arguments); - context.push(); - TypeEnvironment curEnv = context.peek(); - Table table = tableFunctionImplementation.applyArguments(); - table.getFieldTypes().forEach((k, v) -> curEnv.define(new Symbol(Namespace.FIELD_NAME, k), v)); - table.getReservedFieldTypes().forEach( - (k, v) -> curEnv.addReservedWord(new Symbol(Namespace.FIELD_NAME, k), v) - ); - curEnv.define(new Symbol(Namespace.INDEX_NAME, - dataSourceSchemaIdentifierNameResolver.getIdentifierName()), STRUCT); - return new LogicalRelation(dataSourceSchemaIdentifierNameResolver.getIdentifierName(), - tableFunctionImplementation.applyArguments()); - } - - @Override - public LogicalPlan visitLimit(Limit node, AnalysisContext context) { - LogicalPlan child = node.getChild().get(0).accept(this, context); - return new LogicalLimit(child, node.getLimit(), node.getOffset()); - } - - @Override - public LogicalPlan visitFilter(Filter node, AnalysisContext context) { - LogicalPlan child = node.getChild().get(0).accept(this, context); - Expression condition = expressionAnalyzer.analyze(node.getCondition(), context); - - ExpressionReferenceOptimizer optimizer = - new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child); - Expression optimized = optimizer.optimize(condition, context); - return new LogicalFilter(child, optimized); - } - - /** - * Ensure NESTED function is not used in GROUP BY, and HAVING clauses. - * Fallback to legacy engine. Can remove when support is added for NESTED function in WHERE, - * GROUP BY, ORDER BY, and HAVING clauses. - * @param condition : Filter condition - */ - private void verifySupportsCondition(Expression condition) { - if (condition instanceof FunctionExpression) { - if (((FunctionExpression) condition).getFunctionName().getFunctionName().equalsIgnoreCase( - BuiltinFunctionName.NESTED.name() - )) { - throw new SyntaxCheckException( - "Falling back to legacy engine. Nested function is not supported in WHERE," - + " GROUP BY, and HAVING clauses." + + public LogicalPlan analyze(UnresolvedPlan unresolved, AnalysisContext context) { + return unresolved.accept(this, context); + } + + @Override + public LogicalPlan visitRelation(Relation node, AnalysisContext context) { + QualifiedName qualifiedName = node.getTableQualifiedName(); + DataSourceSchemaIdentifierNameResolver dataSourceSchemaIdentifierNameResolver = new DataSourceSchemaIdentifierNameResolver( + dataSourceService, + qualifiedName.getParts() ); - } - ((FunctionExpression)condition).getArguments().stream() - .forEach(e -> verifySupportsCondition(e) - ); + String tableName = dataSourceSchemaIdentifierNameResolver.getIdentifierName(); + context.push(); + TypeEnvironment curEnv = context.peek(); + Table table; + if (DATASOURCES_TABLE_NAME.equals(tableName)) { + table = new DataSourceTable(dataSourceService); + } else { + table = dataSourceService.getDataSource(dataSourceSchemaIdentifierNameResolver.getDataSourceName()) + .getStorageEngine() + .getTable( + new DataSourceSchemaName( + dataSourceSchemaIdentifierNameResolver.getDataSourceName(), + dataSourceSchemaIdentifierNameResolver.getSchemaName() + ), + dataSourceSchemaIdentifierNameResolver.getIdentifierName() + ); + } + table.getFieldTypes().forEach((k, v) -> curEnv.define(new Symbol(Namespace.FIELD_NAME, k), v)); + table.getReservedFieldTypes().forEach((k, v) -> curEnv.addReservedWord(new Symbol(Namespace.FIELD_NAME, k), v)); + + // Put index name or its alias in index namespace on type environment so qualifier + // can be removed when analyzing qualified name. The value (expr type) here doesn't matter. + curEnv.define(new Symbol(Namespace.INDEX_NAME, (node.getAlias() == null) ? tableName : node.getAlias()), STRUCT); + + return new LogicalRelation(tableName, table); } - } - - /** - * Build {@link LogicalRename}. - */ - @Override - public LogicalPlan visitRename(Rename node, AnalysisContext context) { - LogicalPlan child = node.getChild().get(0).accept(this, context); - ImmutableMap.Builder renameMapBuilder = - new ImmutableMap.Builder<>(); - for (Map renameMap : node.getRenameList()) { - Expression origin = expressionAnalyzer.analyze(renameMap.getOrigin(), context); - // We should define the new target field in the context instead of analyze it. - if (renameMap.getTarget() instanceof Field) { - ReferenceExpression target = - new ReferenceExpression(((Field) renameMap.getTarget()).getField().toString(), - origin.type()); - ReferenceExpression originExpr = DSL.ref(origin.toString(), origin.type()); + + @Override + public LogicalPlan visitRelationSubquery(RelationSubquery node, AnalysisContext context) { + LogicalPlan subquery = analyze(node.getChild().get(0), context); + // inherit the parent environment to keep the subquery fields in current environment TypeEnvironment curEnv = context.peek(); - curEnv.remove(originExpr); - curEnv.define(target); - renameMapBuilder.put(originExpr, target); - } else { - throw new SemanticCheckException( - String.format("the target expected to be field, but is %s", renameMap.getTarget())); - } + + // Put subquery alias in index namespace so the qualifier can be removed + // when analyzing qualified name in the subquery layer + curEnv.define(new Symbol(Namespace.INDEX_NAME, node.getAliasAsTableName()), STRUCT); + return subquery; + } + + @Override + public LogicalPlan visitTableFunction(TableFunction node, AnalysisContext context) { + QualifiedName qualifiedName = node.getFunctionName(); + DataSourceSchemaIdentifierNameResolver dataSourceSchemaIdentifierNameResolver = new DataSourceSchemaIdentifierNameResolver( + this.dataSourceService, + qualifiedName.getParts() + ); + + FunctionName functionName = FunctionName.of(dataSourceSchemaIdentifierNameResolver.getIdentifierName()); + List arguments = node.getArguments() + .stream() + .map(unresolvedExpression -> this.expressionAnalyzer.analyze(unresolvedExpression, context)) + .collect(Collectors.toList()); + TableFunctionImplementation tableFunctionImplementation = (TableFunctionImplementation) repository.compile( + context.getFunctionProperties(), + dataSourceService.getDataSource(dataSourceSchemaIdentifierNameResolver.getDataSourceName()).getStorageEngine().getFunctions(), + functionName, + arguments + ); + context.push(); + TypeEnvironment curEnv = context.peek(); + Table table = tableFunctionImplementation.applyArguments(); + table.getFieldTypes().forEach((k, v) -> curEnv.define(new Symbol(Namespace.FIELD_NAME, k), v)); + table.getReservedFieldTypes().forEach((k, v) -> curEnv.addReservedWord(new Symbol(Namespace.FIELD_NAME, k), v)); + curEnv.define(new Symbol(Namespace.INDEX_NAME, dataSourceSchemaIdentifierNameResolver.getIdentifierName()), STRUCT); + return new LogicalRelation( + dataSourceSchemaIdentifierNameResolver.getIdentifierName(), + tableFunctionImplementation.applyArguments() + ); + } + + @Override + public LogicalPlan visitLimit(Limit node, AnalysisContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + return new LogicalLimit(child, node.getLimit(), node.getOffset()); + } + + @Override + public LogicalPlan visitFilter(Filter node, AnalysisContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + Expression condition = expressionAnalyzer.analyze(node.getCondition(), context); + + ExpressionReferenceOptimizer optimizer = new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child); + Expression optimized = optimizer.optimize(condition, context); + return new LogicalFilter(child, optimized); + } + + /** + * Ensure NESTED function is not used in GROUP BY, and HAVING clauses. + * Fallback to legacy engine. Can remove when support is added for NESTED function in WHERE, + * GROUP BY, ORDER BY, and HAVING clauses. + * @param condition : Filter condition + */ + private void verifySupportsCondition(Expression condition) { + if (condition instanceof FunctionExpression) { + if (((FunctionExpression) condition).getFunctionName().getFunctionName().equalsIgnoreCase(BuiltinFunctionName.NESTED.name())) { + throw new SyntaxCheckException( + "Falling back to legacy engine. Nested function is not supported in WHERE," + " GROUP BY, and HAVING clauses." + ); + } + ((FunctionExpression) condition).getArguments().stream().forEach(e -> verifySupportsCondition(e)); + } } - return new LogicalRename(child, renameMapBuilder.build()); - } - - /** - * Build {@link LogicalAggregation}. - */ - @Override - public LogicalPlan visitAggregation(Aggregation node, AnalysisContext context) { - final LogicalPlan child = node.getChild().get(0).accept(this, context); - ImmutableList.Builder aggregatorBuilder = new ImmutableList.Builder<>(); - for (UnresolvedExpression expr : node.getAggExprList()) { - NamedExpression aggExpr = namedExpressionAnalyzer.analyze(expr, context); - aggregatorBuilder - .add(new NamedAggregator(aggExpr.getNameOrAlias(), (Aggregator) aggExpr.getDelegated())); + /** + * Build {@link LogicalRename}. + */ + @Override + public LogicalPlan visitRename(Rename node, AnalysisContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + ImmutableMap.Builder renameMapBuilder = new ImmutableMap.Builder<>(); + for (Map renameMap : node.getRenameList()) { + Expression origin = expressionAnalyzer.analyze(renameMap.getOrigin(), context); + // We should define the new target field in the context instead of analyze it. + if (renameMap.getTarget() instanceof Field) { + ReferenceExpression target = new ReferenceExpression(((Field) renameMap.getTarget()).getField().toString(), origin.type()); + ReferenceExpression originExpr = DSL.ref(origin.toString(), origin.type()); + TypeEnvironment curEnv = context.peek(); + curEnv.remove(originExpr); + curEnv.define(target); + renameMapBuilder.put(originExpr, target); + } else { + throw new SemanticCheckException(String.format("the target expected to be field, but is %s", renameMap.getTarget())); + } + } + + return new LogicalRename(child, renameMapBuilder.build()); } - ImmutableList.Builder groupbyBuilder = new ImmutableList.Builder<>(); - // Span should be first expression if exist. - if (node.getSpan() != null) { - groupbyBuilder.add(namedExpressionAnalyzer.analyze(node.getSpan(), context)); + /** + * Build {@link LogicalAggregation}. + */ + @Override + public LogicalPlan visitAggregation(Aggregation node, AnalysisContext context) { + final LogicalPlan child = node.getChild().get(0).accept(this, context); + ImmutableList.Builder aggregatorBuilder = new ImmutableList.Builder<>(); + for (UnresolvedExpression expr : node.getAggExprList()) { + NamedExpression aggExpr = namedExpressionAnalyzer.analyze(expr, context); + aggregatorBuilder.add(new NamedAggregator(aggExpr.getNameOrAlias(), (Aggregator) aggExpr.getDelegated())); + } + + ImmutableList.Builder groupbyBuilder = new ImmutableList.Builder<>(); + // Span should be first expression if exist. + if (node.getSpan() != null) { + groupbyBuilder.add(namedExpressionAnalyzer.analyze(node.getSpan(), context)); + } + + for (UnresolvedExpression expr : node.getGroupExprList()) { + NamedExpression resolvedExpr = namedExpressionAnalyzer.analyze(expr, context); + verifySupportsCondition(resolvedExpr.getDelegated()); + groupbyBuilder.add(resolvedExpr); + } + ImmutableList groupBys = groupbyBuilder.build(); + + ImmutableList aggregators = aggregatorBuilder.build(); + // new context + context.push(); + TypeEnvironment newEnv = context.peek(); + aggregators.forEach(aggregator -> newEnv.define(new Symbol(Namespace.FIELD_NAME, aggregator.getName()), aggregator.type())); + groupBys.forEach(group -> newEnv.define(new Symbol(Namespace.FIELD_NAME, group.getNameOrAlias()), group.type())); + return new LogicalAggregation(child, aggregators, groupBys); } - for (UnresolvedExpression expr : node.getGroupExprList()) { - NamedExpression resolvedExpr = namedExpressionAnalyzer.analyze(expr, context); - verifySupportsCondition(resolvedExpr.getDelegated()); - groupbyBuilder.add(resolvedExpr); + /** + * Build {@link LogicalRareTopN}. + */ + @Override + public LogicalPlan visitRareTopN(RareTopN node, AnalysisContext context) { + final LogicalPlan child = node.getChild().get(0).accept(this, context); + + ImmutableList.Builder groupbyBuilder = new ImmutableList.Builder<>(); + for (UnresolvedExpression expr : node.getGroupExprList()) { + groupbyBuilder.add(expressionAnalyzer.analyze(expr, context)); + } + ImmutableList groupBys = groupbyBuilder.build(); + + ImmutableList.Builder fieldsBuilder = new ImmutableList.Builder<>(); + for (Field f : node.getFields()) { + fieldsBuilder.add(expressionAnalyzer.analyze(f, context)); + } + ImmutableList fields = fieldsBuilder.build(); + + // new context + context.push(); + TypeEnvironment newEnv = context.peek(); + groupBys.forEach(group -> newEnv.define(new Symbol(Namespace.FIELD_NAME, group.toString()), group.type())); + fields.forEach(field -> newEnv.define(new Symbol(Namespace.FIELD_NAME, field.toString()), field.type())); + + List options = node.getNoOfResults(); + Integer noOfResults = (Integer) options.get(0).getValue().getValue(); + + return new LogicalRareTopN(child, node.getCommandType(), noOfResults, fields, groupBys); } - ImmutableList groupBys = groupbyBuilder.build(); - - ImmutableList aggregators = aggregatorBuilder.build(); - // new context - context.push(); - TypeEnvironment newEnv = context.peek(); - aggregators.forEach(aggregator -> newEnv.define(new Symbol(Namespace.FIELD_NAME, - aggregator.getName()), aggregator.type())); - groupBys.forEach(group -> newEnv.define(new Symbol(Namespace.FIELD_NAME, - group.getNameOrAlias()), group.type())); - return new LogicalAggregation(child, aggregators, groupBys); - } - - /** - * Build {@link LogicalRareTopN}. - */ - @Override - public LogicalPlan visitRareTopN(RareTopN node, AnalysisContext context) { - final LogicalPlan child = node.getChild().get(0).accept(this, context); - - ImmutableList.Builder groupbyBuilder = new ImmutableList.Builder<>(); - for (UnresolvedExpression expr : node.getGroupExprList()) { - groupbyBuilder.add(expressionAnalyzer.analyze(expr, context)); + + /** + * Build {@link LogicalProject} or {@link LogicalRemove} from {@link Field}. + * + *

Todo, the include/exclude fields should change the env definition. The cons of current + * implementation is even the query contain the field reference which has been excluded from + * fields command. There is no {@link SemanticCheckException} will be thrown. Instead, the during + * runtime evaluation, the not exist field will be resolve to {@link ExprMissingValue} which will + * not impact the correctness. + * + *

Postpone the implementation when finding more use case. + */ + @Override + public LogicalPlan visitProject(Project node, AnalysisContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + + if (node.hasArgument()) { + Argument argument = node.getArgExprList().get(0); + Boolean exclude = (Boolean) argument.getValue().getValue(); + if (exclude) { + TypeEnvironment curEnv = context.peek(); + List referenceExpressions = node.getProjectList() + .stream() + .map(expr -> (ReferenceExpression) expressionAnalyzer.analyze(expr, context)) + .collect(Collectors.toList()); + referenceExpressions.forEach(ref -> curEnv.remove(ref)); + return new LogicalRemove(child, ImmutableSet.copyOf(referenceExpressions)); + } + } + + // For each unresolved window function, analyze it by "insert" a window and sort operator + // between project and its child. + for (UnresolvedExpression expr : node.getProjectList()) { + WindowExpressionAnalyzer windowAnalyzer = new WindowExpressionAnalyzer(expressionAnalyzer, child); + child = windowAnalyzer.analyze(expr, context); + } + + for (UnresolvedExpression expr : node.getProjectList()) { + HighlightAnalyzer highlightAnalyzer = new HighlightAnalyzer(expressionAnalyzer, child); + child = highlightAnalyzer.analyze(expr, context); + } + + List namedExpressions = selectExpressionAnalyzer.analyze( + node.getProjectList(), + context, + new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child) + ); + + for (UnresolvedExpression expr : node.getProjectList()) { + NestedAnalyzer nestedAnalyzer = new NestedAnalyzer(namedExpressions, expressionAnalyzer, child); + child = nestedAnalyzer.analyze(expr, context); + } + + // new context + context.push(); + TypeEnvironment newEnv = context.peek(); + namedExpressions.forEach(expr -> newEnv.define(new Symbol(Namespace.FIELD_NAME, expr.getNameOrAlias()), expr.type())); + List namedParseExpressions = context.getNamedParseExpressions(); + return new LogicalProject(child, namedExpressions, namedParseExpressions); } - ImmutableList groupBys = groupbyBuilder.build(); - ImmutableList.Builder fieldsBuilder = new ImmutableList.Builder<>(); - for (Field f : node.getFields()) { - fieldsBuilder.add(expressionAnalyzer.analyze(f, context)); + /** + * Build {@link LogicalEval}. + */ + @Override + public LogicalPlan visitEval(Eval node, AnalysisContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + ImmutableList.Builder> expressionsBuilder = new Builder<>(); + for (Let let : node.getExpressionList()) { + Expression expression = expressionAnalyzer.analyze(let.getExpression(), context); + ReferenceExpression ref = DSL.ref(let.getVar().getField().toString(), expression.type()); + expressionsBuilder.add(ImmutablePair.of(ref, expression)); + TypeEnvironment typeEnvironment = context.peek(); + // define the new reference in type env. + typeEnvironment.define(ref); + } + return new LogicalEval(child, expressionsBuilder.build()); } - ImmutableList fields = fieldsBuilder.build(); - - // new context - context.push(); - TypeEnvironment newEnv = context.peek(); - groupBys.forEach(group -> newEnv.define(new Symbol(Namespace.FIELD_NAME, - group.toString()), group.type())); - fields.forEach(field -> newEnv.define(new Symbol(Namespace.FIELD_NAME, - field.toString()), field.type())); - - List options = node.getNoOfResults(); - Integer noOfResults = (Integer) options.get(0).getValue().getValue(); - - return new LogicalRareTopN(child, node.getCommandType(), noOfResults, fields, groupBys); - } - - /** - * Build {@link LogicalProject} or {@link LogicalRemove} from {@link Field}. - * - *

Todo, the include/exclude fields should change the env definition. The cons of current - * implementation is even the query contain the field reference which has been excluded from - * fields command. There is no {@link SemanticCheckException} will be thrown. Instead, the during - * runtime evaluation, the not exist field will be resolve to {@link ExprMissingValue} which will - * not impact the correctness. - * - *

Postpone the implementation when finding more use case. - */ - @Override - public LogicalPlan visitProject(Project node, AnalysisContext context) { - LogicalPlan child = node.getChild().get(0).accept(this, context); - - if (node.hasArgument()) { - Argument argument = node.getArgExprList().get(0); - Boolean exclude = (Boolean) argument.getValue().getValue(); - if (exclude) { + + /** + * Build {@link ParseExpression} to context and skip to child nodes. + */ + @Override + public LogicalPlan visitParse(Parse node, AnalysisContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + Expression sourceField = expressionAnalyzer.analyze(node.getSourceField(), context); + ParseMethod parseMethod = node.getParseMethod(); + java.util.Map arguments = node.getArguments(); + String pattern = (String) node.getPattern().getValue(); + Expression patternExpression = DSL.literal(pattern); + TypeEnvironment curEnv = context.peek(); - List referenceExpressions = - node.getProjectList().stream() - .map(expr -> (ReferenceExpression) expressionAnalyzer.analyze(expr, context)) - .collect(Collectors.toList()); - referenceExpressions.forEach(ref -> curEnv.remove(ref)); - return new LogicalRemove(child, ImmutableSet.copyOf(referenceExpressions)); - } + ParseUtils.getNamedGroupCandidates(parseMethod, pattern, arguments).forEach(group -> { + ParseExpression expr = ParseUtils.createParseExpression(parseMethod, sourceField, patternExpression, DSL.literal(group)); + curEnv.define(new Symbol(Namespace.FIELD_NAME, group), expr.type()); + context.getNamedParseExpressions().add(new NamedExpression(group, expr)); + }); + return child; } - // For each unresolved window function, analyze it by "insert" a window and sort operator - // between project and its child. - for (UnresolvedExpression expr : node.getProjectList()) { - WindowExpressionAnalyzer windowAnalyzer = - new WindowExpressionAnalyzer(expressionAnalyzer, child); - child = windowAnalyzer.analyze(expr, context); + /** + * Build {@link LogicalSort}. + */ + @Override + public LogicalPlan visitSort(Sort node, AnalysisContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + ExpressionReferenceOptimizer optimizer = new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child); + + List> sortList = node.getSortList().stream().map(sortField -> { + var analyzed = expressionAnalyzer.analyze(sortField.getField(), context); + if (analyzed == null) { + throw new UnsupportedOperationException(String.format("Invalid use of expression %s", sortField.getField())); + } + Expression expression = optimizer.optimize(analyzed, context); + return ImmutablePair.of(analyzeSortOption(sortField.getFieldArgs()), expression); + }).collect(Collectors.toList()); + return new LogicalSort(child, sortList); } - for (UnresolvedExpression expr : node.getProjectList()) { - HighlightAnalyzer highlightAnalyzer = new HighlightAnalyzer(expressionAnalyzer, child); - child = highlightAnalyzer.analyze(expr, context); + /** + * Build {@link LogicalDedupe}. + */ + @Override + public LogicalPlan visitDedupe(Dedupe node, AnalysisContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + List options = node.getOptions(); + // Todo, refactor the option. + Integer allowedDuplication = (Integer) options.get(0).getValue().getValue(); + Boolean keepEmpty = (Boolean) options.get(1).getValue().getValue(); + Boolean consecutive = (Boolean) options.get(2).getValue().getValue(); + + return new LogicalDedupe( + child, + node.getFields().stream().map(f -> expressionAnalyzer.analyze(f, context)).collect(Collectors.toList()), + allowedDuplication, + keepEmpty, + consecutive + ); } - List namedExpressions = - selectExpressionAnalyzer.analyze(node.getProjectList(), context, - new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child)); + /** + * Logical head is identical to {@link LogicalLimit}. + */ + public LogicalPlan visitHead(Head node, AnalysisContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + return new LogicalLimit(child, node.getSize(), node.getFrom()); + } - for (UnresolvedExpression expr : node.getProjectList()) { - NestedAnalyzer nestedAnalyzer = new NestedAnalyzer( - namedExpressions, expressionAnalyzer, child - ); - child = nestedAnalyzer.analyze(expr, context); + @Override + public LogicalPlan visitValues(Values node, AnalysisContext context) { + List> values = node.getValues(); + List> valueExprs = new ArrayList<>(); + for (List value : values) { + valueExprs.add( + value.stream().map(val -> (LiteralExpression) expressionAnalyzer.analyze(val, context)).collect(Collectors.toList()) + ); + } + return new LogicalValues(valueExprs); } - // new context - context.push(); - TypeEnvironment newEnv = context.peek(); - namedExpressions.forEach(expr -> newEnv.define(new Symbol(Namespace.FIELD_NAME, - expr.getNameOrAlias()), expr.type())); - List namedParseExpressions = context.getNamedParseExpressions(); - return new LogicalProject(child, namedExpressions, namedParseExpressions); - } - - /** - * Build {@link LogicalEval}. - */ - @Override - public LogicalPlan visitEval(Eval node, AnalysisContext context) { - LogicalPlan child = node.getChild().get(0).accept(this, context); - ImmutableList.Builder> expressionsBuilder = - new Builder<>(); - for (Let let : node.getExpressionList()) { - Expression expression = expressionAnalyzer.analyze(let.getExpression(), context); - ReferenceExpression ref = DSL.ref(let.getVar().getField().toString(), expression.type()); - expressionsBuilder.add(ImmutablePair.of(ref, expression)); - TypeEnvironment typeEnvironment = context.peek(); - // define the new reference in type env. - typeEnvironment.define(ref); + /** + * Build {@link LogicalMLCommons} for Kmeans command. + */ + @Override + public LogicalPlan visitKmeans(Kmeans node, AnalysisContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + java.util.Map options = node.getArguments(); + + TypeEnvironment currentEnv = context.peek(); + currentEnv.define(new Symbol(Namespace.FIELD_NAME, "ClusterID"), ExprCoreType.INTEGER); + + return new LogicalMLCommons(child, "kmeans", options); } - return new LogicalEval(child, expressionsBuilder.build()); - } - - /** - * Build {@link ParseExpression} to context and skip to child nodes. - */ - @Override - public LogicalPlan visitParse(Parse node, AnalysisContext context) { - LogicalPlan child = node.getChild().get(0).accept(this, context); - Expression sourceField = expressionAnalyzer.analyze(node.getSourceField(), context); - ParseMethod parseMethod = node.getParseMethod(); - java.util.Map arguments = node.getArguments(); - String pattern = (String) node.getPattern().getValue(); - Expression patternExpression = DSL.literal(pattern); - - TypeEnvironment curEnv = context.peek(); - ParseUtils.getNamedGroupCandidates(parseMethod, pattern, arguments).forEach(group -> { - ParseExpression expr = ParseUtils.createParseExpression(parseMethod, sourceField, - patternExpression, DSL.literal(group)); - curEnv.define(new Symbol(Namespace.FIELD_NAME, group), expr.type()); - context.getNamedParseExpressions().add(new NamedExpression(group, expr)); - }); - return child; - } - - /** - * Build {@link LogicalSort}. - */ - @Override - public LogicalPlan visitSort(Sort node, AnalysisContext context) { - LogicalPlan child = node.getChild().get(0).accept(this, context); - ExpressionReferenceOptimizer optimizer = - new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child); - - List> sortList = - node.getSortList().stream() - .map( - sortField -> { - var analyzed = expressionAnalyzer.analyze(sortField.getField(), context); - if (analyzed == null) { - throw new UnsupportedOperationException( - String.format("Invalid use of expression %s", sortField.getField()) - ); - } - Expression expression = optimizer.optimize(analyzed, context); - return ImmutablePair.of(analyzeSortOption(sortField.getFieldArgs()), expression); - }) - .collect(Collectors.toList()); - return new LogicalSort(child, sortList); - } - - /** - * Build {@link LogicalDedupe}. - */ - @Override - public LogicalPlan visitDedupe(Dedupe node, AnalysisContext context) { - LogicalPlan child = node.getChild().get(0).accept(this, context); - List options = node.getOptions(); - // Todo, refactor the option. - Integer allowedDuplication = (Integer) options.get(0).getValue().getValue(); - Boolean keepEmpty = (Boolean) options.get(1).getValue().getValue(); - Boolean consecutive = (Boolean) options.get(2).getValue().getValue(); - - return new LogicalDedupe( - child, - node.getFields().stream() - .map(f -> expressionAnalyzer.analyze(f, context)) - .collect(Collectors.toList()), - allowedDuplication, - keepEmpty, - consecutive); - } - - /** - * Logical head is identical to {@link LogicalLimit}. - */ - public LogicalPlan visitHead(Head node, AnalysisContext context) { - LogicalPlan child = node.getChild().get(0).accept(this, context); - return new LogicalLimit(child, node.getSize(), node.getFrom()); - } - - @Override - public LogicalPlan visitValues(Values node, AnalysisContext context) { - List> values = node.getValues(); - List> valueExprs = new ArrayList<>(); - for (List value : values) { - valueExprs.add(value.stream() - .map(val -> (LiteralExpression) expressionAnalyzer.analyze(val, context)) - .collect(Collectors.toList())); + + /** + * Build {@link LogicalAD} for AD command. + */ + @Override + public LogicalPlan visitAD(AD node, AnalysisContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + java.util.Map options = node.getArguments(); + + TypeEnvironment currentEnv = context.peek(); + + currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_SCORE), ExprCoreType.DOUBLE); + if (Objects.isNull(node.getArguments().get(TIME_FIELD))) { + currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_ANOMALOUS), ExprCoreType.BOOLEAN); + } else { + currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_ANOMALY_GRADE), ExprCoreType.DOUBLE); + currentEnv.define( + new Symbol(Namespace.FIELD_NAME, (String) node.getArguments().get(TIME_FIELD).getValue()), + ExprCoreType.TIMESTAMP + ); + } + return new LogicalAD(child, options); } - return new LogicalValues(valueExprs); - } - - /** - * Build {@link LogicalMLCommons} for Kmeans command. - */ - @Override - public LogicalPlan visitKmeans(Kmeans node, AnalysisContext context) { - LogicalPlan child = node.getChild().get(0).accept(this, context); - java.util.Map options = node.getArguments(); - - TypeEnvironment currentEnv = context.peek(); - currentEnv.define(new Symbol(Namespace.FIELD_NAME, "ClusterID"), ExprCoreType.INTEGER); - - return new LogicalMLCommons(child, "kmeans", options); - } - - /** - * Build {@link LogicalAD} for AD command. - */ - @Override - public LogicalPlan visitAD(AD node, AnalysisContext context) { - LogicalPlan child = node.getChild().get(0).accept(this, context); - java.util.Map options = node.getArguments(); - - TypeEnvironment currentEnv = context.peek(); - - currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_SCORE), ExprCoreType.DOUBLE); - if (Objects.isNull(node.getArguments().get(TIME_FIELD))) { - currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_ANOMALOUS), ExprCoreType.BOOLEAN); - } else { - currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_ANOMALY_GRADE), ExprCoreType.DOUBLE); - currentEnv.define(new Symbol(Namespace.FIELD_NAME, - (String) node.getArguments().get(TIME_FIELD).getValue()), ExprCoreType.TIMESTAMP); + + /** + * Build {@link LogicalML} for ml command. + */ + @Override + public LogicalPlan visitML(ML node, AnalysisContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + TypeEnvironment currentEnv = context.peek(); + node.getOutputSchema(currentEnv) + .entrySet() + .stream() + .forEach(v -> currentEnv.define(new Symbol(Namespace.FIELD_NAME, v.getKey()), v.getValue())); + + return new LogicalML(child, node.getArguments()); + } + + @Override + public LogicalPlan visitPaginate(Paginate paginate, AnalysisContext context) { + LogicalPlan child = paginate.getChild().get(0).accept(this, context); + return new LogicalPaginate(paginate.getPageSize(), List.of(child)); } - return new LogicalAD(child, options); - } - - /** - * Build {@link LogicalML} for ml command. - */ - @Override - public LogicalPlan visitML(ML node, AnalysisContext context) { - LogicalPlan child = node.getChild().get(0).accept(this, context); - TypeEnvironment currentEnv = context.peek(); - node.getOutputSchema(currentEnv).entrySet().stream() - .forEach(v -> currentEnv.define(new Symbol(Namespace.FIELD_NAME, v.getKey()), v.getValue())); - - return new LogicalML(child, node.getArguments()); - } - - @Override - public LogicalPlan visitPaginate(Paginate paginate, AnalysisContext context) { - LogicalPlan child = paginate.getChild().get(0).accept(this, context); - return new LogicalPaginate(paginate.getPageSize(), List.of(child)); - } - - @Override - public LogicalPlan visitFetchCursor(FetchCursor cursor, AnalysisContext context) { - return new LogicalFetchCursor(cursor.getCursor(), - dataSourceService.getDataSource(DEFAULT_DATASOURCE_NAME).getStorageEngine()); - } - - @Override - public LogicalPlan visitCloseCursor(CloseCursor closeCursor, AnalysisContext context) { - return new LogicalCloseCursor(closeCursor.getChild().get(0).accept(this, context)); - } - - /** - * The first argument is always "asc", others are optional. - * Given nullFirst argument, use its value. Otherwise just use DEFAULT_ASC/DESC. - */ - private SortOption analyzeSortOption(List fieldArgs) { - Boolean asc = (Boolean) fieldArgs.get(0).getValue().getValue(); - Optional nullFirst = fieldArgs.stream() - .filter(option -> "nullFirst".equals(option.getArgName())).findFirst(); - - if (nullFirst.isPresent()) { - Boolean isNullFirst = (Boolean) nullFirst.get().getValue().getValue(); - return new SortOption((asc ? ASC : DESC), (isNullFirst ? NULL_FIRST : NULL_LAST)); + + @Override + public LogicalPlan visitFetchCursor(FetchCursor cursor, AnalysisContext context) { + return new LogicalFetchCursor(cursor.getCursor(), dataSourceService.getDataSource(DEFAULT_DATASOURCE_NAME).getStorageEngine()); + } + + @Override + public LogicalPlan visitCloseCursor(CloseCursor closeCursor, AnalysisContext context) { + return new LogicalCloseCursor(closeCursor.getChild().get(0).accept(this, context)); + } + + /** + * The first argument is always "asc", others are optional. + * Given nullFirst argument, use its value. Otherwise just use DEFAULT_ASC/DESC. + */ + private SortOption analyzeSortOption(List fieldArgs) { + Boolean asc = (Boolean) fieldArgs.get(0).getValue().getValue(); + Optional nullFirst = fieldArgs.stream().filter(option -> "nullFirst".equals(option.getArgName())).findFirst(); + + if (nullFirst.isPresent()) { + Boolean isNullFirst = (Boolean) nullFirst.get().getValue().getValue(); + return new SortOption((asc ? ASC : DESC), (isNullFirst ? NULL_FIRST : NULL_LAST)); + } + return asc ? SortOption.DEFAULT_ASC : SortOption.DEFAULT_DESC; } - return asc ? SortOption.DEFAULT_ASC : SortOption.DEFAULT_DESC; - } } diff --git a/core/src/main/java/org/opensearch/sql/analysis/DataSourceSchemaIdentifierNameResolver.java b/core/src/main/java/org/opensearch/sql/analysis/DataSourceSchemaIdentifierNameResolver.java index 1bb8316907..c8bacdbe0a 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/DataSourceSchemaIdentifierNameResolver.java +++ b/core/src/main/java/org/opensearch/sql/analysis/DataSourceSchemaIdentifierNameResolver.java @@ -17,69 +17,64 @@ public class DataSourceSchemaIdentifierNameResolver { - public static final String DEFAULT_DATASOURCE_NAME = "@opensearch"; - public static final String DEFAULT_SCHEMA_NAME = "default"; - public static final String INFORMATION_SCHEMA_NAME = "information_schema"; + public static final String DEFAULT_DATASOURCE_NAME = "@opensearch"; + public static final String DEFAULT_SCHEMA_NAME = "default"; + public static final String INFORMATION_SCHEMA_NAME = "information_schema"; - private String dataSourceName = DEFAULT_DATASOURCE_NAME; - private String schemaName = DEFAULT_SCHEMA_NAME; - private String identifierName; - private DataSourceService dataSourceService; + private String dataSourceName = DEFAULT_DATASOURCE_NAME; + private String schemaName = DEFAULT_SCHEMA_NAME; + private String identifierName; + private DataSourceService dataSourceService; - private static final String DOT = "."; + private static final String DOT = "."; - /** - * Data model for capturing dataSourceName, schema and identifier from - * fully qualifiedName. In the current state, it is used to capture - * DataSourceSchemaTable name and DataSourceSchemaFunction in case of table - * functions. - * - * @param dataSourceService {@link DataSourceService}. - * @param parts parts of qualifiedName. - */ - public DataSourceSchemaIdentifierNameResolver(DataSourceService dataSourceService, - List parts) { - this.dataSourceService = dataSourceService; - List remainingParts - = captureSchemaName(captureDataSourceName(parts)); - identifierName = String.join(DOT, remainingParts); - } - - public String getIdentifierName() { - return identifierName; - } + /** + * Data model for capturing dataSourceName, schema and identifier from + * fully qualifiedName. In the current state, it is used to capture + * DataSourceSchemaTable name and DataSourceSchemaFunction in case of table + * functions. + * + * @param dataSourceService {@link DataSourceService}. + * @param parts parts of qualifiedName. + */ + public DataSourceSchemaIdentifierNameResolver(DataSourceService dataSourceService, List parts) { + this.dataSourceService = dataSourceService; + List remainingParts = captureSchemaName(captureDataSourceName(parts)); + identifierName = String.join(DOT, remainingParts); + } - public String getDataSourceName() { - return dataSourceName; - } + public String getIdentifierName() { + return identifierName; + } - public String getSchemaName() { - return schemaName; - } + public String getDataSourceName() { + return dataSourceName; + } + public String getSchemaName() { + return schemaName; + } - // Capture datasource name and return remaining parts(schema name and table name) - // from the fully qualified name. - private List captureDataSourceName(List parts) { - if (parts.size() > 1 && dataSourceService.dataSourceExists(parts.get(0))) { - dataSourceName = parts.get(0); - return parts.subList(1, parts.size()); - } else { - return parts; + // Capture datasource name and return remaining parts(schema name and table name) + // from the fully qualified name. + private List captureDataSourceName(List parts) { + if (parts.size() > 1 && dataSourceService.dataSourceExists(parts.get(0))) { + dataSourceName = parts.get(0); + return parts.subList(1, parts.size()); + } else { + return parts; + } } - } - // Capture schema name and return the remaining parts(table name ) - // in the fully qualified name. - private List captureSchemaName(List parts) { - if (parts.size() > 1 - && (DEFAULT_SCHEMA_NAME.equals(parts.get(0)) - || INFORMATION_SCHEMA_NAME.contains(parts.get(0)))) { - schemaName = parts.get(0); - return parts.subList(1, parts.size()); - } else { - return parts; + // Capture schema name and return the remaining parts(table name ) + // in the fully qualified name. + private List captureSchemaName(List parts) { + if (parts.size() > 1 && (DEFAULT_SCHEMA_NAME.equals(parts.get(0)) || INFORMATION_SCHEMA_NAME.contains(parts.get(0)))) { + schemaName = parts.get(0); + return parts.subList(1, parts.size()); + } else { + return parts; + } } - } } diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java index 601e3e00cc..ffacac23a5 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.analysis; import static org.opensearch.sql.ast.dsl.AstDSL.and; @@ -78,353 +77,323 @@ * Expression}. */ public class ExpressionAnalyzer extends AbstractNodeVisitor { - @Getter - private final BuiltinFunctionRepository repository; - - @Override - public Expression visitCast(Cast node, AnalysisContext context) { - final Expression expression = node.getExpression().accept(this, context); - return (Expression) repository - .compile(context.getFunctionProperties(), node.convertFunctionName(), - Collections.singletonList(expression)); - } - - public ExpressionAnalyzer( - BuiltinFunctionRepository repository) { - this.repository = repository; - } - - public Expression analyze(UnresolvedExpression unresolved, AnalysisContext context) { - return unresolved.accept(this, context); - } - - @Override - public Expression visitUnresolvedAttribute(UnresolvedAttribute node, AnalysisContext context) { - return visitIdentifier(node.getAttr(), context); - } - - @Override - public Expression visitEqualTo(EqualTo node, AnalysisContext context) { - Expression left = node.getLeft().accept(this, context); - Expression right = node.getRight().accept(this, context); - - return DSL.equal(left, right); - } - - @Override - public Expression visitLiteral(Literal node, AnalysisContext context) { - return DSL - .literal(ExprValueUtils.fromObjectValue(node.getValue(), node.getType().getCoreType())); - } - - @Override - public Expression visitInterval(Interval node, AnalysisContext context) { - Expression value = node.getValue().accept(this, context); - Expression unit = DSL.literal(node.getUnit().name()); - return DSL.interval(value, unit); - } - - @Override - public Expression visitAnd(And node, AnalysisContext context) { - Expression left = node.getLeft().accept(this, context); - Expression right = node.getRight().accept(this, context); - - return DSL.and(left, right); - } - - @Override - public Expression visitOr(Or node, AnalysisContext context) { - Expression left = node.getLeft().accept(this, context); - Expression right = node.getRight().accept(this, context); - - return DSL.or(left, right); - } - - @Override - public Expression visitXor(Xor node, AnalysisContext context) { - Expression left = node.getLeft().accept(this, context); - Expression right = node.getRight().accept(this, context); - - return DSL.xor(left, right); - } - - @Override - public Expression visitNot(Not node, AnalysisContext context) { - return DSL.not(node.getExpression().accept(this, context)); - } - - @Override - public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext context) { - Optional builtinFunctionName = - BuiltinFunctionName.ofAggregation(node.getFuncName()); - if (builtinFunctionName.isPresent()) { - ImmutableList.Builder builder = ImmutableList.builder(); - builder.add(node.getField().accept(this, context)); - for (UnresolvedExpression arg : node.getArgList()) { - builder.add(arg.accept(this, context)); - } - Aggregator aggregator = (Aggregator) repository.compile( - context.getFunctionProperties(), builtinFunctionName.get().getName(), builder.build()); - aggregator.distinct(node.getDistinct()); - if (node.condition() != null) { - aggregator.condition(analyze(node.condition(), context)); - } - return aggregator; - } else { - throw new SemanticCheckException("Unsupported aggregation function " + node.getFuncName()); + @Getter + private final BuiltinFunctionRepository repository; + + @Override + public Expression visitCast(Cast node, AnalysisContext context) { + final Expression expression = node.getExpression().accept(this, context); + return (Expression) repository.compile( + context.getFunctionProperties(), + node.convertFunctionName(), + Collections.singletonList(expression) + ); } - } - - @Override - public Expression visitRelevanceFieldList(RelevanceFieldList node, AnalysisContext context) { - return new LiteralExpression(ExprValueUtils.tupleValue( - ImmutableMap.copyOf(node.getFieldList()))); - } - - @Override - public Expression visitFunction(Function node, AnalysisContext context) { - FunctionName functionName = FunctionName.of(node.getFuncName()); - List arguments = - node.getFuncArgs().stream() - .map(unresolvedExpression -> { - var ret = analyze(unresolvedExpression, context); - if (ret == null) { - throw new UnsupportedOperationException( - String.format("Invalid use of expression %s", unresolvedExpression) - ); - } else { + + public ExpressionAnalyzer(BuiltinFunctionRepository repository) { + this.repository = repository; + } + + public Expression analyze(UnresolvedExpression unresolved, AnalysisContext context) { + return unresolved.accept(this, context); + } + + @Override + public Expression visitUnresolvedAttribute(UnresolvedAttribute node, AnalysisContext context) { + return visitIdentifier(node.getAttr(), context); + } + + @Override + public Expression visitEqualTo(EqualTo node, AnalysisContext context) { + Expression left = node.getLeft().accept(this, context); + Expression right = node.getRight().accept(this, context); + + return DSL.equal(left, right); + } + + @Override + public Expression visitLiteral(Literal node, AnalysisContext context) { + return DSL.literal(ExprValueUtils.fromObjectValue(node.getValue(), node.getType().getCoreType())); + } + + @Override + public Expression visitInterval(Interval node, AnalysisContext context) { + Expression value = node.getValue().accept(this, context); + Expression unit = DSL.literal(node.getUnit().name()); + return DSL.interval(value, unit); + } + + @Override + public Expression visitAnd(And node, AnalysisContext context) { + Expression left = node.getLeft().accept(this, context); + Expression right = node.getRight().accept(this, context); + + return DSL.and(left, right); + } + + @Override + public Expression visitOr(Or node, AnalysisContext context) { + Expression left = node.getLeft().accept(this, context); + Expression right = node.getRight().accept(this, context); + + return DSL.or(left, right); + } + + @Override + public Expression visitXor(Xor node, AnalysisContext context) { + Expression left = node.getLeft().accept(this, context); + Expression right = node.getRight().accept(this, context); + + return DSL.xor(left, right); + } + + @Override + public Expression visitNot(Not node, AnalysisContext context) { + return DSL.not(node.getExpression().accept(this, context)); + } + + @Override + public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext context) { + Optional builtinFunctionName = BuiltinFunctionName.ofAggregation(node.getFuncName()); + if (builtinFunctionName.isPresent()) { + ImmutableList.Builder builder = ImmutableList.builder(); + builder.add(node.getField().accept(this, context)); + for (UnresolvedExpression arg : node.getArgList()) { + builder.add(arg.accept(this, context)); + } + Aggregator aggregator = (Aggregator) repository.compile( + context.getFunctionProperties(), + builtinFunctionName.get().getName(), + builder.build() + ); + aggregator.distinct(node.getDistinct()); + if (node.condition() != null) { + aggregator.condition(analyze(node.condition(), context)); + } + return aggregator; + } else { + throw new SemanticCheckException("Unsupported aggregation function " + node.getFuncName()); + } + } + + @Override + public Expression visitRelevanceFieldList(RelevanceFieldList node, AnalysisContext context) { + return new LiteralExpression(ExprValueUtils.tupleValue(ImmutableMap.copyOf(node.getFieldList()))); + } + + @Override + public Expression visitFunction(Function node, AnalysisContext context) { + FunctionName functionName = FunctionName.of(node.getFuncName()); + List arguments = node.getFuncArgs().stream().map(unresolvedExpression -> { + var ret = analyze(unresolvedExpression, context); + if (ret == null) { + throw new UnsupportedOperationException(String.format("Invalid use of expression %s", unresolvedExpression)); + } else { return ret; - } - }) - .collect(Collectors.toList()); - return (Expression) repository.compile(context.getFunctionProperties(), - functionName, arguments); - } - - @SuppressWarnings("unchecked") - @Override - public Expression visitWindowFunction(WindowFunction node, AnalysisContext context) { - Expression expr = node.getFunction().accept(this, context); - // Wrap regular aggregator by aggregate window function to adapt window operator use - if (expr instanceof Aggregator) { - return new AggregateWindowFunction((Aggregator) expr); + } + }).collect(Collectors.toList()); + return (Expression) repository.compile(context.getFunctionProperties(), functionName, arguments); + } + + @SuppressWarnings("unchecked") + @Override + public Expression visitWindowFunction(WindowFunction node, AnalysisContext context) { + Expression expr = node.getFunction().accept(this, context); + // Wrap regular aggregator by aggregate window function to adapt window operator use + if (expr instanceof Aggregator) { + return new AggregateWindowFunction((Aggregator) expr); + } + return expr; } - return expr; - } - - @Override - public Expression visitHighlightFunction(HighlightFunction node, AnalysisContext context) { - Expression expr = node.getHighlightField().accept(this, context); - return new HighlightExpression(expr); - } - - /** - * visitScoreFunction removes the score function from the AST and replaces it with the child - * relevance function node. If the optional boost variable is provided, the boost argument - * of the relevance function is combined. - * - * @param node score function node - * @param context analysis context for the query - * @return resolved relevance function - */ - public Expression visitScoreFunction(ScoreFunction node, AnalysisContext context) { - Literal boostArg = node.getRelevanceFieldWeight(); - if (!boostArg.getType().equals(DataType.DOUBLE)) { - throw new SemanticCheckException(String.format("Expected boost type '%s' but got '%s'", - DataType.DOUBLE.name(), boostArg.getType().name())); + + @Override + public Expression visitHighlightFunction(HighlightFunction node, AnalysisContext context) { + Expression expr = node.getHighlightField().accept(this, context); + return new HighlightExpression(expr); } - Double thisBoostValue = ((Double) boostArg.getValue()); - - // update the existing unresolved expression to add a boost argument if it doesn't exist - // OR multiply the existing boost argument - Function relevanceQueryUnresolvedExpr = (Function) node.getRelevanceQuery(); - List relevanceFuncArgs = relevanceQueryUnresolvedExpr.getFuncArgs(); - - boolean doesFunctionContainBoostArgument = false; - List updatedFuncArgs = new ArrayList<>(); - for (UnresolvedExpression expr : relevanceFuncArgs) { - String argumentName = ((UnresolvedArgument) expr).getArgName(); - if (argumentName.equalsIgnoreCase("boost")) { - doesFunctionContainBoostArgument = true; - Literal boostArgLiteral = (Literal) ((UnresolvedArgument) expr).getValue(); - Double boostValue = - Double.parseDouble((String) boostArgLiteral.getValue()) * thisBoostValue; - UnresolvedArgument newBoostArg = new UnresolvedArgument( - argumentName, - new Literal(boostValue.toString(), DataType.STRING) - ); - updatedFuncArgs.add(newBoostArg); - } else { - updatedFuncArgs.add(expr); - } + + /** + * visitScoreFunction removes the score function from the AST and replaces it with the child + * relevance function node. If the optional boost variable is provided, the boost argument + * of the relevance function is combined. + * + * @param node score function node + * @param context analysis context for the query + * @return resolved relevance function + */ + public Expression visitScoreFunction(ScoreFunction node, AnalysisContext context) { + Literal boostArg = node.getRelevanceFieldWeight(); + if (!boostArg.getType().equals(DataType.DOUBLE)) { + throw new SemanticCheckException( + String.format("Expected boost type '%s' but got '%s'", DataType.DOUBLE.name(), boostArg.getType().name()) + ); + } + Double thisBoostValue = ((Double) boostArg.getValue()); + + // update the existing unresolved expression to add a boost argument if it doesn't exist + // OR multiply the existing boost argument + Function relevanceQueryUnresolvedExpr = (Function) node.getRelevanceQuery(); + List relevanceFuncArgs = relevanceQueryUnresolvedExpr.getFuncArgs(); + + boolean doesFunctionContainBoostArgument = false; + List updatedFuncArgs = new ArrayList<>(); + for (UnresolvedExpression expr : relevanceFuncArgs) { + String argumentName = ((UnresolvedArgument) expr).getArgName(); + if (argumentName.equalsIgnoreCase("boost")) { + doesFunctionContainBoostArgument = true; + Literal boostArgLiteral = (Literal) ((UnresolvedArgument) expr).getValue(); + Double boostValue = Double.parseDouble((String) boostArgLiteral.getValue()) * thisBoostValue; + UnresolvedArgument newBoostArg = new UnresolvedArgument(argumentName, new Literal(boostValue.toString(), DataType.STRING)); + updatedFuncArgs.add(newBoostArg); + } else { + updatedFuncArgs.add(expr); + } + } + + // since nothing was found, add an argument + if (!doesFunctionContainBoostArgument) { + UnresolvedArgument newBoostArg = new UnresolvedArgument("boost", new Literal(Double.toString(thisBoostValue), DataType.STRING)); + updatedFuncArgs.add(newBoostArg); + } + + // create a new function expression with boost argument and resolve it + Function updatedRelevanceQueryUnresolvedExpr = new Function(relevanceQueryUnresolvedExpr.getFuncName(), updatedFuncArgs); + OpenSearchFunctions.OpenSearchFunction relevanceQueryExpr = + (OpenSearchFunctions.OpenSearchFunction) updatedRelevanceQueryUnresolvedExpr.accept(this, context); + relevanceQueryExpr.setScoreTracked(true); + return relevanceQueryExpr; + } + + @Override + public Expression visitIn(In node, AnalysisContext context) { + return visitIn(node.getField(), node.getValueList(), context); + } + + private Expression visitIn(UnresolvedExpression field, List valueList, AnalysisContext context) { + if (valueList.size() == 1) { + return visitCompare(new Compare("=", field, valueList.get(0)), context); + } else if (valueList.size() > 1) { + return DSL.or( + visitCompare(new Compare("=", field, valueList.get(0)), context), + visitIn(field, valueList.subList(1, valueList.size()), context) + ); + } else { + throw new SemanticCheckException("Values in In clause should not be empty"); + } + } + + @Override + public Expression visitCompare(Compare node, AnalysisContext context) { + FunctionName functionName = FunctionName.of(node.getOperator()); + Expression left = analyze(node.getLeft(), context); + Expression right = analyze(node.getRight(), context); + return (Expression) repository.compile(context.getFunctionProperties(), functionName, Arrays.asList(left, right)); } - // since nothing was found, add an argument - if (!doesFunctionContainBoostArgument) { - UnresolvedArgument newBoostArg = new UnresolvedArgument( - "boost", new Literal(Double.toString(thisBoostValue), DataType.STRING)); - updatedFuncArgs.add(newBoostArg); + @Override + public Expression visitBetween(Between node, AnalysisContext context) { + return and(compare(">=", node.getValue(), node.getLowerBound()), compare("<=", node.getValue(), node.getUpperBound())).accept( + this, + context + ); } - // create a new function expression with boost argument and resolve it - Function updatedRelevanceQueryUnresolvedExpr = new Function( - relevanceQueryUnresolvedExpr.getFuncName(), - updatedFuncArgs); - OpenSearchFunctions.OpenSearchFunction relevanceQueryExpr = - (OpenSearchFunctions.OpenSearchFunction) updatedRelevanceQueryUnresolvedExpr - .accept(this, context); - relevanceQueryExpr.setScoreTracked(true); - return relevanceQueryExpr; - } - - @Override - public Expression visitIn(In node, AnalysisContext context) { - return visitIn(node.getField(), node.getValueList(), context); - } - - private Expression visitIn( - UnresolvedExpression field, List valueList, AnalysisContext context) { - if (valueList.size() == 1) { - return visitCompare(new Compare("=", field, valueList.get(0)), context); - } else if (valueList.size() > 1) { - return DSL.or( - visitCompare(new Compare("=", field, valueList.get(0)), context), - visitIn(field, valueList.subList(1, valueList.size()), context)); - } else { - throw new SemanticCheckException("Values in In clause should not be empty"); + @Override + public Expression visitCase(Case node, AnalysisContext context) { + List whens = new ArrayList<>(); + for (When when : node.getWhenClauses()) { + if (node.getCaseValue() == null) { + whens.add((WhenClause) analyze(when, context)); + } else { + // Merge case value and condition (compare value) into a single equal condition + whens.add( + (WhenClause) analyze( + new When(new Function("=", Arrays.asList(node.getCaseValue(), when.getCondition())), when.getResult()), + context + ) + ); + } + } + + Expression defaultResult = (node.getElseClause() == null) ? null : analyze(node.getElseClause(), context); + CaseClause caseClause = new CaseClause(whens, defaultResult); + + // To make this simple, require all result type same regardless of implicit convert + // Make CaseClause return list so it can be used in error message in determined order + List resultTypes = caseClause.allResultTypes(); + if (ImmutableSet.copyOf(resultTypes).size() > 1) { + throw new SemanticCheckException("All result types of CASE clause must be the same, but found " + resultTypes); + } + return caseClause; } - } - - @Override - public Expression visitCompare(Compare node, AnalysisContext context) { - FunctionName functionName = FunctionName.of(node.getOperator()); - Expression left = analyze(node.getLeft(), context); - Expression right = analyze(node.getRight(), context); - return (Expression) - repository.compile(context.getFunctionProperties(), - functionName, Arrays.asList(left, right)); - } - - @Override - public Expression visitBetween(Between node, AnalysisContext context) { - return and( - compare(">=", node.getValue(), node.getLowerBound()), - compare("<=", node.getValue(), node.getUpperBound()) - ).accept(this, context); - } - - @Override - public Expression visitCase(Case node, AnalysisContext context) { - List whens = new ArrayList<>(); - for (When when : node.getWhenClauses()) { - if (node.getCaseValue() == null) { - whens.add((WhenClause) analyze(when, context)); - } else { - // Merge case value and condition (compare value) into a single equal condition - whens.add((WhenClause) analyze( - new When( - new Function("=", Arrays.asList(node.getCaseValue(), when.getCondition())), - when.getResult() - ), context)); - } + + @Override + public Expression visitWhen(When node, AnalysisContext context) { + return new WhenClause(analyze(node.getCondition(), context), analyze(node.getResult(), context)); } - Expression defaultResult = (node.getElseClause() == null) - ? null : analyze(node.getElseClause(), context); - CaseClause caseClause = new CaseClause(whens, defaultResult); + @Override + public Expression visitField(Field node, AnalysisContext context) { + String attr = node.getField().toString(); + return visitIdentifier(attr, context); + } - // To make this simple, require all result type same regardless of implicit convert - // Make CaseClause return list so it can be used in error message in determined order - List resultTypes = caseClause.allResultTypes(); - if (ImmutableSet.copyOf(resultTypes).size() > 1) { - throw new SemanticCheckException( - "All result types of CASE clause must be the same, but found " + resultTypes); + @Override + public Expression visitAllFields(AllFields node, AnalysisContext context) { + // Convert to string literal for argument in COUNT(*), because there is no difference between + // COUNT(*) and COUNT(literal). For SELECT *, its select expression analyzer will expand * to + // the right field name list by itself. + return DSL.literal("*"); } - return caseClause; - } - - @Override - public Expression visitWhen(When node, AnalysisContext context) { - return new WhenClause( - analyze(node.getCondition(), context), - analyze(node.getResult(), context)); - } - - @Override - public Expression visitField(Field node, AnalysisContext context) { - String attr = node.getField().toString(); - return visitIdentifier(attr, context); - } - - @Override - public Expression visitAllFields(AllFields node, AnalysisContext context) { - // Convert to string literal for argument in COUNT(*), because there is no difference between - // COUNT(*) and COUNT(literal). For SELECT *, its select expression analyzer will expand * to - // the right field name list by itself. - return DSL.literal("*"); - } - - @Override - public Expression visitQualifiedName(QualifiedName node, AnalysisContext context) { - QualifierAnalyzer qualifierAnalyzer = new QualifierAnalyzer(context); - - // check for reserved words in the identifier - for (String part : node.getParts()) { - for (TypeEnvironment typeEnv = context.peek(); - typeEnv != null; - typeEnv = typeEnv.getParent()) { - Optional exprType = typeEnv.getReservedSymbolTable().lookup( - new Symbol(Namespace.FIELD_NAME, part)); - if (exprType.isPresent()) { - return visitMetadata( - qualifierAnalyzer.unqualified(node), - (ExprCoreType) exprType.get(), - context - ); + + @Override + public Expression visitQualifiedName(QualifiedName node, AnalysisContext context) { + QualifierAnalyzer qualifierAnalyzer = new QualifierAnalyzer(context); + + // check for reserved words in the identifier + for (String part : node.getParts()) { + for (TypeEnvironment typeEnv = context.peek(); typeEnv != null; typeEnv = typeEnv.getParent()) { + Optional exprType = typeEnv.getReservedSymbolTable().lookup(new Symbol(Namespace.FIELD_NAME, part)); + if (exprType.isPresent()) { + return visitMetadata(qualifierAnalyzer.unqualified(node), (ExprCoreType) exprType.get(), context); + } + } } - } + return visitIdentifier(qualifierAnalyzer.unqualified(node), context); } - return visitIdentifier(qualifierAnalyzer.unqualified(node), context); - } - - @Override - public Expression visitSpan(Span node, AnalysisContext context) { - return new SpanExpression( - node.getField().accept(this, context), - node.getValue().accept(this, context), - node.getUnit()); - } - - @Override - public Expression visitUnresolvedArgument(UnresolvedArgument node, AnalysisContext context) { - return new NamedArgumentExpression(node.getArgName(), node.getValue().accept(this, context)); - } - - /** - * If QualifiedName is actually a reserved metadata field, return the expr type associated - * with the metadata field. - * @param ident metadata field name - * @param context analysis context - * @return DSL reference - */ - private Expression visitMetadata(String ident, - ExprCoreType exprCoreType, - AnalysisContext context) { - return DSL.ref(ident, exprCoreType); - } - - private Expression visitIdentifier(String ident, AnalysisContext context) { - // ParseExpression will always override ReferenceExpression when ident conflicts - for (NamedExpression expr : context.getNamedParseExpressions()) { - if (expr.getNameOrAlias().equals(ident) && expr.getDelegated() instanceof ParseExpression) { - return expr.getDelegated(); - } + + @Override + public Expression visitSpan(Span node, AnalysisContext context) { + return new SpanExpression(node.getField().accept(this, context), node.getValue().accept(this, context), node.getUnit()); + } + + @Override + public Expression visitUnresolvedArgument(UnresolvedArgument node, AnalysisContext context) { + return new NamedArgumentExpression(node.getArgName(), node.getValue().accept(this, context)); + } + + /** + * If QualifiedName is actually a reserved metadata field, return the expr type associated + * with the metadata field. + * @param ident metadata field name + * @param context analysis context + * @return DSL reference + */ + private Expression visitMetadata(String ident, ExprCoreType exprCoreType, AnalysisContext context) { + return DSL.ref(ident, exprCoreType); } - TypeEnvironment typeEnv = context.peek(); - ReferenceExpression ref = DSL.ref(ident, - typeEnv.resolve(new Symbol(Namespace.FIELD_NAME, ident))); + private Expression visitIdentifier(String ident, AnalysisContext context) { + // ParseExpression will always override ReferenceExpression when ident conflicts + for (NamedExpression expr : context.getNamedParseExpressions()) { + if (expr.getNameOrAlias().equals(ident) && expr.getDelegated() instanceof ParseExpression) { + return expr.getDelegated(); + } + } + + TypeEnvironment typeEnv = context.peek(); + ReferenceExpression ref = DSL.ref(ident, typeEnv.resolve(new Symbol(Namespace.FIELD_NAME, ident))); - return ref; - } + return ref; + } } diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java index eaf5c4abca..05f4245ea1 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.analysis; import java.util.HashMap; @@ -38,124 +37,122 @@ * LogicalAgg(agg=[sum(age), avg(age)], group=[abs(age)]] * LogicalRelation */ -public class ExpressionReferenceOptimizer - extends ExpressionNodeVisitor { - private final BuiltinFunctionRepository repository; - - /** - * The map of expression and it's reference. - * For example, The NamedAggregator should produce the map of Aggregator to Ref(name) - */ - private final Map expressionMap = new HashMap<>(); - - public ExpressionReferenceOptimizer( - BuiltinFunctionRepository repository, LogicalPlan logicalPlan) { - this.repository = repository; - logicalPlan.accept(new ExpressionMapBuilder(), null); - } - - public Expression optimize(Expression analyzed, AnalysisContext context) { - return analyzed.accept(this, context); - } - - @Override - public Expression visitNode(Expression node, AnalysisContext context) { - return node; - } - - @Override - public Expression visitFunction(FunctionExpression node, AnalysisContext context) { - if (expressionMap.containsKey(node)) { - return expressionMap.get(node); - } else { - final List args = - node.getArguments().stream().map(expr -> expr.accept(this, context)) - .collect(Collectors.toList()); - Expression optimizedFunctionExpression = (Expression) repository.compile( - context.getFunctionProperties(), - node.getFunctionName(), - args - ); - // Propagate scoreTracked for OpenSearch functions - if (optimizedFunctionExpression instanceof OpenSearchFunctions.OpenSearchFunction) { - ((OpenSearchFunctions.OpenSearchFunction) optimizedFunctionExpression).setScoreTracked( - ((OpenSearchFunctions.OpenSearchFunction)node).isScoreTracked()); - } - return optimizedFunctionExpression; +public class ExpressionReferenceOptimizer extends ExpressionNodeVisitor { + private final BuiltinFunctionRepository repository; + + /** + * The map of expression and it's reference. + * For example, The NamedAggregator should produce the map of Aggregator to Ref(name) + */ + private final Map expressionMap = new HashMap<>(); + + public ExpressionReferenceOptimizer(BuiltinFunctionRepository repository, LogicalPlan logicalPlan) { + this.repository = repository; + logicalPlan.accept(new ExpressionMapBuilder(), null); } - } - - @Override - public Expression visitAggregator(Aggregator node, AnalysisContext context) { - return expressionMap.getOrDefault(node, node); - } - @Override - public Expression visitNamed(NamedExpression node, AnalysisContext context) { - if (expressionMap.containsKey(node)) { - return expressionMap.get(node); - } - return node.getDelegated().accept(this, context); - } - - /** - * Implement this because Case/When is not registered in function repository. - */ - @Override - public Expression visitCase(CaseClause node, AnalysisContext context) { - if (expressionMap.containsKey(node)) { - return expressionMap.get(node); + public Expression optimize(Expression analyzed, AnalysisContext context) { + return analyzed.accept(this, context); } - List whenClauses = node.getWhenClauses() - .stream() - .map(expr -> (WhenClause) expr.accept(this, context)) - .collect(Collectors.toList()); - Expression defaultResult = null; - if (node.getDefaultResult() != null) { - defaultResult = node.getDefaultResult().accept(this, context); + @Override + public Expression visitNode(Expression node, AnalysisContext context) { + return node; } - return new CaseClause(whenClauses, defaultResult); - } - - @Override - public Expression visitWhen(WhenClause node, AnalysisContext context) { - return new WhenClause( - node.getCondition().accept(this, context), - node.getResult().accept(this, context)); - } + @Override + public Expression visitFunction(FunctionExpression node, AnalysisContext context) { + if (expressionMap.containsKey(node)) { + return expressionMap.get(node); + } else { + final List args = node.getArguments().stream().map(expr -> expr.accept(this, context)).collect(Collectors.toList()); + Expression optimizedFunctionExpression = (Expression) repository.compile( + context.getFunctionProperties(), + node.getFunctionName(), + args + ); + // Propagate scoreTracked for OpenSearch functions + if (optimizedFunctionExpression instanceof OpenSearchFunctions.OpenSearchFunction) { + ((OpenSearchFunctions.OpenSearchFunction) optimizedFunctionExpression).setScoreTracked( + ((OpenSearchFunctions.OpenSearchFunction) node).isScoreTracked() + ); + } + return optimizedFunctionExpression; + } + } - /** - * Expression Map Builder. - */ - class ExpressionMapBuilder extends LogicalPlanNodeVisitor { + @Override + public Expression visitAggregator(Aggregator node, AnalysisContext context) { + return expressionMap.getOrDefault(node, node); + } @Override - public Void visitNode(LogicalPlan plan, Void context) { - plan.getChild().forEach(child -> child.accept(this, context)); - return null; + public Expression visitNamed(NamedExpression node, AnalysisContext context) { + if (expressionMap.containsKey(node)) { + return expressionMap.get(node); + } + return node.getDelegated().accept(this, context); } + /** + * Implement this because Case/When is not registered in function repository. + */ @Override - public Void visitAggregation(LogicalAggregation plan, Void context) { - // Create the mapping for all the aggregator. - plan.getAggregatorList().forEach(namedAggregator -> expressionMap - .put(namedAggregator.getDelegated(), - new ReferenceExpression(namedAggregator.getName(), namedAggregator.type()))); - // Create the mapping for all the group by. - plan.getGroupByList().forEach(groupBy -> expressionMap - .put(groupBy.getDelegated(), - new ReferenceExpression(groupBy.getNameOrAlias(), groupBy.type()))); - return null; + public Expression visitCase(CaseClause node, AnalysisContext context) { + if (expressionMap.containsKey(node)) { + return expressionMap.get(node); + } + + List whenClauses = node.getWhenClauses() + .stream() + .map(expr -> (WhenClause) expr.accept(this, context)) + .collect(Collectors.toList()); + Expression defaultResult = null; + if (node.getDefaultResult() != null) { + defaultResult = node.getDefaultResult().accept(this, context); + } + return new CaseClause(whenClauses, defaultResult); } @Override - public Void visitWindow(LogicalWindow plan, Void context) { - Expression windowFunc = plan.getWindowFunction(); - expressionMap.put(windowFunc, - new ReferenceExpression(((NamedExpression) windowFunc).getName(), windowFunc.type())); - return visitNode(plan, context); + public Expression visitWhen(WhenClause node, AnalysisContext context) { + return new WhenClause(node.getCondition().accept(this, context), node.getResult().accept(this, context)); + } + + /** + * Expression Map Builder. + */ + class ExpressionMapBuilder extends LogicalPlanNodeVisitor { + + @Override + public Void visitNode(LogicalPlan plan, Void context) { + plan.getChild().forEach(child -> child.accept(this, context)); + return null; + } + + @Override + public Void visitAggregation(LogicalAggregation plan, Void context) { + // Create the mapping for all the aggregator. + plan.getAggregatorList() + .forEach( + namedAggregator -> expressionMap.put( + namedAggregator.getDelegated(), + new ReferenceExpression(namedAggregator.getName(), namedAggregator.type()) + ) + ); + // Create the mapping for all the group by. + plan.getGroupByList() + .forEach( + groupBy -> expressionMap.put(groupBy.getDelegated(), new ReferenceExpression(groupBy.getNameOrAlias(), groupBy.type())) + ); + return null; + } + + @Override + public Void visitWindow(LogicalWindow plan, Void context) { + Expression windowFunc = plan.getWindowFunction(); + expressionMap.put(windowFunc, new ReferenceExpression(((NamedExpression) windowFunc).getName(), windowFunc.type())); + return visitNode(plan, context); + } } - } } diff --git a/core/src/main/java/org/opensearch/sql/analysis/HighlightAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/HighlightAnalyzer.java index 0a15c6bac8..800a901fa8 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/HighlightAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/HighlightAnalyzer.java @@ -20,23 +20,23 @@ */ @RequiredArgsConstructor public class HighlightAnalyzer extends AbstractNodeVisitor { - private final ExpressionAnalyzer expressionAnalyzer; - private final LogicalPlan child; + private final ExpressionAnalyzer expressionAnalyzer; + private final LogicalPlan child; - public LogicalPlan analyze(UnresolvedExpression projectItem, AnalysisContext context) { - LogicalPlan highlight = projectItem.accept(this, context); - return (highlight == null) ? child : highlight; - } - - @Override - public LogicalPlan visitAlias(Alias node, AnalysisContext context) { - UnresolvedExpression delegated = node.getDelegated(); - if (!(delegated instanceof HighlightFunction)) { - return null; + public LogicalPlan analyze(UnresolvedExpression projectItem, AnalysisContext context) { + LogicalPlan highlight = projectItem.accept(this, context); + return (highlight == null) ? child : highlight; } - HighlightFunction unresolved = (HighlightFunction) delegated; - Expression field = expressionAnalyzer.analyze(unresolved.getHighlightField(), context); - return new LogicalHighlight(child, field, unresolved.getArguments()); - } + @Override + public LogicalPlan visitAlias(Alias node, AnalysisContext context) { + UnresolvedExpression delegated = node.getDelegated(); + if (!(delegated instanceof HighlightFunction)) { + return null; + } + + HighlightFunction unresolved = (HighlightFunction) delegated; + Expression field = expressionAnalyzer.analyze(unresolved.getHighlightField(), context); + return new LogicalHighlight(child, field, unresolved.getArguments()); + } } diff --git a/core/src/main/java/org/opensearch/sql/analysis/NamedExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/NamedExpressionAnalyzer.java index 1d318c5588..9fa4226cd8 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/NamedExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/NamedExpressionAnalyzer.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.analysis; import lombok.RequiredArgsConstructor; @@ -21,32 +20,31 @@ * {@link NamedExpression}. */ @RequiredArgsConstructor -public class NamedExpressionAnalyzer extends - AbstractNodeVisitor { - private final ExpressionAnalyzer expressionAnalyzer; +public class NamedExpressionAnalyzer extends AbstractNodeVisitor { + private final ExpressionAnalyzer expressionAnalyzer; - /** - * Analyze Select fields. - */ - public NamedExpression analyze(UnresolvedExpression expression, - AnalysisContext analysisContext) { - return expression.accept(this, analysisContext); - } + /** + * Analyze Select fields. + */ + public NamedExpression analyze(UnresolvedExpression expression, AnalysisContext analysisContext) { + return expression.accept(this, analysisContext); + } - @Override - public NamedExpression visitAlias(Alias node, AnalysisContext context) { - return DSL.named( - unqualifiedNameIfFieldOnly(node, context), - node.getDelegated().accept(expressionAnalyzer, context), - node.getAlias()); - } + @Override + public NamedExpression visitAlias(Alias node, AnalysisContext context) { + return DSL.named( + unqualifiedNameIfFieldOnly(node, context), + node.getDelegated().accept(expressionAnalyzer, context), + node.getAlias() + ); + } - private String unqualifiedNameIfFieldOnly(Alias node, AnalysisContext context) { - UnresolvedExpression selectItem = node.getDelegated(); - if (selectItem instanceof QualifiedName) { - QualifierAnalyzer qualifierAnalyzer = new QualifierAnalyzer(context); - return qualifierAnalyzer.unqualified((QualifiedName) selectItem); + private String unqualifiedNameIfFieldOnly(Alias node, AnalysisContext context) { + UnresolvedExpression selectItem = node.getDelegated(); + if (selectItem instanceof QualifiedName) { + QualifierAnalyzer qualifierAnalyzer = new QualifierAnalyzer(context); + return qualifierAnalyzer.unqualified((QualifiedName) selectItem); + } + return node.getName(); } - return node.getName(); - } } diff --git a/core/src/main/java/org/opensearch/sql/analysis/NestedAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/NestedAnalyzer.java index f050824557..aab94eb955 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/NestedAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/NestedAnalyzer.java @@ -32,130 +32,116 @@ */ @RequiredArgsConstructor public class NestedAnalyzer extends AbstractNodeVisitor { - private final List namedExpressions; - private final ExpressionAnalyzer expressionAnalyzer; - private final LogicalPlan child; - - public LogicalPlan analyze(UnresolvedExpression projectItem, AnalysisContext context) { - LogicalPlan nested = projectItem.accept(this, context); - return (nested == null) ? child : nested; - } - - @Override - public LogicalPlan visitAlias(Alias node, AnalysisContext context) { - return node.getDelegated().accept(this, context); - } - - @Override - public LogicalPlan visitNestedAllTupleFields(NestedAllTupleFields node, AnalysisContext context) { - List> args = new ArrayList<>(); - for (NamedExpression namedExpr : namedExpressions) { - if (isNestedFunction(namedExpr.getDelegated())) { - ReferenceExpression field = - (ReferenceExpression) ((FunctionExpression) namedExpr.getDelegated()) - .getArguments().get(0); - - // If path is same as NestedAllTupleFields path - if (field.getAttr().substring(0, field.getAttr().lastIndexOf(".")) - .equalsIgnoreCase(node.getPath())) { - args.add(Map.of( - "field", field, - "path", new ReferenceExpression(node.getPath(), STRING))); + private final List namedExpressions; + private final ExpressionAnalyzer expressionAnalyzer; + private final LogicalPlan child; + + public LogicalPlan analyze(UnresolvedExpression projectItem, AnalysisContext context) { + LogicalPlan nested = projectItem.accept(this, context); + return (nested == null) ? child : nested; + } + + @Override + public LogicalPlan visitAlias(Alias node, AnalysisContext context) { + return node.getDelegated().accept(this, context); + } + + @Override + public LogicalPlan visitNestedAllTupleFields(NestedAllTupleFields node, AnalysisContext context) { + List> args = new ArrayList<>(); + for (NamedExpression namedExpr : namedExpressions) { + if (isNestedFunction(namedExpr.getDelegated())) { + ReferenceExpression field = (ReferenceExpression) ((FunctionExpression) namedExpr.getDelegated()).getArguments().get(0); + + // If path is same as NestedAllTupleFields path + if (field.getAttr().substring(0, field.getAttr().lastIndexOf(".")).equalsIgnoreCase(node.getPath())) { + args.add(Map.of("field", field, "path", new ReferenceExpression(node.getPath(), STRING))); + } + } } - } + + return mergeChildIfLogicalNested(args); } - return mergeChildIfLogicalNested(args); - } - - @Override - public LogicalPlan visitFunction(Function node, AnalysisContext context) { - if (node.getFuncName().equalsIgnoreCase(BuiltinFunctionName.NESTED.name())) { - - List expressions = node.getFuncArgs(); - validateArgs(expressions); - ReferenceExpression nestedField = - (ReferenceExpression)expressionAnalyzer.analyze(expressions.get(0), context); - Map args; - - // Path parameter is supplied - if (expressions.size() == 2) { - args = Map.of( - "field", nestedField, - "path", (ReferenceExpression)expressionAnalyzer.analyze(expressions.get(1), context) - ); - } else { - args = Map.of( - "field", (ReferenceExpression)expressionAnalyzer.analyze(expressions.get(0), context), - "path", generatePath(nestedField.toString()) - ); - } - - return mergeChildIfLogicalNested(new ArrayList<>(Arrays.asList(args))); + @Override + public LogicalPlan visitFunction(Function node, AnalysisContext context) { + if (node.getFuncName().equalsIgnoreCase(BuiltinFunctionName.NESTED.name())) { + + List expressions = node.getFuncArgs(); + validateArgs(expressions); + ReferenceExpression nestedField = (ReferenceExpression) expressionAnalyzer.analyze(expressions.get(0), context); + Map args; + + // Path parameter is supplied + if (expressions.size() == 2) { + args = Map.of("field", nestedField, "path", (ReferenceExpression) expressionAnalyzer.analyze(expressions.get(1), context)); + } else { + args = Map.of( + "field", + (ReferenceExpression) expressionAnalyzer.analyze(expressions.get(0), context), + "path", + generatePath(nestedField.toString()) + ); + } + + return mergeChildIfLogicalNested(new ArrayList<>(Arrays.asList(args))); + } + return null; } - return null; - } - - /** - * NestedAnalyzer visits all functions in SELECT clause, creates logical plans for each and - * merges them. This is to avoid another merge rule in LogicalPlanOptimizer:create(). - * @param args field and path params to add to logical plan. - * @return child of logical nested with added args, or new LogicalNested. - */ - private LogicalPlan mergeChildIfLogicalNested(List> args) { - if (child instanceof LogicalNested) { - for (var arg : args) { - ((LogicalNested) child).addFields(arg); - } - return child; + + /** + * NestedAnalyzer visits all functions in SELECT clause, creates logical plans for each and + * merges them. This is to avoid another merge rule in LogicalPlanOptimizer:create(). + * @param args field and path params to add to logical plan. + * @return child of logical nested with added args, or new LogicalNested. + */ + private LogicalPlan mergeChildIfLogicalNested(List> args) { + if (child instanceof LogicalNested) { + for (var arg : args) { + ((LogicalNested) child).addFields(arg); + } + return child; + } + return new LogicalNested(child, args, namedExpressions); } - return new LogicalNested(child, args, namedExpressions); - } - - /** - * Validate each parameter used in nested function in SELECT clause. Any supplied parameter - * for a nested function in a SELECT statement must be a valid qualified name, and the field - * parameter must be nested at least one level. - * @param args : Arguments in nested function. - */ - private void validateArgs(List args) { - if (args.size() < 1 || args.size() > 2) { - throw new IllegalArgumentException( - "on nested object only allowed 2 parameters (field,path) or 1 parameter (field)" - ); + + /** + * Validate each parameter used in nested function in SELECT clause. Any supplied parameter + * for a nested function in a SELECT statement must be a valid qualified name, and the field + * parameter must be nested at least one level. + * @param args : Arguments in nested function. + */ + private void validateArgs(List args) { + if (args.size() < 1 || args.size() > 2) { + throw new IllegalArgumentException("on nested object only allowed 2 parameters (field,path) or 1 parameter (field)"); + } + + for (int i = 0; i < args.size(); i++) { + if (!(args.get(i) instanceof QualifiedName)) { + throw new IllegalArgumentException(String.format("Illegal nested field name: %s", args.get(i).toString())); + } + if (i == 0 && ((QualifiedName) args.get(i)).getParts().size() < 2) { + throw new IllegalArgumentException(String.format("Illegal nested field name: %s", args.get(i).toString())); + } + } + } + + /** + * Generate nested path dynamically. Assumes at least one level of nesting in supplied string. + * @param field : Nested field to generate path of. + * @return : Path of field derived from last level of nesting. + */ + public static ReferenceExpression generatePath(String field) { + return new ReferenceExpression(field.substring(0, field.lastIndexOf(".")), STRING); } - for (int i = 0; i < args.size(); i++) { - if (!(args.get(i) instanceof QualifiedName)) { - throw new IllegalArgumentException( - String.format("Illegal nested field name: %s", args.get(i).toString()) - ); - } - if (i == 0 && ((QualifiedName)args.get(i)).getParts().size() < 2) { - throw new IllegalArgumentException( - String.format("Illegal nested field name: %s", args.get(i).toString()) - ); - } + /** + * Check if supplied expression is a nested function. + * @param expr Expression checking if is nested function. + * @return True if expression is a nested function. + */ + public static Boolean isNestedFunction(Expression expr) { + return (expr instanceof FunctionExpression + && ((FunctionExpression) expr).getFunctionName().getFunctionName().equalsIgnoreCase(BuiltinFunctionName.NESTED.name())); } - } - - /** - * Generate nested path dynamically. Assumes at least one level of nesting in supplied string. - * @param field : Nested field to generate path of. - * @return : Path of field derived from last level of nesting. - */ - public static ReferenceExpression generatePath(String field) { - return new ReferenceExpression(field.substring(0, field.lastIndexOf(".")), STRING); - } - - /** - * Check if supplied expression is a nested function. - * @param expr Expression checking if is nested function. - * @return True if expression is a nested function. - */ - public static Boolean isNestedFunction(Expression expr) { - return (expr instanceof FunctionExpression - && ((FunctionExpression) expr).getFunctionName().getFunctionName() - .equalsIgnoreCase(BuiltinFunctionName.NESTED.name())); - } } diff --git a/core/src/main/java/org/opensearch/sql/analysis/QualifierAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/QualifierAnalyzer.java index d1e31d0079..344db5bf18 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/QualifierAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/QualifierAnalyzer.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.analysis; import java.util.Arrays; @@ -21,55 +20,59 @@ @RequiredArgsConstructor public class QualifierAnalyzer { - private final AnalysisContext context; + private final AnalysisContext context; - public String unqualified(String... parts) { - return unqualified(QualifiedName.of(Arrays.asList(parts))); - } + public String unqualified(String... parts) { + return unqualified(QualifiedName.of(Arrays.asList(parts))); + } - /** - * Get unqualified name if its qualifier symbol found is in index namespace - * on type environment. Unqualified name means name with qualifier removed. - * For example, unqualified name of "accounts.age" or "acc.age" is "age". - * - * @return unqualified name if criteria met above, otherwise original name - */ - public String unqualified(QualifiedName fullName) { - return isQualifierIndexOrAlias(fullName) ? fullName.rest().toString() : fullName.toString(); - } + /** + * Get unqualified name if its qualifier symbol found is in index namespace + * on type environment. Unqualified name means name with qualifier removed. + * For example, unqualified name of "accounts.age" or "acc.age" is "age". + * + * @return unqualified name if criteria met above, otherwise original name + */ + public String unqualified(QualifiedName fullName) { + return isQualifierIndexOrAlias(fullName) ? fullName.rest().toString() : fullName.toString(); + } - private boolean isQualifierIndexOrAlias(QualifiedName fullName) { - Optional qualifier = fullName.first(); - if (qualifier.isPresent()) { - if (isFieldName(qualifier.get())) { + private boolean isQualifierIndexOrAlias(QualifiedName fullName) { + Optional qualifier = fullName.first(); + if (qualifier.isPresent()) { + if (isFieldName(qualifier.get())) { + return false; + } + resolveQualifierSymbol(fullName, qualifier.get()); + return true; + } return false; - } - resolveQualifierSymbol(fullName, qualifier.get()); - return true; } - return false; - } - private boolean isFieldName(String qualifier) { - try { - // Resolve the qualifier in Namespace.FIELD_NAME - context.peek().resolve(new Symbol(Namespace.FIELD_NAME, qualifier)); - return true; - } catch (SemanticCheckException e2) { - return false; + private boolean isFieldName(String qualifier) { + try { + // Resolve the qualifier in Namespace.FIELD_NAME + context.peek().resolve(new Symbol(Namespace.FIELD_NAME, qualifier)); + return true; + } catch (SemanticCheckException e2) { + return false; + } } - } - private void resolveQualifierSymbol(QualifiedName fullName, String qualifier) { - try { - context.peek().resolve(new Symbol(Namespace.INDEX_NAME, qualifier)); - } catch (SemanticCheckException e) { - // Throw syntax check intentionally to indicate fall back to old engine. - // Need change to semantic check exception in future. - throw new SyntaxCheckException(String.format( - "The qualifier [%s] of qualified name [%s] must be an field name, index name or its " - + "alias", qualifier, fullName)); + private void resolveQualifierSymbol(QualifiedName fullName, String qualifier) { + try { + context.peek().resolve(new Symbol(Namespace.INDEX_NAME, qualifier)); + } catch (SemanticCheckException e) { + // Throw syntax check intentionally to indicate fall back to old engine. + // Need change to semantic check exception in future. + throw new SyntaxCheckException( + String.format( + "The qualifier [%s] of qualified name [%s] must be an field name, index name or its " + "alias", + qualifier, + fullName + ) + ); + } } - } } diff --git a/core/src/main/java/org/opensearch/sql/analysis/SelectExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/SelectExpressionAnalyzer.java index 734f37378b..ff4bcd621b 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/SelectExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/SelectExpressionAnalyzer.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.analysis; import com.google.common.collect.ImmutableList; @@ -34,120 +33,106 @@ * {@link NamedExpression}. */ @RequiredArgsConstructor -public class SelectExpressionAnalyzer - extends - AbstractNodeVisitor, AnalysisContext> { - private final ExpressionAnalyzer expressionAnalyzer; - - private ExpressionReferenceOptimizer optimizer; - - /** - * Analyze Select fields. - */ - public List analyze(List selectList, - AnalysisContext analysisContext, - ExpressionReferenceOptimizer optimizer) { - this.optimizer = optimizer; - ImmutableList.Builder builder = new ImmutableList.Builder<>(); - for (UnresolvedExpression unresolvedExpression : selectList) { - builder.addAll(unresolvedExpression.accept(this, analysisContext)); +public class SelectExpressionAnalyzer extends AbstractNodeVisitor, AnalysisContext> { + private final ExpressionAnalyzer expressionAnalyzer; + + private ExpressionReferenceOptimizer optimizer; + + /** + * Analyze Select fields. + */ + public List analyze( + List selectList, + AnalysisContext analysisContext, + ExpressionReferenceOptimizer optimizer + ) { + this.optimizer = optimizer; + ImmutableList.Builder builder = new ImmutableList.Builder<>(); + for (UnresolvedExpression unresolvedExpression : selectList) { + builder.addAll(unresolvedExpression.accept(this, analysisContext)); + } + return builder.build(); } - return builder.build(); - } - @Override - public List visitField(Field node, AnalysisContext context) { - return Collections.singletonList(DSL.named(node.accept(expressionAnalyzer, context))); - } - - @Override - public List visitAlias(Alias node, AnalysisContext context) { - // Expand all nested fields if used in SELECT clause - if (node.getDelegated() instanceof NestedAllTupleFields) { - return node.getDelegated().accept(this, context); + @Override + public List visitField(Field node, AnalysisContext context) { + return Collections.singletonList(DSL.named(node.accept(expressionAnalyzer, context))); } - Expression expr = referenceIfSymbolDefined(node, context); - return Collections.singletonList(DSL.named( - unqualifiedNameIfFieldOnly(node, context), - expr, - node.getAlias())); - } - - /** - * The Alias could be - * 1. SELECT name, AVG(age) FROM s BY name -> - * Project(Alias("name", expr), Alias("AVG(age)", aggExpr)) - * Agg(Alias("AVG(age)", aggExpr)) - * 2. SELECT length(name), AVG(age) FROM s BY length(name) - * Project(Alias("name", expr), Alias("AVG(age)", aggExpr)) - * Agg(Alias("AVG(age)", aggExpr)) - * 3. SELECT length(name) as l, AVG(age) FROM s BY l - * Project(Alias("name", expr, l), Alias("AVG(age)", aggExpr)) - * Agg(Alias("AVG(age)", aggExpr), Alias("length(name)", groupExpr)) - */ - private Expression referenceIfSymbolDefined(Alias expr, - AnalysisContext context) { - UnresolvedExpression delegatedExpr = expr.getDelegated(); + @Override + public List visitAlias(Alias node, AnalysisContext context) { + // Expand all nested fields if used in SELECT clause + if (node.getDelegated() instanceof NestedAllTupleFields) { + return node.getDelegated().accept(this, context); + } - // Pass named expression because expression like window function loses full name - // (OVER clause) and thus depends on name in alias to be replaced correctly - return optimizer.optimize( - DSL.named( - expr.getName(), - delegatedExpr.accept(expressionAnalyzer, context), - expr.getAlias()), - context); - } + Expression expr = referenceIfSymbolDefined(node, context); + return Collections.singletonList(DSL.named(unqualifiedNameIfFieldOnly(node, context), expr, node.getAlias())); + } - @Override - public List visitAllFields(AllFields node, - AnalysisContext context) { - TypeEnvironment environment = context.peek(); - Map lookupAllFields = environment.lookupAllFields(Namespace.FIELD_NAME); - return lookupAllFields.entrySet().stream().map(entry -> DSL.named(entry.getKey(), - new ReferenceExpression(entry.getKey(), entry.getValue()))).collect(Collectors.toList()); - } + /** + * The Alias could be + * 1. SELECT name, AVG(age) FROM s BY name -> + * Project(Alias("name", expr), Alias("AVG(age)", aggExpr)) + * Agg(Alias("AVG(age)", aggExpr)) + * 2. SELECT length(name), AVG(age) FROM s BY length(name) + * Project(Alias("name", expr), Alias("AVG(age)", aggExpr)) + * Agg(Alias("AVG(age)", aggExpr)) + * 3. SELECT length(name) as l, AVG(age) FROM s BY l + * Project(Alias("name", expr, l), Alias("AVG(age)", aggExpr)) + * Agg(Alias("AVG(age)", aggExpr), Alias("length(name)", groupExpr)) + */ + private Expression referenceIfSymbolDefined(Alias expr, AnalysisContext context) { + UnresolvedExpression delegatedExpr = expr.getDelegated(); + + // Pass named expression because expression like window function loses full name + // (OVER clause) and thus depends on name in alias to be replaced correctly + return optimizer.optimize(DSL.named(expr.getName(), delegatedExpr.accept(expressionAnalyzer, context), expr.getAlias()), context); + } - @Override - public List visitNestedAllTupleFields(NestedAllTupleFields node, - AnalysisContext context) { - TypeEnvironment environment = context.peek(); - Map lookupAllTupleFields = - environment.lookupAllTupleFields(Namespace.FIELD_NAME); - environment.resolve(new Symbol(Namespace.FIELD_NAME, node.getPath())); + @Override + public List visitAllFields(AllFields node, AnalysisContext context) { + TypeEnvironment environment = context.peek(); + Map lookupAllFields = environment.lookupAllFields(Namespace.FIELD_NAME); + return lookupAllFields.entrySet() + .stream() + .map(entry -> DSL.named(entry.getKey(), new ReferenceExpression(entry.getKey(), entry.getValue()))) + .collect(Collectors.toList()); + } - // Match all fields with same path as used in nested function. - Pattern p = Pattern.compile(node.getPath() + "\\.[^\\.]+$"); - return lookupAllTupleFields.entrySet().stream() - .filter(field -> p.matcher(field.getKey()).find()) - .map(entry -> { - Expression nestedFunc = new Function( - "nested", - List.of( - new QualifiedName(List.of(entry.getKey().split("\\.")))) - ).accept(expressionAnalyzer, context); - return DSL.named("nested(" + entry.getKey() + ")", nestedFunc); - }) - .collect(Collectors.toList()); - } + @Override + public List visitNestedAllTupleFields(NestedAllTupleFields node, AnalysisContext context) { + TypeEnvironment environment = context.peek(); + Map lookupAllTupleFields = environment.lookupAllTupleFields(Namespace.FIELD_NAME); + environment.resolve(new Symbol(Namespace.FIELD_NAME, node.getPath())); + + // Match all fields with same path as used in nested function. + Pattern p = Pattern.compile(node.getPath() + "\\.[^\\.]+$"); + return lookupAllTupleFields.entrySet().stream().filter(field -> p.matcher(field.getKey()).find()).map(entry -> { + Expression nestedFunc = new Function("nested", List.of(new QualifiedName(List.of(entry.getKey().split("\\."))))).accept( + expressionAnalyzer, + context + ); + return DSL.named("nested(" + entry.getKey() + ")", nestedFunc); + }).collect(Collectors.toList()); + } - /** - * Get unqualified name if select item is just a field. For example, suppose an index - * named "accounts", return "age" for "SELECT accounts.age". But do nothing for expression - * in "SELECT ABS(accounts.age)". - * Note that an assumption is made implicitly that original name field in Alias must be - * the same as the values in QualifiedName. This is true because AST builder does this. - * Otherwise, what unqualified() returns will override Alias's name as NamedExpression's name - * even though the QualifiedName doesn't have qualifier. - */ - private String unqualifiedNameIfFieldOnly(Alias node, AnalysisContext context) { - UnresolvedExpression selectItem = node.getDelegated(); - if (selectItem instanceof QualifiedName) { - QualifierAnalyzer qualifierAnalyzer = new QualifierAnalyzer(context); - return qualifierAnalyzer.unqualified((QualifiedName) selectItem); + /** + * Get unqualified name if select item is just a field. For example, suppose an index + * named "accounts", return "age" for "SELECT accounts.age". But do nothing for expression + * in "SELECT ABS(accounts.age)". + * Note that an assumption is made implicitly that original name field in Alias must be + * the same as the values in QualifiedName. This is true because AST builder does this. + * Otherwise, what unqualified() returns will override Alias's name as NamedExpression's name + * even though the QualifiedName doesn't have qualifier. + */ + private String unqualifiedNameIfFieldOnly(Alias node, AnalysisContext context) { + UnresolvedExpression selectItem = node.getDelegated(); + if (selectItem instanceof QualifiedName) { + QualifierAnalyzer qualifierAnalyzer = new QualifierAnalyzer(context); + return qualifierAnalyzer.unqualified((QualifiedName) selectItem); + } + return node.getName(); } - return node.getName(); - } } diff --git a/core/src/main/java/org/opensearch/sql/analysis/TypeEnvironment.java b/core/src/main/java/org/opensearch/sql/analysis/TypeEnvironment.java index 17d203f66f..556b6040bb 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/TypeEnvironment.java +++ b/core/src/main/java/org/opensearch/sql/analysis/TypeEnvironment.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.analysis; import static org.opensearch.sql.analysis.symbol.Namespace.FIELD_NAME; @@ -25,116 +24,114 @@ * The definition of Type Environment. */ public class TypeEnvironment implements Environment { - @Getter - private final TypeEnvironment parent; - private final SymbolTable symbolTable; - - @Getter - private final SymbolTable reservedSymbolTable; - - /** - * Constructor with empty symbol tables. - * - * @param parent parent environment - */ - public TypeEnvironment(TypeEnvironment parent) { - this.parent = parent; - this.symbolTable = new SymbolTable(); - this.reservedSymbolTable = new SymbolTable(); - } - - /** - * Constructor with empty reserved symbol table. - * - * @param parent parent environment - * @param symbolTable type table - */ - public TypeEnvironment(TypeEnvironment parent, SymbolTable symbolTable) { - this.parent = parent; - this.symbolTable = symbolTable; - this.reservedSymbolTable = new SymbolTable(); - } - - /** - * Resolve the {@link Expression} from environment. - * - * @param symbol Symbol - * @return resolved {@link ExprType} - */ - @Override - public ExprType resolve(Symbol symbol) { - for (TypeEnvironment cur = this; cur != null; cur = cur.parent) { - Optional typeOptional = cur.symbolTable.lookup(symbol); - if (typeOptional.isPresent()) { - return typeOptional.get(); - } + @Getter + private final TypeEnvironment parent; + private final SymbolTable symbolTable; + + @Getter + private final SymbolTable reservedSymbolTable; + + /** + * Constructor with empty symbol tables. + * + * @param parent parent environment + */ + public TypeEnvironment(TypeEnvironment parent) { + this.parent = parent; + this.symbolTable = new SymbolTable(); + this.reservedSymbolTable = new SymbolTable(); + } + + /** + * Constructor with empty reserved symbol table. + * + * @param parent parent environment + * @param symbolTable type table + */ + public TypeEnvironment(TypeEnvironment parent, SymbolTable symbolTable) { + this.parent = parent; + this.symbolTable = symbolTable; + this.reservedSymbolTable = new SymbolTable(); + } + + /** + * Resolve the {@link Expression} from environment. + * + * @param symbol Symbol + * @return resolved {@link ExprType} + */ + @Override + public ExprType resolve(Symbol symbol) { + for (TypeEnvironment cur = this; cur != null; cur = cur.parent) { + Optional typeOptional = cur.symbolTable.lookup(symbol); + if (typeOptional.isPresent()) { + return typeOptional.get(); + } + } + throw new SemanticCheckException(String.format("can't resolve %s in type env", symbol)); + } + + /** + * Resolve all fields in the current environment. + * + * @param namespace a namespace + * @return all symbols in the namespace + */ + public Map lookupAllFields(Namespace namespace) { + Map result = new LinkedHashMap<>(); + symbolTable.lookupAllFields(namespace).forEach(result::putIfAbsent); + return result; + } + + /** + * Resolve all fields in the current environment. + * @param namespace a namespace + * @return all symbols in the namespace + */ + public Map lookupAllTupleFields(Namespace namespace) { + Map result = new LinkedHashMap<>(); + symbolTable.lookupAllTupleFields(namespace).forEach(result::putIfAbsent); + return result; + } + + /** + * Define symbol with the type. + * + * @param symbol symbol to define + * @param type type + */ + public void define(Symbol symbol, ExprType type) { + symbolTable.store(symbol, type); + } + + /** + * Define expression with the type. + * + * @param ref {@link ReferenceExpression} + */ + public void define(ReferenceExpression ref) { + define(new Symbol(FIELD_NAME, ref.getAttr()), ref.type()); + } + + public void remove(Symbol symbol) { + symbolTable.remove(symbol); + } + + /** + * Remove ref. + */ + public void remove(ReferenceExpression ref) { + remove(new Symbol(FIELD_NAME, ref.getAttr())); + } + + /** + * Clear all fields in the current environment. + */ + public void clearAllFields() { + lookupAllFields(FIELD_NAME).keySet().forEach(v -> remove(new Symbol(Namespace.FIELD_NAME, v))); + } + + public void addReservedWord(Symbol symbol, ExprType type) { + reservedSymbolTable.store(symbol, type); } - throw new SemanticCheckException( - String.format("can't resolve %s in type env", symbol)); - } - - /** - * Resolve all fields in the current environment. - * - * @param namespace a namespace - * @return all symbols in the namespace - */ - public Map lookupAllFields(Namespace namespace) { - Map result = new LinkedHashMap<>(); - symbolTable.lookupAllFields(namespace).forEach(result::putIfAbsent); - return result; - } - - /** - * Resolve all fields in the current environment. - * @param namespace a namespace - * @return all symbols in the namespace - */ - public Map lookupAllTupleFields(Namespace namespace) { - Map result = new LinkedHashMap<>(); - symbolTable.lookupAllTupleFields(namespace).forEach(result::putIfAbsent); - return result; - } - - /** - * Define symbol with the type. - * - * @param symbol symbol to define - * @param type type - */ - public void define(Symbol symbol, ExprType type) { - symbolTable.store(symbol, type); - } - - /** - * Define expression with the type. - * - * @param ref {@link ReferenceExpression} - */ - public void define(ReferenceExpression ref) { - define(new Symbol(FIELD_NAME, ref.getAttr()), ref.type()); - } - - public void remove(Symbol symbol) { - symbolTable.remove(symbol); - } - - /** - * Remove ref. - */ - public void remove(ReferenceExpression ref) { - remove(new Symbol(FIELD_NAME, ref.getAttr())); - } - - /** - * Clear all fields in the current environment. - */ - public void clearAllFields() { - lookupAllFields(FIELD_NAME).keySet().forEach( - v -> remove(new Symbol(Namespace.FIELD_NAME, v))); - } - - public void addReservedWord(Symbol symbol, ExprType type) { - reservedSymbolTable.store(symbol, type); - } } diff --git a/core/src/main/java/org/opensearch/sql/analysis/WindowExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/WindowExpressionAnalyzer.java index 3abcf9e140..c917de69ca 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/WindowExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/WindowExpressionAnalyzer.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.analysis; import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC; @@ -35,81 +34,69 @@ @RequiredArgsConstructor public class WindowExpressionAnalyzer extends AbstractNodeVisitor { - /** - * Expression analyzer. - */ - private final ExpressionAnalyzer expressionAnalyzer; - - /** - * Child node to be wrapped by a new window operator. - */ - private final LogicalPlan child; - - /** - * Analyze the given project item and return window operator (with child node inside) - * if the given project item is a window function. - * @param projectItem project item - * @param context analysis context - * @return window operator or original child if not windowed - */ - public LogicalPlan analyze(UnresolvedExpression projectItem, AnalysisContext context) { - LogicalPlan window = projectItem.accept(this, context); - return (window == null) ? child : window; - } - - @Override - public LogicalPlan visitAlias(Alias node, AnalysisContext context) { - if (!(node.getDelegated() instanceof WindowFunction)) { - return null; + /** + * Expression analyzer. + */ + private final ExpressionAnalyzer expressionAnalyzer; + + /** + * Child node to be wrapped by a new window operator. + */ + private final LogicalPlan child; + + /** + * Analyze the given project item and return window operator (with child node inside) + * if the given project item is a window function. + * @param projectItem project item + * @param context analysis context + * @return window operator or original child if not windowed + */ + public LogicalPlan analyze(UnresolvedExpression projectItem, AnalysisContext context) { + LogicalPlan window = projectItem.accept(this, context); + return (window == null) ? child : window; } - WindowFunction unresolved = (WindowFunction) node.getDelegated(); - Expression windowFunction = expressionAnalyzer.analyze(unresolved, context); - List partitionByList = analyzePartitionList(unresolved, context); - List> sortList = analyzeSortList(unresolved, context); - - WindowDefinition windowDefinition = new WindowDefinition(partitionByList, sortList); - NamedExpression namedWindowFunction = - new NamedExpression(node.getName(), windowFunction, node.getAlias()); - List> allSortItems = windowDefinition.getAllSortItems(); - - if (allSortItems.isEmpty()) { - return new LogicalWindow(child, namedWindowFunction, windowDefinition); + @Override + public LogicalPlan visitAlias(Alias node, AnalysisContext context) { + if (!(node.getDelegated() instanceof WindowFunction)) { + return null; + } + + WindowFunction unresolved = (WindowFunction) node.getDelegated(); + Expression windowFunction = expressionAnalyzer.analyze(unresolved, context); + List partitionByList = analyzePartitionList(unresolved, context); + List> sortList = analyzeSortList(unresolved, context); + + WindowDefinition windowDefinition = new WindowDefinition(partitionByList, sortList); + NamedExpression namedWindowFunction = new NamedExpression(node.getName(), windowFunction, node.getAlias()); + List> allSortItems = windowDefinition.getAllSortItems(); + + if (allSortItems.isEmpty()) { + return new LogicalWindow(child, namedWindowFunction, windowDefinition); + } + return new LogicalWindow(new LogicalSort(child, allSortItems), namedWindowFunction, windowDefinition); } - return new LogicalWindow( - new LogicalSort(child, allSortItems), - namedWindowFunction, - windowDefinition); - } - private List analyzePartitionList(WindowFunction node, AnalysisContext context) { - return node.getPartitionByList() - .stream() - .map(expr -> expressionAnalyzer.analyze(expr, context)) - .collect(Collectors.toList()); - } + private List analyzePartitionList(WindowFunction node, AnalysisContext context) { + return node.getPartitionByList().stream().map(expr -> expressionAnalyzer.analyze(expr, context)).collect(Collectors.toList()); + } - private List> analyzeSortList(WindowFunction node, - AnalysisContext context) { - return node.getSortList() - .stream() - .map(pair -> ImmutablePair - .of(analyzeSortOption(pair.getLeft()), - expressionAnalyzer.analyze(pair.getRight(), context))) - .collect(Collectors.toList()); - } + private List> analyzeSortList(WindowFunction node, AnalysisContext context) { + return node.getSortList() + .stream() + .map(pair -> ImmutablePair.of(analyzeSortOption(pair.getLeft()), expressionAnalyzer.analyze(pair.getRight(), context))) + .collect(Collectors.toList()); + } - /** - * Frontend creates sort option from query directly which means sort or null order may be null. - * The final and default value for each is determined here during expression analysis. - */ - private SortOption analyzeSortOption(SortOption option) { - if (option.getNullOrder() == null) { - return (option.getSortOrder() == DESC) ? DEFAULT_DESC : DEFAULT_ASC; + /** + * Frontend creates sort option from query directly which means sort or null order may be null. + * The final and default value for each is determined here during expression analysis. + */ + private SortOption analyzeSortOption(SortOption option) { + if (option.getNullOrder() == null) { + return (option.getSortOrder() == DESC) ? DEFAULT_DESC : DEFAULT_ASC; + } + return new SortOption((option.getSortOrder() == DESC) ? DESC : ASC, option.getNullOrder()); } - return new SortOption( - (option.getSortOrder() == DESC) ? DESC : ASC, - option.getNullOrder()); - } } diff --git a/core/src/main/java/org/opensearch/sql/analysis/symbol/Namespace.java b/core/src/main/java/org/opensearch/sql/analysis/symbol/Namespace.java index b5203033a8..bf005924fc 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/symbol/Namespace.java +++ b/core/src/main/java/org/opensearch/sql/analysis/symbol/Namespace.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.analysis.symbol; /** @@ -11,14 +10,14 @@ */ public enum Namespace { - INDEX_NAME("Index"), - FIELD_NAME("Field"), - FUNCTION_NAME("Function"); + INDEX_NAME("Index"), + FIELD_NAME("Field"), + FUNCTION_NAME("Function"); - private final String name; + private final String name; - Namespace(String name) { - this.name = name; - } + Namespace(String name) { + this.name = name; + } } diff --git a/core/src/main/java/org/opensearch/sql/analysis/symbol/Symbol.java b/core/src/main/java/org/opensearch/sql/analysis/symbol/Symbol.java index 8cc9505710..748c60b6d2 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/symbol/Symbol.java +++ b/core/src/main/java/org/opensearch/sql/analysis/symbol/Symbol.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.analysis.symbol; import lombok.Getter; @@ -17,6 +16,6 @@ @Getter @RequiredArgsConstructor public class Symbol { - private final Namespace namespace; - private final String name; + private final Namespace namespace; + private final String name; } diff --git a/core/src/main/java/org/opensearch/sql/analysis/symbol/SymbolTable.java b/core/src/main/java/org/opensearch/sql/analysis/symbol/SymbolTable.java index be7435c288..28296e2ac2 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/symbol/SymbolTable.java +++ b/core/src/main/java/org/opensearch/sql/analysis/symbol/SymbolTable.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.analysis.symbol; import static java.util.Collections.emptyMap; @@ -22,134 +21,117 @@ */ public class SymbolTable { - /** - * Two-dimension hash table to manage symbols with type in different namespace. - */ - private Map> tableByNamespace = - new EnumMap<>(Namespace.class); + /** + * Two-dimension hash table to manage symbols with type in different namespace. + */ + private Map> tableByNamespace = new EnumMap<>(Namespace.class); - /** - * Two-dimension hash table to manage symbols with type in different namespace. - * Comparing with tableByNamespace, orderedTable use the LinkedHashMap to keep the order of - * symbol. - */ - private Map> orderedTable = - new EnumMap<>(Namespace.class); + /** + * Two-dimension hash table to manage symbols with type in different namespace. + * Comparing with tableByNamespace, orderedTable use the LinkedHashMap to keep the order of + * symbol. + */ + private Map> orderedTable = new EnumMap<>(Namespace.class); - /** - * Store symbol with the type. Create new map for namespace for the first time. - * - * @param symbol symbol to define - * @param type symbol type - */ - public void store(Symbol symbol, ExprType type) { - tableByNamespace.computeIfAbsent( - symbol.getNamespace(), - ns -> new TreeMap<>() - ).put(symbol.getName(), type); + /** + * Store symbol with the type. Create new map for namespace for the first time. + * + * @param symbol symbol to define + * @param type symbol type + */ + public void store(Symbol symbol, ExprType type) { + tableByNamespace.computeIfAbsent(symbol.getNamespace(), ns -> new TreeMap<>()).put(symbol.getName(), type); - orderedTable.computeIfAbsent( - symbol.getNamespace(), - ns -> new LinkedHashMap<>() - ).put(symbol.getName(), type); - } + orderedTable.computeIfAbsent(symbol.getNamespace(), ns -> new LinkedHashMap<>()).put(symbol.getName(), type); + } - /** - * Remove a symbol from SymbolTable. - */ - public void remove(Symbol symbol) { - tableByNamespace.computeIfPresent( - symbol.getNamespace(), - (k, v) -> { - v.remove(symbol.getName()); - return v; - } - ); - orderedTable.computeIfPresent( - symbol.getNamespace(), - (k, v) -> { - v.remove(symbol.getName()); - return v; - } - ); - } + /** + * Remove a symbol from SymbolTable. + */ + public void remove(Symbol symbol) { + tableByNamespace.computeIfPresent(symbol.getNamespace(), (k, v) -> { + v.remove(symbol.getName()); + return v; + }); + orderedTable.computeIfPresent(symbol.getNamespace(), (k, v) -> { + v.remove(symbol.getName()); + return v; + }); + } - /** - * Look up symbol in the namespace map. - * - * @param symbol symbol to look up - * @return symbol type which is optional - */ - public Optional lookup(Symbol symbol) { - Map table = tableByNamespace.get(symbol.getNamespace()); - ExprType type = null; - if (table != null) { - type = table.get(symbol.getName()); + /** + * Look up symbol in the namespace map. + * + * @param symbol symbol to look up + * @return symbol type which is optional + */ + public Optional lookup(Symbol symbol) { + Map table = tableByNamespace.get(symbol.getNamespace()); + ExprType type = null; + if (table != null) { + type = table.get(symbol.getName()); + } + return Optional.ofNullable(type); } - return Optional.ofNullable(type); - } - /** - * Look up symbols by a prefix. - * - * @param prefix a symbol prefix - * @return symbols starting with the prefix - */ - public Map lookupByPrefix(Symbol prefix) { - NavigableMap table = tableByNamespace.get(prefix.getNamespace()); - if (table != null) { - return table.subMap(prefix.getName(), prefix.getName() + Character.MAX_VALUE); + /** + * Look up symbols by a prefix. + * + * @param prefix a symbol prefix + * @return symbols starting with the prefix + */ + public Map lookupByPrefix(Symbol prefix) { + NavigableMap table = tableByNamespace.get(prefix.getNamespace()); + if (table != null) { + return table.subMap(prefix.getName(), prefix.getName() + Character.MAX_VALUE); + } + return emptyMap(); } - return emptyMap(); - } - /** - * Look up all top level symbols in the namespace. - * this function is mainly used by SELECT * use case to get the top level fields - * Todo. currently, the top level fields is the field which doesn't include "." in the name or - * the prefix doesn't exist in the symbol table. - * e.g. The symbol table includes person, person.name, person/2.0. - * person, is the top level field - * person.name, isn't the top level field, because the prefix (person) in symbol table - * person/2.0, is the top level field, because the prefix (person/2) isn't in symbol table - * - * @param namespace a namespace - * @return all symbols in the namespace map - */ - public Map lookupAllFields(Namespace namespace) { - final LinkedHashMap allSymbols = - orderedTable.getOrDefault(namespace, new LinkedHashMap<>()); - final LinkedHashMap results = new LinkedHashMap<>(); - allSymbols.entrySet().stream().filter(entry -> { - String symbolName = entry.getKey(); - int lastDot = symbolName.lastIndexOf("."); - return -1 == lastDot || !allSymbols.containsKey(symbolName.substring(0, lastDot)); - }).forEach(entry -> results.put(entry.getKey(), entry.getValue())); - return results; - } + /** + * Look up all top level symbols in the namespace. + * this function is mainly used by SELECT * use case to get the top level fields + * Todo. currently, the top level fields is the field which doesn't include "." in the name or + * the prefix doesn't exist in the symbol table. + * e.g. The symbol table includes person, person.name, person/2.0. + * person, is the top level field + * person.name, isn't the top level field, because the prefix (person) in symbol table + * person/2.0, is the top level field, because the prefix (person/2) isn't in symbol table + * + * @param namespace a namespace + * @return all symbols in the namespace map + */ + public Map lookupAllFields(Namespace namespace) { + final LinkedHashMap allSymbols = orderedTable.getOrDefault(namespace, new LinkedHashMap<>()); + final LinkedHashMap results = new LinkedHashMap<>(); + allSymbols.entrySet().stream().filter(entry -> { + String symbolName = entry.getKey(); + int lastDot = symbolName.lastIndexOf("."); + return -1 == lastDot || !allSymbols.containsKey(symbolName.substring(0, lastDot)); + }).forEach(entry -> results.put(entry.getKey(), entry.getValue())); + return results; + } - /** - * Look up all top level symbols in the namespace. - * - * @param namespace a namespace - * @return all symbols in the namespace map - */ - public Map lookupAllTupleFields(Namespace namespace) { - final LinkedHashMap allSymbols = - orderedTable.getOrDefault(namespace, new LinkedHashMap<>()); - final LinkedHashMap result = new LinkedHashMap<>(); - allSymbols.entrySet().stream() - .forEach(entry -> result.put(entry.getKey(), entry.getValue())); - return result; - } + /** + * Look up all top level symbols in the namespace. + * + * @param namespace a namespace + * @return all symbols in the namespace map + */ + public Map lookupAllTupleFields(Namespace namespace) { + final LinkedHashMap allSymbols = orderedTable.getOrDefault(namespace, new LinkedHashMap<>()); + final LinkedHashMap result = new LinkedHashMap<>(); + allSymbols.entrySet().stream().forEach(entry -> result.put(entry.getKey(), entry.getValue())); + return result; + } - /** - * Check if namespace map in empty (none definition). - * - * @param namespace a namespace - * @return true for empty - */ - public boolean isEmpty(Namespace namespace) { - return tableByNamespace.getOrDefault(namespace, emptyNavigableMap()).isEmpty(); - } + /** + * Check if namespace map in empty (none definition). + * + * @param namespace a namespace + * @return true for empty + */ + public boolean isEmpty(Namespace namespace) { + return tableByNamespace.getOrDefault(namespace, emptyNavigableMap()).isEmpty(); + } } diff --git a/core/src/main/java/org/opensearch/sql/data/model/AbstractExprNumberValue.java b/core/src/main/java/org/opensearch/sql/data/model/AbstractExprNumberValue.java index 1f6363c068..c420e7299d 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/AbstractExprNumberValue.java +++ b/core/src/main/java/org/opensearch/sql/data/model/AbstractExprNumberValue.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.model; import com.google.common.base.Objects; @@ -14,45 +13,45 @@ */ @RequiredArgsConstructor public abstract class AbstractExprNumberValue extends AbstractExprValue { - private final Number value; - - @Override - public boolean isNumber() { - return true; - } - - @Override - public Byte byteValue() { - return value.byteValue(); - } - - @Override - public Short shortValue() { - return value.shortValue(); - } - - @Override - public Integer integerValue() { - return value.intValue(); - } - - @Override - public Long longValue() { - return value.longValue(); - } - - @Override - public Float floatValue() { - return value.floatValue(); - } - - @Override - public Double doubleValue() { - return value.doubleValue(); - } - - @Override - public int hashCode() { - return Objects.hashCode(value); - } + private final Number value; + + @Override + public boolean isNumber() { + return true; + } + + @Override + public Byte byteValue() { + return value.byteValue(); + } + + @Override + public Short shortValue() { + return value.shortValue(); + } + + @Override + public Integer integerValue() { + return value.intValue(); + } + + @Override + public Long longValue() { + return value.longValue(); + } + + @Override + public Float floatValue() { + return value.floatValue(); + } + + @Override + public Double doubleValue() { + return value.doubleValue(); + } + + @Override + public int hashCode() { + return Objects.hashCode(value); + } } diff --git a/core/src/main/java/org/opensearch/sql/data/model/AbstractExprValue.java b/core/src/main/java/org/opensearch/sql/data/model/AbstractExprValue.java index ad2c2ddb49..506fe94d7a 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/AbstractExprValue.java +++ b/core/src/main/java/org/opensearch/sql/data/model/AbstractExprValue.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.model; import org.opensearch.sql.exception.ExpressionEvaluationException; @@ -12,60 +11,56 @@ * Abstract ExprValue. */ public abstract class AbstractExprValue implements ExprValue { - /** - * The customize compareTo logic. - */ - @Override - public int compareTo(ExprValue other) { - if (this.isNull() || this.isMissing() || other.isNull() || other.isMissing()) { - throw new IllegalStateException( - "[BUG] Unreachable, Comparing with NULL or MISSING is undefined"); - } - if ((this.isNumber() && other.isNumber()) - || (this.isDateTime() && other.isDateTime()) - || this.type().equals(other.type())) { - return compare(other); - } else { - throw new ExpressionEvaluationException( - String.format( - "compare expected value have same type, but with [%s, %s]", - this.type(), other.type())); + /** + * The customize compareTo logic. + */ + @Override + public int compareTo(ExprValue other) { + if (this.isNull() || this.isMissing() || other.isNull() || other.isMissing()) { + throw new IllegalStateException("[BUG] Unreachable, Comparing with NULL or MISSING is undefined"); + } + if ((this.isNumber() && other.isNumber()) || (this.isDateTime() && other.isDateTime()) || this.type().equals(other.type())) { + return compare(other); + } else { + throw new ExpressionEvaluationException( + String.format("compare expected value have same type, but with [%s, %s]", this.type(), other.type()) + ); + } } - } - /** - * The customize equals logic. - * The table below list the NULL and MISSING handling logic. - * A B A == B - * NULL NULL TRUE - * NULL MISSING FALSE - * MISSING NULL FALSE - * MISSING MISSING TRUE - */ - @Override - public boolean equals(Object o) { - if (o == this) { - return true; - } else if (!(o instanceof ExprValue)) { - return false; - } - ExprValue other = (ExprValue) o; - if (this.isNull() || this.isMissing()) { - return equal(other); - } else if (other.isNull() || other.isMissing()) { - return other.equals(this); - } else { - return equal(other); + /** + * The customize equals logic. + * The table below list the NULL and MISSING handling logic. + * A B A == B + * NULL NULL TRUE + * NULL MISSING FALSE + * MISSING NULL FALSE + * MISSING MISSING TRUE + */ + @Override + public boolean equals(Object o) { + if (o == this) { + return true; + } else if (!(o instanceof ExprValue)) { + return false; + } + ExprValue other = (ExprValue) o; + if (this.isNull() || this.isMissing()) { + return equal(other); + } else if (other.isNull() || other.isMissing()) { + return other.equals(this); + } else { + return equal(other); + } } - } - /** - * The expression value compare. - */ - public abstract int compare(ExprValue other); + /** + * The expression value compare. + */ + public abstract int compare(ExprValue other); - /** - * The expression value equal. - */ - public abstract boolean equal(ExprValue other); + /** + * The expression value equal. + */ + public abstract boolean equal(ExprValue other); } diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprBooleanValue.java b/core/src/main/java/org/opensearch/sql/data/model/ExprBooleanValue.java index d655c0dabb..7f047ebb10 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprBooleanValue.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprBooleanValue.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.model; import com.google.common.base.Objects; @@ -14,51 +13,51 @@ * Expression Boolean Value. */ public class ExprBooleanValue extends AbstractExprValue { - private static final ExprBooleanValue TRUE = new ExprBooleanValue(true); - private static final ExprBooleanValue FALSE = new ExprBooleanValue(false); - - private final Boolean value; - - private ExprBooleanValue(Boolean value) { - this.value = value; - } - - public static ExprBooleanValue of(Boolean value) { - return value ? TRUE : FALSE; - } - - @Override - public Object value() { - return value; - } - - @Override - public ExprType type() { - return ExprCoreType.BOOLEAN; - } - - @Override - public Boolean booleanValue() { - return value; - } - - @Override - public String toString() { - return value.toString(); - } - - @Override - public int compare(ExprValue other) { - return Boolean.compare(value, other.booleanValue()); - } - - @Override - public boolean equal(ExprValue other) { - return value.equals(other.booleanValue()); - } - - @Override - public int hashCode() { - return Objects.hashCode(value); - } + private static final ExprBooleanValue TRUE = new ExprBooleanValue(true); + private static final ExprBooleanValue FALSE = new ExprBooleanValue(false); + + private final Boolean value; + + private ExprBooleanValue(Boolean value) { + this.value = value; + } + + public static ExprBooleanValue of(Boolean value) { + return value ? TRUE : FALSE; + } + + @Override + public Object value() { + return value; + } + + @Override + public ExprType type() { + return ExprCoreType.BOOLEAN; + } + + @Override + public Boolean booleanValue() { + return value; + } + + @Override + public String toString() { + return value.toString(); + } + + @Override + public int compare(ExprValue other) { + return Boolean.compare(value, other.booleanValue()); + } + + @Override + public boolean equal(ExprValue other) { + return value.equals(other.booleanValue()); + } + + @Override + public int hashCode() { + return Objects.hashCode(value); + } } diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprByteValue.java b/core/src/main/java/org/opensearch/sql/data/model/ExprByteValue.java index b39e6e9d7f..d73bbe62bb 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprByteValue.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprByteValue.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.model; import org.opensearch.sql.data.type.ExprCoreType; @@ -14,32 +13,32 @@ */ public class ExprByteValue extends AbstractExprNumberValue { - public ExprByteValue(Number value) { - super(value); - } - - @Override - public int compare(ExprValue other) { - return Byte.compare(byteValue(), other.byteValue()); - } - - @Override - public boolean equal(ExprValue other) { - return byteValue().equals(other.byteValue()); - } - - @Override - public Object value() { - return byteValue(); - } - - @Override - public ExprType type() { - return ExprCoreType.BYTE; - } - - @Override - public String toString() { - return value().toString(); - } + public ExprByteValue(Number value) { + super(value); + } + + @Override + public int compare(ExprValue other) { + return Byte.compare(byteValue(), other.byteValue()); + } + + @Override + public boolean equal(ExprValue other) { + return byteValue().equals(other.byteValue()); + } + + @Override + public Object value() { + return byteValue(); + } + + @Override + public ExprType type() { + return ExprCoreType.BYTE; + } + + @Override + public String toString() { + return value().toString(); + } } diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprCollectionValue.java b/core/src/main/java/org/opensearch/sql/data/model/ExprCollectionValue.java index 1326733263..63609bb746 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprCollectionValue.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprCollectionValue.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.model; import com.google.common.base.Objects; @@ -20,64 +19,62 @@ */ @RequiredArgsConstructor public class ExprCollectionValue extends AbstractExprValue { - private final List valueList; + private final List valueList; - @Override - public Object value() { - List results = new ArrayList<>(); - for (ExprValue exprValue : valueList) { - results.add(exprValue.value()); + @Override + public Object value() { + List results = new ArrayList<>(); + for (ExprValue exprValue : valueList) { + results.add(exprValue.value()); + } + return results; } - return results; - } - @Override - public ExprType type() { - return ExprCoreType.ARRAY; - } + @Override + public ExprType type() { + return ExprCoreType.ARRAY; + } - @Override - public List collectionValue() { - return valueList; - } + @Override + public List collectionValue() { + return valueList; + } - @Override - public String toString() { - return valueList.stream() - .map(Object::toString) - .collect(Collectors.joining(", ", "[", "]")); - } + @Override + public String toString() { + return valueList.stream().map(Object::toString).collect(Collectors.joining(", ", "[", "]")); + } - @Override - public boolean equal(ExprValue o) { - if (!(o instanceof ExprCollectionValue)) { - return false; - } else { - ExprCollectionValue other = (ExprCollectionValue) o; - Iterator thisIterator = this.valueList.iterator(); - Iterator otherIterator = other.valueList.iterator(); + @Override + public boolean equal(ExprValue o) { + if (!(o instanceof ExprCollectionValue)) { + return false; + } else { + ExprCollectionValue other = (ExprCollectionValue) o; + Iterator thisIterator = this.valueList.iterator(); + Iterator otherIterator = other.valueList.iterator(); - while (thisIterator.hasNext() && otherIterator.hasNext()) { - ExprValue thisEntry = thisIterator.next(); - ExprValue otherEntry = otherIterator.next(); - if (!thisEntry.equals(otherEntry)) { - return false; + while (thisIterator.hasNext() && otherIterator.hasNext()) { + ExprValue thisEntry = thisIterator.next(); + ExprValue otherEntry = otherIterator.next(); + if (!thisEntry.equals(otherEntry)) { + return false; + } + } + return !(thisIterator.hasNext() || otherIterator.hasNext()); } - } - return !(thisIterator.hasNext() || otherIterator.hasNext()); } - } - /** - * Only compare the size of the list. - */ - @Override - public int compare(ExprValue other) { - return Integer.compare(valueList.size(), other.collectionValue().size()); - } + /** + * Only compare the size of the list. + */ + @Override + public int compare(ExprValue other) { + return Integer.compare(valueList.size(), other.collectionValue().size()); + } - @Override - public int hashCode() { - return Objects.hashCode(valueList); - } + @Override + public int hashCode() { + return Objects.hashCode(valueList); + } } diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprDateValue.java b/core/src/main/java/org/opensearch/sql/data/model/ExprDateValue.java index 57ce87df47..c586e09384 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprDateValue.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprDateValue.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.model; import static org.opensearch.sql.utils.DateTimeFormatters.DATE_TIME_FORMATTER_VARIABLE_NANOS_OPTIONAL; @@ -28,72 +27,71 @@ @RequiredArgsConstructor public class ExprDateValue extends AbstractExprValue { - private final LocalDate date; - - /** - * Constructor of ExprDateValue. - */ - public ExprDateValue(String date) { - try { - this.date = LocalDate.parse(date, DATE_TIME_FORMATTER_VARIABLE_NANOS_OPTIONAL); - } catch (DateTimeParseException e) { - throw new SemanticCheckException(String.format("date:%s in unsupported format, please use " - + "yyyy-MM-dd", date)); + private final LocalDate date; + + /** + * Constructor of ExprDateValue. + */ + public ExprDateValue(String date) { + try { + this.date = LocalDate.parse(date, DATE_TIME_FORMATTER_VARIABLE_NANOS_OPTIONAL); + } catch (DateTimeParseException e) { + throw new SemanticCheckException(String.format("date:%s in unsupported format, please use " + "yyyy-MM-dd", date)); + } + } + + @Override + public String value() { + return DateTimeFormatter.ISO_LOCAL_DATE.format(date); + } + + @Override + public ExprType type() { + return ExprCoreType.DATE; + } + + @Override + public LocalDate dateValue() { + return date; + } + + @Override + public LocalTime timeValue() { + return LocalTime.of(0, 0, 0); + } + + @Override + public LocalDateTime datetimeValue() { + return LocalDateTime.of(date, timeValue()); + } + + @Override + public Instant timestampValue() { + return ZonedDateTime.of(date, timeValue(), UTC_ZONE_ID).toInstant(); + } + + @Override + public boolean isDateTime() { + return true; + } + + @Override + public String toString() { + return String.format("DATE '%s'", value()); + } + + @Override + public int compare(ExprValue other) { + return date.compareTo(other.dateValue()); + } + + @Override + public boolean equal(ExprValue other) { + return date.equals(other.dateValue()); + } + + @Override + public int hashCode() { + return Objects.hashCode(date); } - } - - @Override - public String value() { - return DateTimeFormatter.ISO_LOCAL_DATE.format(date); - } - - @Override - public ExprType type() { - return ExprCoreType.DATE; - } - - @Override - public LocalDate dateValue() { - return date; - } - - @Override - public LocalTime timeValue() { - return LocalTime.of(0, 0, 0); - } - - @Override - public LocalDateTime datetimeValue() { - return LocalDateTime.of(date, timeValue()); - } - - @Override - public Instant timestampValue() { - return ZonedDateTime.of(date, timeValue(), UTC_ZONE_ID).toInstant(); - } - - @Override - public boolean isDateTime() { - return true; - } - - @Override - public String toString() { - return String.format("DATE '%s'", value()); - } - - @Override - public int compare(ExprValue other) { - return date.compareTo(other.dateValue()); - } - - @Override - public boolean equal(ExprValue other) { - return date.equals(other.dateValue()); - } - - @Override - public int hashCode() { - return Objects.hashCode(date); - } } diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprDatetimeValue.java b/core/src/main/java/org/opensearch/sql/data/model/ExprDatetimeValue.java index 8d40aaf82c..7f0f576659 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprDatetimeValue.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprDatetimeValue.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.model; import static org.opensearch.sql.utils.DateTimeFormatters.DATE_TIME_FORMATTER_WITH_TZ; @@ -23,77 +22,79 @@ import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.exception.SemanticCheckException; - @RequiredArgsConstructor public class ExprDatetimeValue extends AbstractExprValue { - private final LocalDateTime datetime; - - /** - * Constructor with datetime string as input. - */ - public ExprDatetimeValue(String datetime) { - try { - this.datetime = LocalDateTime.parse(datetime, DATE_TIME_FORMATTER_WITH_TZ); - } catch (DateTimeParseException e) { - throw new SemanticCheckException(String.format("datetime:%s in unsupported format, please " - + "use yyyy-MM-dd HH:mm:ss[.SSSSSSSSS]", datetime)); + private final LocalDateTime datetime; + + /** + * Constructor with datetime string as input. + */ + public ExprDatetimeValue(String datetime) { + try { + this.datetime = LocalDateTime.parse(datetime, DATE_TIME_FORMATTER_WITH_TZ); + } catch (DateTimeParseException e) { + throw new SemanticCheckException( + String.format("datetime:%s in unsupported format, please " + "use yyyy-MM-dd HH:mm:ss[.SSSSSSSSS]", datetime) + ); + } + } + + @Override + public LocalDateTime datetimeValue() { + return datetime; + } + + @Override + public LocalDate dateValue() { + return datetime.toLocalDate(); + } + + @Override + public LocalTime timeValue() { + return datetime.toLocalTime(); + } + + @Override + public Instant timestampValue() { + return ZonedDateTime.of(datetime, UTC_ZONE_ID).toInstant(); + } + + @Override + public boolean isDateTime() { + return true; + } + + @Override + public int compare(ExprValue other) { + return datetime.compareTo(other.datetimeValue()); + } + + @Override + public boolean equal(ExprValue other) { + return datetime.equals(other.datetimeValue()); + } + + @Override + public String value() { + return String.format( + "%s %s", + DateTimeFormatter.ISO_DATE.format(datetime), + DateTimeFormatter.ISO_TIME.format((datetime.getNano() == 0) ? datetime.truncatedTo(ChronoUnit.SECONDS) : datetime) + ); + } + + @Override + public ExprType type() { + return ExprCoreType.DATETIME; + } + + @Override + public String toString() { + return String.format("DATETIME '%s'", value()); + } + + @Override + public int hashCode() { + return Objects.hashCode(datetime); } - } - - @Override - public LocalDateTime datetimeValue() { - return datetime; - } - - @Override - public LocalDate dateValue() { - return datetime.toLocalDate(); - } - - @Override - public LocalTime timeValue() { - return datetime.toLocalTime(); - } - - @Override - public Instant timestampValue() { - return ZonedDateTime.of(datetime, UTC_ZONE_ID).toInstant(); - } - - @Override - public boolean isDateTime() { - return true; - } - - @Override - public int compare(ExprValue other) { - return datetime.compareTo(other.datetimeValue()); - } - - @Override - public boolean equal(ExprValue other) { - return datetime.equals(other.datetimeValue()); - } - - @Override - public String value() { - return String.format("%s %s", DateTimeFormatter.ISO_DATE.format(datetime), - DateTimeFormatter.ISO_TIME.format((datetime.getNano() == 0) - ? datetime.truncatedTo(ChronoUnit.SECONDS) : datetime)); - } - - @Override - public ExprType type() { - return ExprCoreType.DATETIME; - } - - @Override - public String toString() { - return String.format("DATETIME '%s'", value()); - } - - @Override - public int hashCode() { - return Objects.hashCode(datetime); - } } diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprDoubleValue.java b/core/src/main/java/org/opensearch/sql/data/model/ExprDoubleValue.java index 171b064e68..1ef736187c 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprDoubleValue.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprDoubleValue.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.model; import org.opensearch.sql.data.type.ExprCoreType; @@ -14,32 +13,32 @@ */ public class ExprDoubleValue extends AbstractExprNumberValue { - public ExprDoubleValue(Number value) { - super(value); - } - - @Override - public Object value() { - return doubleValue(); - } - - @Override - public ExprType type() { - return ExprCoreType.DOUBLE; - } - - @Override - public String toString() { - return doubleValue().toString(); - } - - @Override - public int compare(ExprValue other) { - return Double.compare(doubleValue(), other.doubleValue()); - } - - @Override - public boolean equal(ExprValue other) { - return doubleValue().equals(other.doubleValue()); - } + public ExprDoubleValue(Number value) { + super(value); + } + + @Override + public Object value() { + return doubleValue(); + } + + @Override + public ExprType type() { + return ExprCoreType.DOUBLE; + } + + @Override + public String toString() { + return doubleValue().toString(); + } + + @Override + public int compare(ExprValue other) { + return Double.compare(doubleValue(), other.doubleValue()); + } + + @Override + public boolean equal(ExprValue other) { + return doubleValue().equals(other.doubleValue()); + } } diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprFloatValue.java b/core/src/main/java/org/opensearch/sql/data/model/ExprFloatValue.java index dc454b4b50..2238516f5e 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprFloatValue.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprFloatValue.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.model; import org.opensearch.sql.data.type.ExprCoreType; @@ -14,32 +13,32 @@ */ public class ExprFloatValue extends AbstractExprNumberValue { - public ExprFloatValue(Number value) { - super(value); - } - - @Override - public Object value() { - return floatValue(); - } - - @Override - public ExprType type() { - return ExprCoreType.FLOAT; - } - - @Override - public String toString() { - return floatValue().toString(); - } - - @Override - public int compare(ExprValue other) { - return Float.compare(floatValue(), other.floatValue()); - } - - @Override - public boolean equal(ExprValue other) { - return floatValue().equals(other.floatValue()); - } + public ExprFloatValue(Number value) { + super(value); + } + + @Override + public Object value() { + return floatValue(); + } + + @Override + public ExprType type() { + return ExprCoreType.FLOAT; + } + + @Override + public String toString() { + return floatValue().toString(); + } + + @Override + public int compare(ExprValue other) { + return Float.compare(floatValue(), other.floatValue()); + } + + @Override + public boolean equal(ExprValue other) { + return floatValue().equals(other.floatValue()); + } } diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprIntegerValue.java b/core/src/main/java/org/opensearch/sql/data/model/ExprIntegerValue.java index 06947766fc..5d45cfc303 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprIntegerValue.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprIntegerValue.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.model; import org.opensearch.sql.data.type.ExprCoreType; @@ -14,32 +13,32 @@ */ public class ExprIntegerValue extends AbstractExprNumberValue { - public ExprIntegerValue(Number value) { - super(value); - } - - @Override - public Object value() { - return integerValue(); - } - - @Override - public ExprType type() { - return ExprCoreType.INTEGER; - } - - @Override - public String toString() { - return integerValue().toString(); - } - - @Override - public int compare(ExprValue other) { - return Integer.compare(integerValue(), other.integerValue()); - } - - @Override - public boolean equal(ExprValue other) { - return integerValue().equals(other.integerValue()); - } + public ExprIntegerValue(Number value) { + super(value); + } + + @Override + public Object value() { + return integerValue(); + } + + @Override + public ExprType type() { + return ExprCoreType.INTEGER; + } + + @Override + public String toString() { + return integerValue().toString(); + } + + @Override + public int compare(ExprValue other) { + return Integer.compare(integerValue(), other.integerValue()); + } + + @Override + public boolean equal(ExprValue other) { + return integerValue().equals(other.integerValue()); + } } diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprIntervalValue.java b/core/src/main/java/org/opensearch/sql/data/model/ExprIntervalValue.java index 25a3115e8c..26efaa7a63 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprIntervalValue.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprIntervalValue.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.model; import java.time.temporal.TemporalAmount; @@ -15,48 +14,43 @@ @RequiredArgsConstructor public class ExprIntervalValue extends AbstractExprValue { - private final TemporalAmount interval; - - @Override - public TemporalAmount intervalValue() { - return interval; - } - - @Override - public int compare(ExprValue other) { - TemporalAmount otherInterval = other.intervalValue(); - if (!interval.getClass().equals(other.intervalValue().getClass())) { - throw new ExpressionEvaluationException( - String.format("invalid to compare intervals with units %s and %s", - unit(), ((ExprIntervalValue) other).unit())); + private final TemporalAmount interval; + + @Override + public TemporalAmount intervalValue() { + return interval; + } + + @Override + public int compare(ExprValue other) { + TemporalAmount otherInterval = other.intervalValue(); + if (!interval.getClass().equals(other.intervalValue().getClass())) { + throw new ExpressionEvaluationException( + String.format("invalid to compare intervals with units %s and %s", unit(), ((ExprIntervalValue) other).unit()) + ); + } + return Long.compare(interval.get(unit()), otherInterval.get(((ExprIntervalValue) other).unit())); + } + + @Override + public boolean equal(ExprValue other) { + return interval.equals(other.intervalValue()); + } + + @Override + public TemporalAmount value() { + return interval; + } + + @Override + public ExprType type() { + return ExprCoreType.INTERVAL; + } + + /** + * Util method to get temporal unit stored locally. + */ + public TemporalUnit unit() { + return interval.getUnits().stream().filter(v -> interval.get(v) != 0).findAny().orElse(interval.getUnits().get(0)); } - return Long.compare( - interval.get(unit()), otherInterval.get(((ExprIntervalValue) other).unit())); - } - - @Override - public boolean equal(ExprValue other) { - return interval.equals(other.intervalValue()); - } - - @Override - public TemporalAmount value() { - return interval; - } - - @Override - public ExprType type() { - return ExprCoreType.INTERVAL; - } - - /** - * Util method to get temporal unit stored locally. - */ - public TemporalUnit unit() { - return interval.getUnits() - .stream() - .filter(v -> interval.get(v) != 0) - .findAny() - .orElse(interval.getUnits().get(0)); - } } diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprLongValue.java b/core/src/main/java/org/opensearch/sql/data/model/ExprLongValue.java index 1df590246c..b63e6fb6a5 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprLongValue.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprLongValue.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.model; import org.opensearch.sql.data.type.ExprCoreType; @@ -14,32 +13,32 @@ */ public class ExprLongValue extends AbstractExprNumberValue { - public ExprLongValue(Number value) { - super(value); - } - - @Override - public Object value() { - return longValue(); - } - - @Override - public ExprType type() { - return ExprCoreType.LONG; - } - - @Override - public String toString() { - return longValue().toString(); - } - - @Override - public int compare(ExprValue other) { - return Long.compare(longValue(), other.longValue()); - } - - @Override - public boolean equal(ExprValue other) { - return longValue().equals(other.longValue()); - } + public ExprLongValue(Number value) { + super(value); + } + + @Override + public Object value() { + return longValue(); + } + + @Override + public ExprType type() { + return ExprCoreType.LONG; + } + + @Override + public String toString() { + return longValue().toString(); + } + + @Override + public int compare(ExprValue other) { + return Long.compare(longValue(), other.longValue()); + } + + @Override + public boolean equal(ExprValue other) { + return longValue().equals(other.longValue()); + } } diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprMissingValue.java b/core/src/main/java/org/opensearch/sql/data/model/ExprMissingValue.java index 9908074773..43246fb74c 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprMissingValue.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprMissingValue.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.model; import java.util.Objects; @@ -14,52 +13,50 @@ * Expression Missing Value. */ public class ExprMissingValue extends AbstractExprValue { - private static final ExprMissingValue instance = new ExprMissingValue(); - - private ExprMissingValue() { - } - - public static ExprMissingValue of() { - return instance; - } - - @Override - public Object value() { - return null; - } - - @Override - public ExprType type() { - return ExprCoreType.UNDEFINED; - } - - @Override - public boolean isMissing() { - return true; - } - - @Override - public int compare(ExprValue other) { - throw new IllegalStateException(String.format("[BUG] Unreachable, Comparing with MISSING is " - + "undefined")); - } - - /** - * Missing value is equal to Missing value. - * Notes, this function should only used for Java Object Compare. - */ - @Override - public boolean equal(ExprValue other) { - return other.isMissing(); - } - - @Override - public int hashCode() { - return Objects.hashCode("MISSING"); - } - - @Override - public String toString() { - return "MISSING"; - } + private static final ExprMissingValue instance = new ExprMissingValue(); + + private ExprMissingValue() {} + + public static ExprMissingValue of() { + return instance; + } + + @Override + public Object value() { + return null; + } + + @Override + public ExprType type() { + return ExprCoreType.UNDEFINED; + } + + @Override + public boolean isMissing() { + return true; + } + + @Override + public int compare(ExprValue other) { + throw new IllegalStateException(String.format("[BUG] Unreachable, Comparing with MISSING is " + "undefined")); + } + + /** + * Missing value is equal to Missing value. + * Notes, this function should only used for Java Object Compare. + */ + @Override + public boolean equal(ExprValue other) { + return other.isMissing(); + } + + @Override + public int hashCode() { + return Objects.hashCode("MISSING"); + } + + @Override + public String toString() { + return "MISSING"; + } } diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprNullValue.java b/core/src/main/java/org/opensearch/sql/data/model/ExprNullValue.java index 54d4811d33..265450de0b 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprNullValue.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprNullValue.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.model; import java.util.Objects; @@ -14,52 +13,50 @@ * Expression Null Value. */ public class ExprNullValue extends AbstractExprValue { - private static final ExprNullValue instance = new ExprNullValue(); - - private ExprNullValue() { - } - - @Override - public int hashCode() { - return Objects.hashCode("NULL"); - } - - @Override - public String toString() { - return "NULL"; - } - - public static ExprNullValue of() { - return instance; - } - - @Override - public Object value() { - return null; - } - - @Override - public ExprType type() { - return ExprCoreType.UNDEFINED; - } - - @Override - public boolean isNull() { - return true; - } - - @Override - public int compare(ExprValue other) { - throw new IllegalStateException( - String.format("[BUG] Unreachable, Comparing with NULL is undefined")); - } - - /** - * NULL value is equal to NULL value. - * Notes, this function should only used for Java Object Compare. - */ - @Override - public boolean equal(ExprValue other) { - return other.isNull(); - } + private static final ExprNullValue instance = new ExprNullValue(); + + private ExprNullValue() {} + + @Override + public int hashCode() { + return Objects.hashCode("NULL"); + } + + @Override + public String toString() { + return "NULL"; + } + + public static ExprNullValue of() { + return instance; + } + + @Override + public Object value() { + return null; + } + + @Override + public ExprType type() { + return ExprCoreType.UNDEFINED; + } + + @Override + public boolean isNull() { + return true; + } + + @Override + public int compare(ExprValue other) { + throw new IllegalStateException(String.format("[BUG] Unreachable, Comparing with NULL is undefined")); + } + + /** + * NULL value is equal to NULL value. + * Notes, this function should only used for Java Object Compare. + */ + @Override + public boolean equal(ExprValue other) { + return other.isNull(); + } } diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprShortValue.java b/core/src/main/java/org/opensearch/sql/data/model/ExprShortValue.java index 3e5f6858bc..a38c65c075 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprShortValue.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprShortValue.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.model; import org.opensearch.sql.data.type.ExprCoreType; @@ -14,32 +13,32 @@ */ public class ExprShortValue extends AbstractExprNumberValue { - public ExprShortValue(Number value) { - super(value); - } - - @Override - public Object value() { - return shortValue(); - } - - @Override - public ExprType type() { - return ExprCoreType.SHORT; - } - - @Override - public String toString() { - return shortValue().toString(); - } - - @Override - public int compare(ExprValue other) { - return Short.compare(shortValue(), other.shortValue()); - } - - @Override - public boolean equal(ExprValue other) { - return shortValue().equals(other.shortValue()); - } + public ExprShortValue(Number value) { + super(value); + } + + @Override + public Object value() { + return shortValue(); + } + + @Override + public ExprType type() { + return ExprCoreType.SHORT; + } + + @Override + public String toString() { + return shortValue().toString(); + } + + @Override + public int compare(ExprValue other) { + return Short.compare(shortValue(), other.shortValue()); + } + + @Override + public boolean equal(ExprValue other) { + return shortValue().equals(other.shortValue()); + } } diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprStringValue.java b/core/src/main/java/org/opensearch/sql/data/model/ExprStringValue.java index c41c23d6ac..f1c2e068e6 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprStringValue.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprStringValue.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.model; import java.time.LocalDate; @@ -20,74 +19,73 @@ */ @RequiredArgsConstructor public class ExprStringValue extends AbstractExprValue { - private final String value; + private final String value; - @Override - public Object value() { - return value; - } + @Override + public Object value() { + return value; + } - @Override - public ExprType type() { - return ExprCoreType.STRING; - } + @Override + public ExprType type() { + return ExprCoreType.STRING; + } - @Override - public String stringValue() { - return value; - } + @Override + public String stringValue() { + return value; + } - @Override - public LocalDateTime datetimeValue() { - try { - return new ExprDatetimeValue(value).datetimeValue(); - } catch (SemanticCheckException e) { - try { - return new ExprDatetimeValue( - LocalDateTime.of(new ExprDateValue(value).dateValue(), LocalTime.of(0, 0, 0))) - .datetimeValue(); - } catch (SemanticCheckException exception) { - throw new SemanticCheckException(String.format("datetime:%s in unsupported format, please " - + "use yyyy-MM-dd HH:mm:ss[.SSSSSSSSS]", value)); - } + @Override + public LocalDateTime datetimeValue() { + try { + return new ExprDatetimeValue(value).datetimeValue(); + } catch (SemanticCheckException e) { + try { + return new ExprDatetimeValue(LocalDateTime.of(new ExprDateValue(value).dateValue(), LocalTime.of(0, 0, 0))).datetimeValue(); + } catch (SemanticCheckException exception) { + throw new SemanticCheckException( + String.format("datetime:%s in unsupported format, please " + "use yyyy-MM-dd HH:mm:ss[.SSSSSSSSS]", value) + ); + } + } } - } - @Override - public LocalDate dateValue() { - try { - return new ExprDatetimeValue(value).dateValue(); - } catch (SemanticCheckException e) { - return new ExprDateValue(value).dateValue(); + @Override + public LocalDate dateValue() { + try { + return new ExprDatetimeValue(value).dateValue(); + } catch (SemanticCheckException e) { + return new ExprDateValue(value).dateValue(); + } } - } - @Override - public LocalTime timeValue() { - try { - return new ExprDatetimeValue(value).timeValue(); - } catch (SemanticCheckException e) { - return new ExprTimeValue(value).timeValue(); + @Override + public LocalTime timeValue() { + try { + return new ExprDatetimeValue(value).timeValue(); + } catch (SemanticCheckException e) { + return new ExprTimeValue(value).timeValue(); + } } - } - @Override - public String toString() { - return String.format("\"%s\"", value); - } + @Override + public String toString() { + return String.format("\"%s\"", value); + } - @Override - public int compare(ExprValue other) { - return value.compareTo(other.stringValue()); - } + @Override + public int compare(ExprValue other) { + return value.compareTo(other.stringValue()); + } - @Override - public boolean equal(ExprValue other) { - return value.equals(other.stringValue()); - } + @Override + public boolean equal(ExprValue other) { + return value.equals(other.stringValue()); + } - @Override - public int hashCode() { - return Objects.hashCode(value); - } + @Override + public int hashCode() { + return Objects.hashCode(value); + } } diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprTimeValue.java b/core/src/main/java/org/opensearch/sql/data/model/ExprTimeValue.java index db5bf7cb52..6a6bac901f 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprTimeValue.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprTimeValue.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.model; import static java.time.format.DateTimeFormatter.ISO_LOCAL_TIME; @@ -29,70 +28,68 @@ @RequiredArgsConstructor public class ExprTimeValue extends AbstractExprValue { - private final LocalTime time; - - /** - * Constructor of ExprTimeValue. - */ - public ExprTimeValue(String time) { - try { - this.time = LocalTime.parse(time, DATE_TIME_FORMATTER_VARIABLE_NANOS_OPTIONAL); - } catch (DateTimeParseException e) { - throw new SemanticCheckException(String.format("time:%s in unsupported format, please use " - + "HH:mm:ss[.SSSSSSSSS]", time)); + private final LocalTime time; + + /** + * Constructor of ExprTimeValue. + */ + public ExprTimeValue(String time) { + try { + this.time = LocalTime.parse(time, DATE_TIME_FORMATTER_VARIABLE_NANOS_OPTIONAL); + } catch (DateTimeParseException e) { + throw new SemanticCheckException(String.format("time:%s in unsupported format, please use " + "HH:mm:ss[.SSSSSSSSS]", time)); + } + } + + @Override + public String value() { + return ISO_LOCAL_TIME.format(time); + } + + @Override + public ExprType type() { + return ExprCoreType.TIME; + } + + @Override + public LocalTime timeValue() { + return time; + } + + public LocalDate dateValue(FunctionProperties functionProperties) { + return LocalDate.now(functionProperties.getQueryStartClock()); + } + + public LocalDateTime datetimeValue(FunctionProperties functionProperties) { + return LocalDateTime.of(dateValue(functionProperties), timeValue()); + } + + public Instant timestampValue(FunctionProperties functionProperties) { + return ZonedDateTime.of(dateValue(functionProperties), timeValue(), UTC_ZONE_ID).toInstant(); + } + + @Override + public boolean isDateTime() { + return true; + } + + @Override + public String toString() { + return String.format("TIME '%s'", value()); + } + + @Override + public int compare(ExprValue other) { + return time.compareTo(other.timeValue()); + } + + @Override + public boolean equal(ExprValue other) { + return time.equals(other.timeValue()); + } + + @Override + public int hashCode() { + return Objects.hashCode(time); } - } - - @Override - public String value() { - return ISO_LOCAL_TIME.format(time); - } - - @Override - public ExprType type() { - return ExprCoreType.TIME; - } - - @Override - public LocalTime timeValue() { - return time; - } - - public LocalDate dateValue(FunctionProperties functionProperties) { - return LocalDate.now(functionProperties.getQueryStartClock()); - } - - public LocalDateTime datetimeValue(FunctionProperties functionProperties) { - return LocalDateTime.of(dateValue(functionProperties), timeValue()); - } - - public Instant timestampValue(FunctionProperties functionProperties) { - return ZonedDateTime.of(dateValue(functionProperties), timeValue(), UTC_ZONE_ID) - .toInstant(); - } - - @Override - public boolean isDateTime() { - return true; - } - - @Override - public String toString() { - return String.format("TIME '%s'", value()); - } - - @Override - public int compare(ExprValue other) { - return time.compareTo(other.timeValue()); - } - - @Override - public boolean equal(ExprValue other) { - return time.equals(other.timeValue()); - } - - @Override - public int hashCode() { - return Objects.hashCode(time); - } } diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprTimestampValue.java b/core/src/main/java/org/opensearch/sql/data/model/ExprTimestampValue.java index d15cee5e71..fe5f6f645a 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprTimestampValue.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprTimestampValue.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.model; import static org.opensearch.sql.utils.DateTimeFormatters.DATE_TIME_FORMATTER_VARIABLE_NANOS; @@ -28,77 +27,76 @@ @RequiredArgsConstructor public class ExprTimestampValue extends AbstractExprValue { - private final Instant timestamp; - - /** - * Constructor. - */ - public ExprTimestampValue(String timestamp) { - try { - this.timestamp = LocalDateTime.parse(timestamp, DATE_TIME_FORMATTER_VARIABLE_NANOS) - .atZone(UTC_ZONE_ID) - .toInstant(); - } catch (DateTimeParseException e) { - throw new SemanticCheckException(String.format("timestamp:%s in unsupported format, please " - + "use yyyy-MM-dd HH:mm:ss[.SSSSSSSSS]", timestamp)); + private final Instant timestamp; + + /** + * Constructor. + */ + public ExprTimestampValue(String timestamp) { + try { + this.timestamp = LocalDateTime.parse(timestamp, DATE_TIME_FORMATTER_VARIABLE_NANOS).atZone(UTC_ZONE_ID).toInstant(); + } catch (DateTimeParseException e) { + throw new SemanticCheckException( + String.format("timestamp:%s in unsupported format, please " + "use yyyy-MM-dd HH:mm:ss[.SSSSSSSSS]", timestamp) + ); + } + + } + + @Override + public String value() { + return timestamp.getNano() == 0 + ? DATE_TIME_FORMATTER_WITHOUT_NANO.withZone(UTC_ZONE_ID).format(timestamp.truncatedTo(ChronoUnit.SECONDS)) + : DATE_TIME_FORMATTER_VARIABLE_NANOS.withZone(UTC_ZONE_ID).format(timestamp); + } + + @Override + public ExprType type() { + return ExprCoreType.TIMESTAMP; + } + + @Override + public Instant timestampValue() { + return timestamp; + } + + @Override + public LocalDate dateValue() { + return timestamp.atZone(UTC_ZONE_ID).toLocalDate(); + } + + @Override + public LocalTime timeValue() { + return timestamp.atZone(UTC_ZONE_ID).toLocalTime(); + } + + @Override + public LocalDateTime datetimeValue() { + return timestamp.atZone(UTC_ZONE_ID).toLocalDateTime(); + } + + @Override + public boolean isDateTime() { + return true; + } + + @Override + public String toString() { + return String.format("TIMESTAMP '%s'", value()); + } + + @Override + public int compare(ExprValue other) { + return timestamp.compareTo(other.timestampValue().atZone(UTC_ZONE_ID).toInstant()); } - } - - @Override - public String value() { - return timestamp.getNano() == 0 ? DATE_TIME_FORMATTER_WITHOUT_NANO.withZone(UTC_ZONE_ID) - .format(timestamp.truncatedTo(ChronoUnit.SECONDS)) - : DATE_TIME_FORMATTER_VARIABLE_NANOS.withZone(UTC_ZONE_ID).format(timestamp); - } - - @Override - public ExprType type() { - return ExprCoreType.TIMESTAMP; - } - - @Override - public Instant timestampValue() { - return timestamp; - } - - @Override - public LocalDate dateValue() { - return timestamp.atZone(UTC_ZONE_ID).toLocalDate(); - } - - @Override - public LocalTime timeValue() { - return timestamp.atZone(UTC_ZONE_ID).toLocalTime(); - } - - @Override - public LocalDateTime datetimeValue() { - return timestamp.atZone(UTC_ZONE_ID).toLocalDateTime(); - } - - @Override - public boolean isDateTime() { - return true; - } - - @Override - public String toString() { - return String.format("TIMESTAMP '%s'", value()); - } - - @Override - public int compare(ExprValue other) { - return timestamp.compareTo(other.timestampValue().atZone(UTC_ZONE_ID).toInstant()); - } - - @Override - public boolean equal(ExprValue other) { - return timestamp.equals(other.timestampValue().atZone(UTC_ZONE_ID).toInstant()); - } - - @Override - public int hashCode() { - return Objects.hashCode(timestamp); - } + @Override + public boolean equal(ExprValue other) { + return timestamp.equals(other.timestampValue().atZone(UTC_ZONE_ID).toInstant()); + } + + @Override + public int hashCode() { + return Objects.hashCode(timestamp); + } } diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprTupleValue.java b/core/src/main/java/org/opensearch/sql/data/model/ExprTupleValue.java index 749de931ee..735c5447ce 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprTupleValue.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprTupleValue.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.model; import java.util.Iterator; @@ -24,83 +23,82 @@ @RequiredArgsConstructor public class ExprTupleValue extends AbstractExprValue { - private final LinkedHashMap valueMap; + private final LinkedHashMap valueMap; - public static ExprTupleValue fromExprValueMap(Map map) { - LinkedHashMap linkedHashMap = new LinkedHashMap<>(map); - return new ExprTupleValue(linkedHashMap); - } + public static ExprTupleValue fromExprValueMap(Map map) { + LinkedHashMap linkedHashMap = new LinkedHashMap<>(map); + return new ExprTupleValue(linkedHashMap); + } - @Override - public Object value() { - LinkedHashMap resultMap = new LinkedHashMap<>(); - for (Entry entry : valueMap.entrySet()) { - resultMap.put(entry.getKey(), entry.getValue().value()); + @Override + public Object value() { + LinkedHashMap resultMap = new LinkedHashMap<>(); + for (Entry entry : valueMap.entrySet()) { + resultMap.put(entry.getKey(), entry.getValue().value()); + } + return resultMap; } - return resultMap; - } - @Override - public ExprType type() { - return ExprCoreType.STRUCT; - } + @Override + public ExprType type() { + return ExprCoreType.STRUCT; + } - @Override - public String toString() { - return valueMap.entrySet() - .stream() - .map(entry -> String.format("%s:%s", entry.getKey(), entry.getValue())) - .collect(Collectors.joining(",", "{", "}")); - } + @Override + public String toString() { + return valueMap.entrySet() + .stream() + .map(entry -> String.format("%s:%s", entry.getKey(), entry.getValue())) + .collect(Collectors.joining(",", "{", "}")); + } - @Override - public BindingTuple bindingTuples() { - return new LazyBindingTuple(() -> this); - } + @Override + public BindingTuple bindingTuples() { + return new LazyBindingTuple(() -> this); + } - @Override - public Map tupleValue() { - return valueMap; - } + @Override + public Map tupleValue() { + return valueMap; + } - @Override - public ExprValue keyValue(String key) { - return valueMap.getOrDefault(key, ExprMissingValue.of()); - } + @Override + public ExprValue keyValue(String key) { + return valueMap.getOrDefault(key, ExprMissingValue.of()); + } - /** - * Override the equals method. - * @return true for equal, otherwise false. - */ - public boolean equal(ExprValue o) { - if (!(o instanceof ExprTupleValue)) { - return false; - } else { - ExprTupleValue other = (ExprTupleValue) o; - Iterator> thisIterator = this.valueMap.entrySet().iterator(); - Iterator> otherIterator = other.valueMap.entrySet().iterator(); - while (thisIterator.hasNext() && otherIterator.hasNext()) { - Entry thisEntry = thisIterator.next(); - Entry otherEntry = otherIterator.next(); - if (!(thisEntry.getKey().equals(otherEntry.getKey()) - && thisEntry.getValue().equals(otherEntry.getValue()))) { - return false; + /** + * Override the equals method. + * @return true for equal, otherwise false. + */ + public boolean equal(ExprValue o) { + if (!(o instanceof ExprTupleValue)) { + return false; + } else { + ExprTupleValue other = (ExprTupleValue) o; + Iterator> thisIterator = this.valueMap.entrySet().iterator(); + Iterator> otherIterator = other.valueMap.entrySet().iterator(); + while (thisIterator.hasNext() && otherIterator.hasNext()) { + Entry thisEntry = thisIterator.next(); + Entry otherEntry = otherIterator.next(); + if (!(thisEntry.getKey().equals(otherEntry.getKey()) && thisEntry.getValue().equals(otherEntry.getValue()))) { + return false; + } + } + return !(thisIterator.hasNext() || otherIterator.hasNext()); } - } - return !(thisIterator.hasNext() || otherIterator.hasNext()); } - } - /** - * Only compare the size of the map. - */ - @Override - public int compare(ExprValue other) { - return Integer.compare(valueMap.size(), other.tupleValue().size()); - } + /** + * Only compare the size of the map. + */ + @Override + public int compare(ExprValue other) { + return Integer.compare(valueMap.size(), other.tupleValue().size()); + } - @Override - public int hashCode() { - return Objects.hashCode(valueMap); - } + @Override + public int hashCode() { + return Objects.hashCode(valueMap); + } } diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprValue.java b/core/src/main/java/org/opensearch/sql/data/model/ExprValue.java index 1ae03de37b..e53d4e6261 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprValue.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprValue.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.model; import java.io.Serializable; @@ -23,184 +22,169 @@ * The definition of the Expression Value. */ public interface ExprValue extends Serializable, Comparable { - /** - * Get the Object value of the Expression Value. - */ - Object value(); - - /** - * Get the {@link ExprCoreType} of the Expression Value. - */ - ExprType type(); - - /** - * Is null value. - * - * @return true: is null value, otherwise false - */ - default boolean isNull() { - return false; - } - - /** - * Is missing value. - * - * @return true: is missing value, otherwise false - */ - default boolean isMissing() { - return false; - } - - /** - * Is Number value. - * - * @return true: is number value, otherwise false - */ - default boolean isNumber() { - return false; - } - - /** - * Is Datetime value. - * - * @return true: is a datetime value, otherwise false - */ - default boolean isDateTime() { - return false; - } - - /** - * Get the {@link BindingTuple}. - */ - default BindingTuple bindingTuples() { - return BindingTuple.EMPTY; - } - - /** - * Get byte value. - */ - default Byte byteValue() { - throw new ExpressionEvaluationException( - "invalid to get byteValue from value of type " + type()); - } - - /** - * Get short value. - */ - default Short shortValue() { - throw new ExpressionEvaluationException( - "invalid to get shortValue from value of type " + type()); - } - - /** - * Get integer value. - */ - default Integer integerValue() { - throw new ExpressionEvaluationException( - "invalid to get integerValue from value of type " + type()); - } - - /** - * Get long value. - */ - default Long longValue() { - throw new ExpressionEvaluationException( - "invalid to get longValue from value of type " + type()); - } - - /** - * Get float value. - */ - default Float floatValue() { - throw new ExpressionEvaluationException( - "invalid to get floatValue from value of type " + type()); - } - - /** - * Get float value. - */ - default Double doubleValue() { - throw new ExpressionEvaluationException( - "invalid to get doubleValue from value of type " + type()); - } - - /** - * Get string value. - */ - default String stringValue() { - throw new ExpressionEvaluationException( - "invalid to get stringValue from value of type " + type()); - } - - /** - * Get boolean value. - */ - default Boolean booleanValue() { - throw new ExpressionEvaluationException( - "invalid to get booleanValue from value of type " + type()); - } - - /** - * Get timestamp value. - */ - default Instant timestampValue() { - throw new ExpressionEvaluationException( - "invalid to get timestampValue from value of type " + type()); - } - - /** - * Get time value. - */ - default LocalTime timeValue() { - throw new ExpressionEvaluationException( - "invalid to get timeValue from value of type " + type()); - } - - /** - * Get date value. - */ - default LocalDate dateValue() { - throw new ExpressionEvaluationException( - "invalid to get dateValue from value of type " + type()); - } - - /** - * Get datetime value. - */ - default LocalDateTime datetimeValue() { - throw new ExpressionEvaluationException( - "invalid to get datetimeValue from value of type " + type()); - } - - /** - * Get interval value. - */ - default TemporalAmount intervalValue() { - throw new ExpressionEvaluationException( - "invalid to get intervalValue from value of type " + type()); - } - - /** - * Get map value. - */ - default Map tupleValue() { - throw new ExpressionEvaluationException( - "invalid to get tupleValue from value of type " + type()); - } - - /** - * Get collection value. - */ - default List collectionValue() { - throw new ExpressionEvaluationException( - "invalid to get collectionValue from value of type " + type()); - } - - /** - * Get the value specified by key from {@link ExprTupleValue}. - * This method only be implemented in {@link ExprTupleValue}. - */ - default ExprValue keyValue(String key) { - return ExprMissingValue.of(); - } + /** + * Get the Object value of the Expression Value. + */ + Object value(); + + /** + * Get the {@link ExprCoreType} of the Expression Value. + */ + ExprType type(); + + /** + * Is null value. + * + * @return true: is null value, otherwise false + */ + default boolean isNull() { + return false; + } + + /** + * Is missing value. + * + * @return true: is missing value, otherwise false + */ + default boolean isMissing() { + return false; + } + + /** + * Is Number value. + * + * @return true: is number value, otherwise false + */ + default boolean isNumber() { + return false; + } + + /** + * Is Datetime value. + * + * @return true: is a datetime value, otherwise false + */ + default boolean isDateTime() { + return false; + } + + /** + * Get the {@link BindingTuple}. + */ + default BindingTuple bindingTuples() { + return BindingTuple.EMPTY; + } + + /** + * Get byte value. + */ + default Byte byteValue() { + throw new ExpressionEvaluationException("invalid to get byteValue from value of type " + type()); + } + + /** + * Get short value. + */ + default Short shortValue() { + throw new ExpressionEvaluationException("invalid to get shortValue from value of type " + type()); + } + + /** + * Get integer value. + */ + default Integer integerValue() { + throw new ExpressionEvaluationException("invalid to get integerValue from value of type " + type()); + } + + /** + * Get long value. + */ + default Long longValue() { + throw new ExpressionEvaluationException("invalid to get longValue from value of type " + type()); + } + + /** + * Get float value. + */ + default Float floatValue() { + throw new ExpressionEvaluationException("invalid to get floatValue from value of type " + type()); + } + + /** + * Get float value. + */ + default Double doubleValue() { + throw new ExpressionEvaluationException("invalid to get doubleValue from value of type " + type()); + } + + /** + * Get string value. + */ + default String stringValue() { + throw new ExpressionEvaluationException("invalid to get stringValue from value of type " + type()); + } + + /** + * Get boolean value. + */ + default Boolean booleanValue() { + throw new ExpressionEvaluationException("invalid to get booleanValue from value of type " + type()); + } + + /** + * Get timestamp value. + */ + default Instant timestampValue() { + throw new ExpressionEvaluationException("invalid to get timestampValue from value of type " + type()); + } + + /** + * Get time value. + */ + default LocalTime timeValue() { + throw new ExpressionEvaluationException("invalid to get timeValue from value of type " + type()); + } + + /** + * Get date value. + */ + default LocalDate dateValue() { + throw new ExpressionEvaluationException("invalid to get dateValue from value of type " + type()); + } + + /** + * Get datetime value. + */ + default LocalDateTime datetimeValue() { + throw new ExpressionEvaluationException("invalid to get datetimeValue from value of type " + type()); + } + + /** + * Get interval value. + */ + default TemporalAmount intervalValue() { + throw new ExpressionEvaluationException("invalid to get intervalValue from value of type " + type()); + } + + /** + * Get map value. + */ + default Map tupleValue() { + throw new ExpressionEvaluationException("invalid to get tupleValue from value of type " + type()); + } + + /** + * Get collection value. + */ + default List collectionValue() { + throw new ExpressionEvaluationException("invalid to get collectionValue from value of type " + type()); + } + + /** + * Get the value specified by key from {@link ExprTupleValue}. + * This method only be implemented in {@link ExprTupleValue}. + */ + default ExprValue keyValue(String key) { + return ExprMissingValue.of(); + } } diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprValueUtils.java b/core/src/main/java/org/opensearch/sql/data/model/ExprValueUtils.java index 43a3140ef3..f1be26139e 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprValueUtils.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprValueUtils.java @@ -23,187 +23,186 @@ */ @UtilityClass public class ExprValueUtils { - public static final ExprValue LITERAL_TRUE = ExprBooleanValue.of(true); - public static final ExprValue LITERAL_FALSE = ExprBooleanValue.of(false); - public static final ExprValue LITERAL_NULL = ExprNullValue.of(); - public static final ExprValue LITERAL_MISSING = ExprMissingValue.of(); - - public static ExprValue booleanValue(Boolean value) { - return value ? LITERAL_TRUE : LITERAL_FALSE; - } - - public static ExprValue byteValue(Byte value) { - return new ExprByteValue(value); - } - - public static ExprValue shortValue(Short value) { - return new ExprShortValue(value); - } - - public static ExprValue integerValue(Integer value) { - return new ExprIntegerValue(value); - } - - public static ExprValue doubleValue(Double value) { - return new ExprDoubleValue(value); - } - - public static ExprValue floatValue(Float value) { - return new ExprFloatValue(value); - } - - public static ExprValue longValue(Long value) { - return new ExprLongValue(value); - } - - public static ExprValue stringValue(String value) { - return new ExprStringValue(value); - } - - public static ExprValue intervalValue(TemporalAmount value) { - return new ExprIntervalValue(value); - } - - public static ExprValue dateValue(LocalDate value) { - return new ExprDateValue(value); - } - - public static ExprValue datetimeValue(LocalDateTime value) { - return new ExprDatetimeValue(value); - } - - public static ExprValue timeValue(LocalTime value) { - return new ExprTimeValue(value); - } - - public static ExprValue timestampValue(Instant value) { - return new ExprTimestampValue(value); - } - - /** - * {@link ExprTupleValue} constructor. - */ - public static ExprValue tupleValue(Map map) { - LinkedHashMap valueMap = new LinkedHashMap<>(); - map.forEach((k, v) -> valueMap - .put(k, v instanceof ExprValue ? (ExprValue) v : fromObjectValue(v))); - return new ExprTupleValue(valueMap); - } - - /** - * {@link ExprCollectionValue} constructor. - */ - public static ExprValue collectionValue(List list) { - List valueList = new ArrayList<>(); - list.forEach(o -> valueList.add(fromObjectValue(o))); - return new ExprCollectionValue(valueList); - } - - public static ExprValue missingValue() { - return ExprMissingValue.of(); - } - - public static ExprValue nullValue() { - return ExprNullValue.of(); - } - - /** - * Construct ExprValue from Object. - */ - public static ExprValue fromObjectValue(Object o) { - if (null == o) { - return LITERAL_NULL; - } - if (o instanceof Map) { - return tupleValue((Map) o); - } else if (o instanceof List) { - return collectionValue(((List) o)); - } else if (o instanceof Byte) { - return byteValue((Byte) o); - } else if (o instanceof Short) { - return shortValue((Short) o); - } else if (o instanceof Integer) { - return integerValue((Integer) o); - } else if (o instanceof Long) { - return longValue(((Long) o)); - } else if (o instanceof Boolean) { - return booleanValue((Boolean) o); - } else if (o instanceof Double) { - return doubleValue((Double) o); - } else if (o instanceof String) { - return stringValue((String) o); - } else if (o instanceof Float) { - return floatValue((Float) o); - } else if (o instanceof LocalDate) { - return dateValue((LocalDate) o); - } else if (o instanceof LocalDateTime) { - return datetimeValue((LocalDateTime) o); - } else if (o instanceof LocalTime) { - return timeValue((LocalTime) o); - } else if (o instanceof Instant) { - return timestampValue((Instant) o); - } else if (o instanceof TemporalAmount) { - return intervalValue((TemporalAmount) o); - } else { - throw new ExpressionEvaluationException("unsupported object " + o.getClass()); - } - } - - /** - * Construct ExprValue from Object with ExprCoreType. - */ - public static ExprValue fromObjectValue(Object o, ExprCoreType type) { - switch (type) { - case TIMESTAMP: - return new ExprTimestampValue((String)o); - case DATE: - return new ExprDateValue((String)o); - case TIME: - return new ExprTimeValue((String)o); - case DATETIME: - return new ExprDatetimeValue((String)o); - default: - return fromObjectValue(o); - } - } - - public static Byte getByteValue(ExprValue exprValue) { - return exprValue.byteValue(); - } - - public static Short getShortValue(ExprValue exprValue) { - return exprValue.shortValue(); - } - - public static Integer getIntegerValue(ExprValue exprValue) { - return exprValue.integerValue(); - } - - public static Double getDoubleValue(ExprValue exprValue) { - return exprValue.doubleValue(); - } - - public static Long getLongValue(ExprValue exprValue) { - return exprValue.longValue(); - } - - public static Float getFloatValue(ExprValue exprValue) { - return exprValue.floatValue(); - } - - public static String getStringValue(ExprValue exprValue) { - return exprValue.stringValue(); - } - - public static List getCollectionValue(ExprValue exprValue) { - return exprValue.collectionValue(); - } - - public static Map getTupleValue(ExprValue exprValue) { - return exprValue.tupleValue(); - } - - public static Boolean getBooleanValue(ExprValue exprValue) { - return exprValue.booleanValue(); - } + public static final ExprValue LITERAL_TRUE = ExprBooleanValue.of(true); + public static final ExprValue LITERAL_FALSE = ExprBooleanValue.of(false); + public static final ExprValue LITERAL_NULL = ExprNullValue.of(); + public static final ExprValue LITERAL_MISSING = ExprMissingValue.of(); + + public static ExprValue booleanValue(Boolean value) { + return value ? LITERAL_TRUE : LITERAL_FALSE; + } + + public static ExprValue byteValue(Byte value) { + return new ExprByteValue(value); + } + + public static ExprValue shortValue(Short value) { + return new ExprShortValue(value); + } + + public static ExprValue integerValue(Integer value) { + return new ExprIntegerValue(value); + } + + public static ExprValue doubleValue(Double value) { + return new ExprDoubleValue(value); + } + + public static ExprValue floatValue(Float value) { + return new ExprFloatValue(value); + } + + public static ExprValue longValue(Long value) { + return new ExprLongValue(value); + } + + public static ExprValue stringValue(String value) { + return new ExprStringValue(value); + } + + public static ExprValue intervalValue(TemporalAmount value) { + return new ExprIntervalValue(value); + } + + public static ExprValue dateValue(LocalDate value) { + return new ExprDateValue(value); + } + + public static ExprValue datetimeValue(LocalDateTime value) { + return new ExprDatetimeValue(value); + } + + public static ExprValue timeValue(LocalTime value) { + return new ExprTimeValue(value); + } + + public static ExprValue timestampValue(Instant value) { + return new ExprTimestampValue(value); + } + + /** + * {@link ExprTupleValue} constructor. + */ + public static ExprValue tupleValue(Map map) { + LinkedHashMap valueMap = new LinkedHashMap<>(); + map.forEach((k, v) -> valueMap.put(k, v instanceof ExprValue ? (ExprValue) v : fromObjectValue(v))); + return new ExprTupleValue(valueMap); + } + + /** + * {@link ExprCollectionValue} constructor. + */ + public static ExprValue collectionValue(List list) { + List valueList = new ArrayList<>(); + list.forEach(o -> valueList.add(fromObjectValue(o))); + return new ExprCollectionValue(valueList); + } + + public static ExprValue missingValue() { + return ExprMissingValue.of(); + } + + public static ExprValue nullValue() { + return ExprNullValue.of(); + } + + /** + * Construct ExprValue from Object. + */ + public static ExprValue fromObjectValue(Object o) { + if (null == o) { + return LITERAL_NULL; + } + if (o instanceof Map) { + return tupleValue((Map) o); + } else if (o instanceof List) { + return collectionValue(((List) o)); + } else if (o instanceof Byte) { + return byteValue((Byte) o); + } else if (o instanceof Short) { + return shortValue((Short) o); + } else if (o instanceof Integer) { + return integerValue((Integer) o); + } else if (o instanceof Long) { + return longValue(((Long) o)); + } else if (o instanceof Boolean) { + return booleanValue((Boolean) o); + } else if (o instanceof Double) { + return doubleValue((Double) o); + } else if (o instanceof String) { + return stringValue((String) o); + } else if (o instanceof Float) { + return floatValue((Float) o); + } else if (o instanceof LocalDate) { + return dateValue((LocalDate) o); + } else if (o instanceof LocalDateTime) { + return datetimeValue((LocalDateTime) o); + } else if (o instanceof LocalTime) { + return timeValue((LocalTime) o); + } else if (o instanceof Instant) { + return timestampValue((Instant) o); + } else if (o instanceof TemporalAmount) { + return intervalValue((TemporalAmount) o); + } else { + throw new ExpressionEvaluationException("unsupported object " + o.getClass()); + } + } + + /** + * Construct ExprValue from Object with ExprCoreType. + */ + public static ExprValue fromObjectValue(Object o, ExprCoreType type) { + switch (type) { + case TIMESTAMP: + return new ExprTimestampValue((String) o); + case DATE: + return new ExprDateValue((String) o); + case TIME: + return new ExprTimeValue((String) o); + case DATETIME: + return new ExprDatetimeValue((String) o); + default: + return fromObjectValue(o); + } + } + + public static Byte getByteValue(ExprValue exprValue) { + return exprValue.byteValue(); + } + + public static Short getShortValue(ExprValue exprValue) { + return exprValue.shortValue(); + } + + public static Integer getIntegerValue(ExprValue exprValue) { + return exprValue.integerValue(); + } + + public static Double getDoubleValue(ExprValue exprValue) { + return exprValue.doubleValue(); + } + + public static Long getLongValue(ExprValue exprValue) { + return exprValue.longValue(); + } + + public static Float getFloatValue(ExprValue exprValue) { + return exprValue.floatValue(); + } + + public static String getStringValue(ExprValue exprValue) { + return exprValue.stringValue(); + } + + public static List getCollectionValue(ExprValue exprValue) { + return exprValue.collectionValue(); + } + + public static Map getTupleValue(ExprValue exprValue) { + return exprValue.tupleValue(); + } + + public static Boolean getBooleanValue(ExprValue exprValue) { + return exprValue.booleanValue(); + } } diff --git a/core/src/main/java/org/opensearch/sql/data/type/ExprCoreType.java b/core/src/main/java/org/opensearch/sql/data/type/ExprCoreType.java index 815f94a9df..ea21df7764 100644 --- a/core/src/main/java/org/opensearch/sql/data/type/ExprCoreType.java +++ b/core/src/main/java/org/opensearch/sql/data/type/ExprCoreType.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.type; import com.google.common.collect.ImmutableMap; @@ -19,114 +18,110 @@ * Expression Type. */ public enum ExprCoreType implements ExprType { - /** - * Unknown due to unsupported data type. - */ - UNKNOWN, - - /** - * Undefined type for special literal such as NULL. - * As the root of data type tree, it is compatible with any other type. - * In other word, undefined type is the "narrowest" type. - */ - UNDEFINED, - - /** - * Numbers. - */ - BYTE(UNDEFINED), - SHORT(BYTE), - INTEGER(SHORT), - LONG(INTEGER), - FLOAT(LONG), - DOUBLE(FLOAT), - - /** - * String. - */ - STRING(UNDEFINED), - - /** - * Boolean. - */ - BOOLEAN(STRING), - - /** - * Date. - */ - DATE(STRING), - TIME(STRING), - DATETIME(STRING, DATE, TIME), - TIMESTAMP(STRING, DATETIME), - INTERVAL(UNDEFINED), - - /** - * Struct. - */ - STRUCT(UNDEFINED), - - /** - * Array. - */ - ARRAY(UNDEFINED); - - /** - * Parents (wider/compatible types) of current base type. - */ - private final List parents = new ArrayList<>(); - - /** - * The mapping between Type and legacy JDBC type name. - */ - private static final Map LEGACY_TYPE_NAME_MAPPING = - new ImmutableMap.Builder() - .put(STRUCT, "OBJECT") - .put(ARRAY, "NESTED") - .put(STRING, "KEYWORD") - .build(); - - private static final Set NUMBER_TYPES = - new ImmutableSet.Builder() - .add(BYTE) - .add(SHORT) - .add(INTEGER) - .add(LONG) - .add(FLOAT) - .add(DOUBLE) - .build(); - - ExprCoreType(ExprCoreType... compatibleTypes) { - for (ExprCoreType subType : compatibleTypes) { - subType.parents.add(this); + /** + * Unknown due to unsupported data type. + */ + UNKNOWN, + + /** + * Undefined type for special literal such as NULL. + * As the root of data type tree, it is compatible with any other type. + * In other word, undefined type is the "narrowest" type. + */ + UNDEFINED, + + /** + * Numbers. + */ + BYTE(UNDEFINED), + SHORT(BYTE), + INTEGER(SHORT), + LONG(INTEGER), + FLOAT(LONG), + DOUBLE(FLOAT), + + /** + * String. + */ + STRING(UNDEFINED), + + /** + * Boolean. + */ + BOOLEAN(STRING), + + /** + * Date. + */ + DATE(STRING), + TIME(STRING), + DATETIME(STRING, DATE, TIME), + TIMESTAMP(STRING, DATETIME), + INTERVAL(UNDEFINED), + + /** + * Struct. + */ + STRUCT(UNDEFINED), + + /** + * Array. + */ + ARRAY(UNDEFINED); + + /** + * Parents (wider/compatible types) of current base type. + */ + private final List parents = new ArrayList<>(); + + /** + * The mapping between Type and legacy JDBC type name. + */ + private static final Map LEGACY_TYPE_NAME_MAPPING = new ImmutableMap.Builder().put( + STRUCT, + "OBJECT" + ).put(ARRAY, "NESTED").put(STRING, "KEYWORD").build(); + + private static final Set NUMBER_TYPES = new ImmutableSet.Builder().add(BYTE) + .add(SHORT) + .add(INTEGER) + .add(LONG) + .add(FLOAT) + .add(DOUBLE) + .build(); + + ExprCoreType(ExprCoreType... compatibleTypes) { + for (ExprCoreType subType : compatibleTypes) { + subType.parents.add(this); + } + } + + @Override + public List getParent() { + return parents.isEmpty() ? ExprType.super.getParent() : parents; + } + + @Override + public String typeName() { + return this.name(); + } + + @Override + public String legacyTypeName() { + return LEGACY_TYPE_NAME_MAPPING.getOrDefault(this, this.name()); + } + + /** + * Return all the valid ExprCoreType. + */ + public static List coreTypes() { + return Arrays.stream(ExprCoreType.values()) + .filter(type -> type != UNKNOWN) + .filter(type -> type != UNDEFINED) + .collect(Collectors.toList()); + } + + public static Set numberTypes() { + return NUMBER_TYPES; } - } - - @Override - public List getParent() { - return parents.isEmpty() ? ExprType.super.getParent() : parents; - } - - @Override - public String typeName() { - return this.name(); - } - - @Override - public String legacyTypeName() { - return LEGACY_TYPE_NAME_MAPPING.getOrDefault(this, this.name()); - } - - /** - * Return all the valid ExprCoreType. - */ - public static List coreTypes() { - return Arrays.stream(ExprCoreType.values()) - .filter(type -> type != UNKNOWN) - .filter(type -> type != UNDEFINED) - .collect(Collectors.toList()); - } - - public static Set numberTypes() { - return NUMBER_TYPES; - } } diff --git a/core/src/main/java/org/opensearch/sql/data/type/ExprType.java b/core/src/main/java/org/opensearch/sql/data/type/ExprType.java index 782714ba70..3ba43922f5 100644 --- a/core/src/main/java/org/opensearch/sql/data/type/ExprType.java +++ b/core/src/main/java/org/opensearch/sql/data/type/ExprType.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.type; import static org.opensearch.sql.data.type.ExprCoreType.UNKNOWN; @@ -17,51 +16,51 @@ * The Type of {@link Expression} and {@link ExprValue}. */ public interface ExprType { - /** - * Is compatible with other types. - */ - default boolean isCompatible(ExprType other) { - if (this.equals(other)) { - return true; - } else { - if (other.equals(UNKNOWN)) { - return false; - } - for (ExprType parentTypeOfOther : other.getParent()) { - if (isCompatible(parentTypeOfOther)) { - return true; + /** + * Is compatible with other types. + */ + default boolean isCompatible(ExprType other) { + if (this.equals(other)) { + return true; + } else { + if (other.equals(UNKNOWN)) { + return false; + } + for (ExprType parentTypeOfOther : other.getParent()) { + if (isCompatible(parentTypeOfOther)) { + return true; + } + } + return false; } - } - return false; } - } - /** - * Should cast this type to other type or not. By default, cast is always required - * if the given type is different from this type. - * @param other other data type - * @return true if cast is required, otherwise false - */ - default boolean shouldCast(ExprType other) { - return !this.equals(other); - } + /** + * Should cast this type to other type or not. By default, cast is always required + * if the given type is different from this type. + * @param other other data type + * @return true if cast is required, otherwise false + */ + default boolean shouldCast(ExprType other) { + return !this.equals(other); + } - /** - * Get the parent type. - */ - default List getParent() { - return Arrays.asList(UNKNOWN); - } + /** + * Get the parent type. + */ + default List getParent() { + return Arrays.asList(UNKNOWN); + } - /** - * Get the type name. - */ - String typeName(); + /** + * Get the type name. + */ + String typeName(); - /** - * Get the legacy type name for old engine. - */ - default String legacyTypeName() { - return typeName(); - } + /** + * Get the legacy type name for old engine. + */ + default String legacyTypeName() { + return typeName(); + } } diff --git a/core/src/main/java/org/opensearch/sql/data/type/WideningTypeRule.java b/core/src/main/java/org/opensearch/sql/data/type/WideningTypeRule.java index e1f356782f..72f17314a5 100644 --- a/core/src/main/java/org/opensearch/sql/data/type/WideningTypeRule.java +++ b/core/src/main/java/org/opensearch/sql/data/type/WideningTypeRule.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.type; import static org.opensearch.sql.data.type.ExprCoreType.UNKNOWN; @@ -25,52 +24,49 @@ */ @UtilityClass public class WideningTypeRule { - public static final int IMPOSSIBLE_WIDENING = Integer.MAX_VALUE; - public static final int TYPE_EQUAL = 0; + public static final int IMPOSSIBLE_WIDENING = Integer.MAX_VALUE; + public static final int TYPE_EQUAL = 0; - /** - * The widening distance is calculated from the leaf to root. - * e.g. distance(INTEGER, FLOAT) = 2, but distance(FLOAT, INTEGER) = IMPOSSIBLE_WIDENING - * - * @param type1 widen from type - * @param type2 widen to type - * @return The widening distance when widen one type to another type. - */ - public static int distance(ExprType type1, ExprType type2) { - return distance(type1, type2, TYPE_EQUAL); - } + /** + * The widening distance is calculated from the leaf to root. + * e.g. distance(INTEGER, FLOAT) = 2, but distance(FLOAT, INTEGER) = IMPOSSIBLE_WIDENING + * + * @param type1 widen from type + * @param type2 widen to type + * @return The widening distance when widen one type to another type. + */ + public static int distance(ExprType type1, ExprType type2) { + return distance(type1, type2, TYPE_EQUAL); + } - private static int distance(ExprType type1, ExprType type2, int distance) { - if (type1 == type2) { - return distance; - } else if (type1 == UNKNOWN) { - return IMPOSSIBLE_WIDENING; - } else { - return type1.getParent().stream() - .map(parentOfType1 -> distance(parentOfType1, type2, distance + 1)) - .reduce(Math::min).get(); + private static int distance(ExprType type1, ExprType type2, int distance) { + if (type1 == type2) { + return distance; + } else if (type1 == UNKNOWN) { + return IMPOSSIBLE_WIDENING; + } else { + return type1.getParent().stream().map(parentOfType1 -> distance(parentOfType1, type2, distance + 1)).reduce(Math::min).get(); + } } - } - /** - * The max type among two types. The max is defined as follow - * if type1 could widen to type2, then max is type2, vice versa - * if type1 could't widen to type2 and type2 could't widen to type1, - * then throw {@link ExpressionEvaluationException}. - * - * @param type1 type1 - * @param type2 type2 - * @return the max type among two types. - */ - public static ExprType max(ExprType type1, ExprType type2) { - int type1To2 = distance(type1, type2); - int type2To1 = distance(type2, type1); + /** + * The max type among two types. The max is defined as follow + * if type1 could widen to type2, then max is type2, vice versa + * if type1 could't widen to type2 and type2 could't widen to type1, + * then throw {@link ExpressionEvaluationException}. + * + * @param type1 type1 + * @param type2 type2 + * @return the max type among two types. + */ + public static ExprType max(ExprType type1, ExprType type2) { + int type1To2 = distance(type1, type2); + int type2To1 = distance(type2, type1); - if (type1To2 == Integer.MAX_VALUE && type2To1 == Integer.MAX_VALUE) { - throw new ExpressionEvaluationException( - String.format("no max type of %s and %s ", type1, type2)); - } else { - return type1To2 == Integer.MAX_VALUE ? type1 : type2; + if (type1To2 == Integer.MAX_VALUE && type2To1 == Integer.MAX_VALUE) { + throw new ExpressionEvaluationException(String.format("no max type of %s and %s ", type1, type2)); + } else { + return type1To2 == Integer.MAX_VALUE ? type1 : type2; + } } - } } diff --git a/core/src/main/java/org/opensearch/sql/data/utils/ExprValueOrdering.java b/core/src/main/java/org/opensearch/sql/data/utils/ExprValueOrdering.java index ef390dc53b..4017ce6fb3 100644 --- a/core/src/main/java/org/opensearch/sql/data/utils/ExprValueOrdering.java +++ b/core/src/main/java/org/opensearch/sql/data/utils/ExprValueOrdering.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.utils; import com.google.common.collect.Ordering; @@ -19,23 +18,23 @@ @RequiredArgsConstructor public abstract class ExprValueOrdering implements Comparator { - public static ExprValueOrdering natural() { - return NaturalExprValueOrdering.INSTANCE; - } + public static ExprValueOrdering natural() { + return NaturalExprValueOrdering.INSTANCE; + } - public ExprValueOrdering reverse() { - return new ReverseExprValueOrdering(this); - } + public ExprValueOrdering reverse() { + return new ReverseExprValueOrdering(this); + } - public ExprValueOrdering nullsFirst() { - return new NullsFirstExprValueOrdering(this); - } + public ExprValueOrdering nullsFirst() { + return new NullsFirstExprValueOrdering(this); + } - public ExprValueOrdering nullsLast() { - return new NullsLastExprValueOrdering(this); - } + public ExprValueOrdering nullsLast() { + return new NullsLastExprValueOrdering(this); + } - // Never make these public - static final int LEFT_IS_GREATER = 1; - static final int RIGHT_IS_GREATER = -1; + // Never make these public + static final int LEFT_IS_GREATER = 1; + static final int RIGHT_IS_GREATER = -1; } diff --git a/core/src/main/java/org/opensearch/sql/data/utils/NaturalExprValueOrdering.java b/core/src/main/java/org/opensearch/sql/data/utils/NaturalExprValueOrdering.java index 13c3606f72..7c8ba20cc8 100644 --- a/core/src/main/java/org/opensearch/sql/data/utils/NaturalExprValueOrdering.java +++ b/core/src/main/java/org/opensearch/sql/data/utils/NaturalExprValueOrdering.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.utils; import com.google.common.collect.Ordering; @@ -15,36 +14,36 @@ * org.opensearch.sql.data.model.ExprMissingValue} */ public class NaturalExprValueOrdering extends ExprValueOrdering { - static final ExprValueOrdering INSTANCE = new NaturalExprValueOrdering(); + static final ExprValueOrdering INSTANCE = new NaturalExprValueOrdering(); - private transient ExprValueOrdering nullsFirst; - private transient ExprValueOrdering nullsLast; + private transient ExprValueOrdering nullsFirst; + private transient ExprValueOrdering nullsLast; - @Override - public int compare(ExprValue left, ExprValue right) { - return left.compareTo(right); - } + @Override + public int compare(ExprValue left, ExprValue right) { + return left.compareTo(right); + } - @Override - public ExprValueOrdering nullsFirst() { - ExprValueOrdering result = nullsFirst; - if (result == null) { - result = nullsFirst = super.nullsFirst(); + @Override + public ExprValueOrdering nullsFirst() { + ExprValueOrdering result = nullsFirst; + if (result == null) { + result = nullsFirst = super.nullsFirst(); + } + return result; } - return result; - } - - @Override - public ExprValueOrdering nullsLast() { - ExprValueOrdering result = nullsLast; - if (result == null) { - result = nullsLast = super.nullsLast(); + + @Override + public ExprValueOrdering nullsLast() { + ExprValueOrdering result = nullsLast; + if (result == null) { + result = nullsLast = super.nullsLast(); + } + return result; } - return result; - } - @Override - public ExprValueOrdering reverse() { - return super.reverse(); - } + @Override + public ExprValueOrdering reverse() { + return super.reverse(); + } } diff --git a/core/src/main/java/org/opensearch/sql/data/utils/NullsFirstExprValueOrdering.java b/core/src/main/java/org/opensearch/sql/data/utils/NullsFirstExprValueOrdering.java index 03890bba61..071fd18a4a 100644 --- a/core/src/main/java/org/opensearch/sql/data/utils/NullsFirstExprValueOrdering.java +++ b/core/src/main/java/org/opensearch/sql/data/utils/NullsFirstExprValueOrdering.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.utils; import com.google.common.collect.Ordering; @@ -17,34 +16,34 @@ */ @RequiredArgsConstructor public class NullsFirstExprValueOrdering extends ExprValueOrdering { - private final ExprValueOrdering ordering; - - @Override - public int compare(ExprValue left, ExprValue right) { - if (left == right) { - return 0; - } - if (left.isNull() || left.isMissing()) { - return RIGHT_IS_GREATER; - } - if (right.isNull() || right.isMissing()) { - return LEFT_IS_GREATER; + private final ExprValueOrdering ordering; + + @Override + public int compare(ExprValue left, ExprValue right) { + if (left == right) { + return 0; + } + if (left.isNull() || left.isMissing()) { + return RIGHT_IS_GREATER; + } + if (right.isNull() || right.isMissing()) { + return LEFT_IS_GREATER; + } + return ordering.compare(left, right); } - return ordering.compare(left, right); - } - @Override - public ExprValueOrdering reverse() { - return ordering.reverse().nullsLast(); - } + @Override + public ExprValueOrdering reverse() { + return ordering.reverse().nullsLast(); + } - @Override - public ExprValueOrdering nullsFirst() { - return this; - } + @Override + public ExprValueOrdering nullsFirst() { + return this; + } - @Override - public ExprValueOrdering nullsLast() { - return ordering.nullsLast(); - } + @Override + public ExprValueOrdering nullsLast() { + return ordering.nullsLast(); + } } diff --git a/core/src/main/java/org/opensearch/sql/data/utils/NullsLastExprValueOrdering.java b/core/src/main/java/org/opensearch/sql/data/utils/NullsLastExprValueOrdering.java index 589d4b3043..aec80f7cb9 100644 --- a/core/src/main/java/org/opensearch/sql/data/utils/NullsLastExprValueOrdering.java +++ b/core/src/main/java/org/opensearch/sql/data/utils/NullsLastExprValueOrdering.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.utils; import com.google.common.collect.Ordering; @@ -17,34 +16,34 @@ */ @RequiredArgsConstructor public class NullsLastExprValueOrdering extends ExprValueOrdering { - private final ExprValueOrdering ordering; - - @Override - public int compare(ExprValue left, ExprValue right) { - if (left == right) { - return 0; - } - if (left.isNull() || left.isMissing()) { - return LEFT_IS_GREATER; - } - if (right.isNull() || right.isMissing()) { - return RIGHT_IS_GREATER; + private final ExprValueOrdering ordering; + + @Override + public int compare(ExprValue left, ExprValue right) { + if (left == right) { + return 0; + } + if (left.isNull() || left.isMissing()) { + return LEFT_IS_GREATER; + } + if (right.isNull() || right.isMissing()) { + return RIGHT_IS_GREATER; + } + return ordering.compare(left, right); } - return ordering.compare(left, right); - } - @Override - public ExprValueOrdering reverse() { - return ordering.reverse().nullsFirst(); - } + @Override + public ExprValueOrdering reverse() { + return ordering.reverse().nullsFirst(); + } - @Override - public ExprValueOrdering nullsFirst() { - return ordering.nullsFirst(); - } + @Override + public ExprValueOrdering nullsFirst() { + return ordering.nullsFirst(); + } - @Override - public ExprValueOrdering nullsLast() { - return this; - } + @Override + public ExprValueOrdering nullsLast() { + return this; + } } diff --git a/core/src/main/java/org/opensearch/sql/data/utils/ReverseExprValueOrdering.java b/core/src/main/java/org/opensearch/sql/data/utils/ReverseExprValueOrdering.java index 65fceacf99..a4bbcb3f2b 100644 --- a/core/src/main/java/org/opensearch/sql/data/utils/ReverseExprValueOrdering.java +++ b/core/src/main/java/org/opensearch/sql/data/utils/ReverseExprValueOrdering.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.utils; import com.google.common.collect.Ordering; @@ -17,15 +16,15 @@ */ @RequiredArgsConstructor public class ReverseExprValueOrdering extends ExprValueOrdering { - private final ExprValueOrdering forwardOrder; + private final ExprValueOrdering forwardOrder; - @Override - public int compare(ExprValue left, ExprValue right) { - return forwardOrder.compare(right, left); - } + @Override + public int compare(ExprValue left, ExprValue right) { + return forwardOrder.compare(right, left); + } - @Override - public ExprValueOrdering reverse() { - return forwardOrder; - } + @Override + public ExprValueOrdering reverse() { + return forwardOrder; + } } diff --git a/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java b/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java index 9167737a70..37e4a69ba5 100644 --- a/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java +++ b/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java @@ -14,62 +14,59 @@ */ public interface DataSourceService { - /** - * Returns {@link DataSource} corresponding to the DataSource name. - * - * @param dataSourceName Name of the {@link DataSource}. - * @return {@link DataSource}. - */ - DataSource getDataSource(String dataSourceName); + /** + * Returns {@link DataSource} corresponding to the DataSource name. + * + * @param dataSourceName Name of the {@link DataSource}. + * @return {@link DataSource}. + */ + DataSource getDataSource(String dataSourceName); + /** + * Returns all dataSource Metadata objects. The returned objects won't contain + * any of the credential info. + * + * @param isDefaultDataSourceRequired is used to specify + * if default opensearch connector is required in the output list. + * @return set of {@link DataSourceMetadata}. + */ + Set getDataSourceMetadata(boolean isDefaultDataSourceRequired); - /** - * Returns all dataSource Metadata objects. The returned objects won't contain - * any of the credential info. - * - * @param isDefaultDataSourceRequired is used to specify - * if default opensearch connector is required in the output list. - * @return set of {@link DataSourceMetadata}. - */ - Set getDataSourceMetadata(boolean isDefaultDataSourceRequired); + /** + * Returns dataSourceMetadata object with specific name. + * The returned objects won't contain any crendetial info. + * + * @param name name of the {@link DataSource}. + * @return set of {@link DataSourceMetadata}. + */ + DataSourceMetadata getDataSourceMetadata(String name); + /** + * Register {@link DataSource} defined by {@link DataSourceMetadata}. + * + * @param metadata {@link DataSourceMetadata}. + */ + void createDataSource(DataSourceMetadata metadata); - /** - * Returns dataSourceMetadata object with specific name. - * The returned objects won't contain any crendetial info. - * - * @param name name of the {@link DataSource}. - * @return set of {@link DataSourceMetadata}. - */ - DataSourceMetadata getDataSourceMetadata(String name); + /** + * Updates {@link DataSource} corresponding to dataSourceMetadata. + * + * @param dataSourceMetadata {@link DataSourceMetadata}. + */ + void updateDataSource(DataSourceMetadata dataSourceMetadata); - /** - * Register {@link DataSource} defined by {@link DataSourceMetadata}. - * - * @param metadata {@link DataSourceMetadata}. - */ - void createDataSource(DataSourceMetadata metadata); + /** + * Deletes {@link DataSource} corresponding to the DataSource name. + * + * @param dataSourceName name of the {@link DataSource}. + */ + void deleteDataSource(String dataSourceName); - /** - * Updates {@link DataSource} corresponding to dataSourceMetadata. - * - * @param dataSourceMetadata {@link DataSourceMetadata}. - */ - void updateDataSource(DataSourceMetadata dataSourceMetadata); - - - /** - * Deletes {@link DataSource} corresponding to the DataSource name. - * - * @param dataSourceName name of the {@link DataSource}. - */ - void deleteDataSource(String dataSourceName); - - /** - * Returns true {@link Boolean} if datasource with dataSourceName exists - * or else false {@link Boolean}. - * - * @param dataSourceName name of the {@link DataSource}. - */ - Boolean dataSourceExists(String dataSourceName); + /** + * Returns true {@link Boolean} if datasource with dataSourceName exists + * or else false {@link Boolean}. + * + * @param dataSourceName name of the {@link DataSource}. + */ + Boolean dataSourceExists(String dataSourceName); } diff --git a/core/src/main/java/org/opensearch/sql/datasource/model/DataSource.java b/core/src/main/java/org/opensearch/sql/datasource/model/DataSource.java index 5deb460961..a64f143bd3 100644 --- a/core/src/main/java/org/opensearch/sql/datasource/model/DataSource.java +++ b/core/src/main/java/org/opensearch/sql/datasource/model/DataSource.java @@ -20,11 +20,11 @@ @EqualsAndHashCode public class DataSource { - private final String name; + private final String name; - private final DataSourceType connectorType; + private final DataSourceType connectorType; - @EqualsAndHashCode.Exclude - private final StorageEngine storageEngine; + @EqualsAndHashCode.Exclude + private final StorageEngine storageEngine; } diff --git a/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceMetadata.java b/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceMetadata.java index 7945f8aec3..62c5c57393 100644 --- a/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceMetadata.java +++ b/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceMetadata.java @@ -5,7 +5,6 @@ package org.opensearch.sql.datasource.model; - import static org.opensearch.sql.analysis.DataSourceSchemaIdentifierNameResolver.DEFAULT_DATASOURCE_NAME; import com.fasterxml.jackson.annotation.JsonFormat; @@ -30,25 +29,24 @@ @JsonIgnoreProperties(ignoreUnknown = true) public class DataSourceMetadata { - @JsonProperty - private String name; + @JsonProperty + private String name; - @JsonProperty - @JsonFormat(with = JsonFormat.Feature.ACCEPT_CASE_INSENSITIVE_PROPERTIES) - private DataSourceType connector; + @JsonProperty + @JsonFormat(with = JsonFormat.Feature.ACCEPT_CASE_INSENSITIVE_PROPERTIES) + private DataSourceType connector; - @JsonProperty - private List allowedRoles; + @JsonProperty + private List allowedRoles; - @JsonProperty - private Map properties; + @JsonProperty + private Map properties; - /** - * Default OpenSearch {@link DataSourceMetadata}. Which is used to register default OpenSearch - * {@link DataSource} to {@link DataSourceService}. - */ - public static DataSourceMetadata defaultOpenSearchDataSourceMetadata() { - return new DataSourceMetadata(DEFAULT_DATASOURCE_NAME, - DataSourceType.OPENSEARCH, Collections.emptyList(), ImmutableMap.of()); - } + /** + * Default OpenSearch {@link DataSourceMetadata}. Which is used to register default OpenSearch + * {@link DataSource} to {@link DataSourceService}. + */ + public static DataSourceMetadata defaultOpenSearchDataSourceMetadata() { + return new DataSourceMetadata(DEFAULT_DATASOURCE_NAME, DataSourceType.OPENSEARCH, Collections.emptyList(), ImmutableMap.of()); + } } diff --git a/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceType.java b/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceType.java index 5010e41942..d3be3d4bfd 100644 --- a/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceType.java +++ b/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceType.java @@ -6,31 +6,32 @@ package org.opensearch.sql.datasource.model; public enum DataSourceType { - PROMETHEUS("prometheus"), - OPENSEARCH("opensearch"), - SPARK("spark"); - private String text; + PROMETHEUS("prometheus"), + OPENSEARCH("opensearch"), + SPARK("spark"); - DataSourceType(String text) { - this.text = text; - } + private String text; - public String getText() { - return this.text; - } + DataSourceType(String text) { + this.text = text; + } + + public String getText() { + return this.text; + } - /** - * Get DataSourceType from text. - * - * @param text text. - * @return DataSourceType {@link DataSourceType}. - */ - public static DataSourceType fromString(String text) { - for (DataSourceType dataSourceType : DataSourceType.values()) { - if (dataSourceType.text.equalsIgnoreCase(text)) { - return dataSourceType; - } + /** + * Get DataSourceType from text. + * + * @param text text. + * @return DataSourceType {@link DataSourceType}. + */ + public static DataSourceType fromString(String text) { + for (DataSourceType dataSourceType : DataSourceType.values()) { + if (dataSourceType.text.equalsIgnoreCase(text)) { + return dataSourceType; + } + } + throw new IllegalArgumentException("No DataSourceType with text " + text + " found"); } - throw new IllegalArgumentException("No DataSourceType with text " + text + " found"); - } } diff --git a/core/src/main/java/org/opensearch/sql/executor/ExecutionContext.java b/core/src/main/java/org/opensearch/sql/executor/ExecutionContext.java index 8a3162068f..710f987ce9 100644 --- a/core/src/main/java/org/opensearch/sql/executor/ExecutionContext.java +++ b/core/src/main/java/org/opensearch/sql/executor/ExecutionContext.java @@ -13,18 +13,18 @@ * Execution context hold planning related information. */ public class ExecutionContext { - @Getter - private final Optional split; + @Getter + private final Optional split; - public ExecutionContext(Split split) { - this.split = Optional.of(split); - } + public ExecutionContext(Split split) { + this.split = Optional.of(split); + } - private ExecutionContext(Optional split) { - this.split = split; - } + private ExecutionContext(Optional split) { + this.split = split; + } - public static ExecutionContext emptyExecutionContext() { - return new ExecutionContext(Optional.empty()); - } + public static ExecutionContext emptyExecutionContext() { + return new ExecutionContext(Optional.empty()); + } } diff --git a/core/src/main/java/org/opensearch/sql/executor/ExecutionEngine.java b/core/src/main/java/org/opensearch/sql/executor/ExecutionEngine.java index 9465da22c9..c605404e58 100644 --- a/core/src/main/java/org/opensearch/sql/executor/ExecutionEngine.java +++ b/core/src/main/java/org/opensearch/sql/executor/ExecutionEngine.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.executor; import java.util.List; @@ -22,69 +21,68 @@ */ public interface ExecutionEngine { - /** - * Execute physical plan and call back response listener. - * Todo. deprecated this interface after finalize {@link ExecutionContext}. - * - * @param plan executable physical plan - * @param listener response listener - */ - void execute(PhysicalPlan plan, ResponseListener listener); - - /** - * Execute physical plan with {@link ExecutionContext} and call back response listener. - */ - void execute(PhysicalPlan plan, ExecutionContext context, - ResponseListener listener); + /** + * Execute physical plan and call back response listener. + * Todo. deprecated this interface after finalize {@link ExecutionContext}. + * + * @param plan executable physical plan + * @param listener response listener + */ + void execute(PhysicalPlan plan, ResponseListener listener); - /** - * Explain physical plan and call back response listener. The reason why this has to - * be part of execution engine interface is that the physical plan probably needs to - * be executed to get more info for profiling, such as actual execution time, rows fetched etc. - * - * @param plan physical plan to explain - * @param listener response listener - */ - void explain(PhysicalPlan plan, ResponseListener listener); + /** + * Execute physical plan with {@link ExecutionContext} and call back response listener. + */ + void execute(PhysicalPlan plan, ExecutionContext context, ResponseListener listener); - /** - * Data class that encapsulates ExprValue. - */ - @Data - class QueryResponse { - private final Schema schema; - private final List results; - private final Cursor cursor; - } + /** + * Explain physical plan and call back response listener. The reason why this has to + * be part of execution engine interface is that the physical plan probably needs to + * be executed to get more info for profiling, such as actual execution time, rows fetched etc. + * + * @param plan physical plan to explain + * @param listener response listener + */ + void explain(PhysicalPlan plan, ResponseListener listener); - @Data - class Schema { - private final List columns; + /** + * Data class that encapsulates ExprValue. + */ + @Data + class QueryResponse { + private final Schema schema; + private final List results; + private final Cursor cursor; + } @Data - public static class Column { - private final String name; - private final String alias; - private final ExprType exprType; + class Schema { + private final List columns; + + @Data + public static class Column { + private final String name; + private final String alias; + private final ExprType exprType; + } } - } - /** - * Data class that encapsulates explain result. This can help decouple core engine - * from concrete explain response format. - */ - @Data - class ExplainResponse { - private final ExplainResponseNode root; - } + /** + * Data class that encapsulates explain result. This can help decouple core engine + * from concrete explain response format. + */ + @Data + class ExplainResponse { + private final ExplainResponseNode root; + } - @AllArgsConstructor - @Data - @RequiredArgsConstructor - class ExplainResponseNode { - private final String name; - private Map description; - private List children; - } + @AllArgsConstructor + @Data + @RequiredArgsConstructor + class ExplainResponseNode { + private final String name; + private Map description; + private List children; + } } diff --git a/core/src/main/java/org/opensearch/sql/executor/Explain.java b/core/src/main/java/org/opensearch/sql/executor/Explain.java index 7c16e0b720..3fb7f588f6 100644 --- a/core/src/main/java/org/opensearch/sql/executor/Explain.java +++ b/core/src/main/java/org/opensearch/sql/executor/Explain.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.executor; import com.google.common.collect.ImmutableMap; @@ -38,150 +37,198 @@ /** * Visitor that explains a physical plan to JSON format. */ -public class Explain extends PhysicalPlanNodeVisitor - implements Function { - - @Override - public ExplainResponse apply(PhysicalPlan plan) { - return new ExplainResponse(plan.accept(this, null)); - } - - @Override - public ExplainResponseNode visitProject(ProjectOperator node, Object context) { - return explain(node, context, explainNode -> explainNode.setDescription(ImmutableMap.of( - "fields", node.getProjectList().toString()))); - } - - @Override - public ExplainResponseNode visitFilter(FilterOperator node, Object context) { - return explain(node, context, explainNode -> explainNode.setDescription(ImmutableMap.of( - "conditions", node.getConditions().toString()))); - } - - @Override - public ExplainResponseNode visitSort(SortOperator node, Object context) { - return explain(node, context, explainNode -> explainNode.setDescription(ImmutableMap.of( - "sortList", describeSortList(node.getSortList())))); - } - - @Override - public ExplainResponseNode visitTableScan(TableScanOperator node, Object context) { - return explain(node, context, explainNode -> explainNode.setDescription(ImmutableMap.of( - "request", node.toString()))); - } - - @Override - public ExplainResponseNode visitAggregation(AggregationOperator node, Object context) { - return explain(node, context, explainNode -> explainNode.setDescription(ImmutableMap.of( - "aggregators", node.getAggregatorList().toString(), - "groupBy", node.getGroupByExprList().toString()))); - } - - @Override - public ExplainResponseNode visitWindow(WindowOperator node, Object context) { - return explain(node, context, explainNode -> explainNode.setDescription(ImmutableMap.of( - "function", node.getWindowFunction().toString(), - "definition", ImmutableMap.of( - "partitionBy", node.getWindowDefinition().getPartitionByList().toString(), - "sortList", describeSortList(node.getWindowDefinition().getSortList()))))); - } - - @Override - public ExplainResponseNode visitRename(RenameOperator node, Object context) { - Map renameMappingDescription = - node.getMapping() +public class Explain extends PhysicalPlanNodeVisitor implements Function { + + @Override + public ExplainResponse apply(PhysicalPlan plan) { + return new ExplainResponse(plan.accept(this, null)); + } + + @Override + public ExplainResponseNode visitProject(ProjectOperator node, Object context) { + return explain( + node, + context, + explainNode -> explainNode.setDescription(ImmutableMap.of("fields", node.getProjectList().toString())) + ); + } + + @Override + public ExplainResponseNode visitFilter(FilterOperator node, Object context) { + return explain( + node, + context, + explainNode -> explainNode.setDescription(ImmutableMap.of("conditions", node.getConditions().toString())) + ); + } + + @Override + public ExplainResponseNode visitSort(SortOperator node, Object context) { + return explain( + node, + context, + explainNode -> explainNode.setDescription(ImmutableMap.of("sortList", describeSortList(node.getSortList()))) + ); + } + + @Override + public ExplainResponseNode visitTableScan(TableScanOperator node, Object context) { + return explain(node, context, explainNode -> explainNode.setDescription(ImmutableMap.of("request", node.toString()))); + } + + @Override + public ExplainResponseNode visitAggregation(AggregationOperator node, Object context) { + return explain( + node, + context, + explainNode -> explainNode.setDescription( + ImmutableMap.of("aggregators", node.getAggregatorList().toString(), "groupBy", node.getGroupByExprList().toString()) + ) + ); + } + + @Override + public ExplainResponseNode visitWindow(WindowOperator node, Object context) { + return explain( + node, + context, + explainNode -> explainNode.setDescription( + ImmutableMap.of( + "function", + node.getWindowFunction().toString(), + "definition", + ImmutableMap.of( + "partitionBy", + node.getWindowDefinition().getPartitionByList().toString(), + "sortList", + describeSortList(node.getWindowDefinition().getSortList()) + ) + ) + ) + ); + } + + @Override + public ExplainResponseNode visitRename(RenameOperator node, Object context) { + Map renameMappingDescription = node.getMapping() .entrySet() .stream() - .collect(Collectors.toMap( - e -> e.getKey().toString(), - e -> e.getValue().toString())); - - return explain(node, context, explainNode -> explainNode.setDescription(ImmutableMap.of( - "mapping", renameMappingDescription))); - } - - @Override - public ExplainResponseNode visitRemove(RemoveOperator node, Object context) { - return explain(node, context, explainNode -> explainNode.setDescription(ImmutableMap.of( - "removeList", node.getRemoveList().toString()))); - } - - @Override - public ExplainResponseNode visitEval(EvalOperator node, Object context) { - return explain(node, context, explainNode -> explainNode.setDescription(ImmutableMap.of( - "expressions", convertPairListToMap(node.getExpressionList())))); - } - - @Override - public ExplainResponseNode visitDedupe(DedupeOperator node, Object context) { - return explain(node, context, explainNode -> explainNode.setDescription(ImmutableMap.of( - "dedupeList", node.getDedupeList().toString(), - "allowedDuplication", node.getAllowedDuplication(), - "keepEmpty", node.getKeepEmpty(), - "consecutive", node.getConsecutive()))); - } - - @Override - public ExplainResponseNode visitRareTopN(RareTopNOperator node, Object context) { - return explain(node, context, explainNode -> explainNode.setDescription(ImmutableMap.of( - "commandType", node.getCommandType(), - "noOfResults", node.getNoOfResults(), - "fields", node.getFieldExprList().toString(), - "groupBy", node.getGroupByExprList().toString() - ))); - } - - @Override - public ExplainResponseNode visitValues(ValuesOperator node, Object context) { - return explain(node, context, explainNode -> explainNode.setDescription(ImmutableMap.of( - "values", node.getValues()))); - } - - @Override - public ExplainResponseNode visitLimit(LimitOperator node, Object context) { - return explain(node, context, explanNode -> explanNode.setDescription(ImmutableMap.of( - "limit", node.getLimit(), "offset", node.getOffset()))); - } - - @Override - public ExplainResponseNode visitNested(NestedOperator node, Object context) { - return explain(node, context, explanNode -> explanNode.setDescription(ImmutableMap.of( - "nested", node.getFields()))); - } - - protected ExplainResponseNode explain(PhysicalPlan node, Object context, - Consumer doExplain) { - ExplainResponseNode explainNode = new ExplainResponseNode(getOperatorName(node)); - - List children = new ArrayList<>(); - for (PhysicalPlan child : node.getChild()) { - children.add(child.accept(this, context)); - } - explainNode.setChildren(children); - - doExplain.accept(explainNode); - return explainNode; - } - - private String getOperatorName(PhysicalPlan node) { - return node.getClass().getSimpleName(); - } - - private Map convertPairListToMap(List> pairs) { - return pairs.stream() - .collect(Collectors.toMap( - p -> p.getLeft().toString(), - p -> p.getRight().toString())); - } - - private Map> describeSortList( - List> sortList) { - return sortList.stream() - .collect(Collectors.toMap( - p -> p.getRight().toString(), - p -> ImmutableMap.of( - "sortOrder", p.getLeft().getSortOrder().toString(), - "nullOrder", p.getLeft().getNullOrder().toString()))); - } + .collect(Collectors.toMap(e -> e.getKey().toString(), e -> e.getValue().toString())); + + return explain(node, context, explainNode -> explainNode.setDescription(ImmutableMap.of("mapping", renameMappingDescription))); + } + + @Override + public ExplainResponseNode visitRemove(RemoveOperator node, Object context) { + return explain( + node, + context, + explainNode -> explainNode.setDescription(ImmutableMap.of("removeList", node.getRemoveList().toString())) + ); + } + + @Override + public ExplainResponseNode visitEval(EvalOperator node, Object context) { + return explain( + node, + context, + explainNode -> explainNode.setDescription(ImmutableMap.of("expressions", convertPairListToMap(node.getExpressionList()))) + ); + } + + @Override + public ExplainResponseNode visitDedupe(DedupeOperator node, Object context) { + return explain( + node, + context, + explainNode -> explainNode.setDescription( + ImmutableMap.of( + "dedupeList", + node.getDedupeList().toString(), + "allowedDuplication", + node.getAllowedDuplication(), + "keepEmpty", + node.getKeepEmpty(), + "consecutive", + node.getConsecutive() + ) + ) + ); + } + + @Override + public ExplainResponseNode visitRareTopN(RareTopNOperator node, Object context) { + return explain( + node, + context, + explainNode -> explainNode.setDescription( + ImmutableMap.of( + "commandType", + node.getCommandType(), + "noOfResults", + node.getNoOfResults(), + "fields", + node.getFieldExprList().toString(), + "groupBy", + node.getGroupByExprList().toString() + ) + ) + ); + } + + @Override + public ExplainResponseNode visitValues(ValuesOperator node, Object context) { + return explain(node, context, explainNode -> explainNode.setDescription(ImmutableMap.of("values", node.getValues()))); + } + + @Override + public ExplainResponseNode visitLimit(LimitOperator node, Object context) { + return explain( + node, + context, + explanNode -> explanNode.setDescription(ImmutableMap.of("limit", node.getLimit(), "offset", node.getOffset())) + ); + } + + @Override + public ExplainResponseNode visitNested(NestedOperator node, Object context) { + return explain(node, context, explanNode -> explanNode.setDescription(ImmutableMap.of("nested", node.getFields()))); + } + + protected ExplainResponseNode explain(PhysicalPlan node, Object context, Consumer doExplain) { + ExplainResponseNode explainNode = new ExplainResponseNode(getOperatorName(node)); + + List children = new ArrayList<>(); + for (PhysicalPlan child : node.getChild()) { + children.add(child.accept(this, context)); + } + explainNode.setChildren(children); + + doExplain.accept(explainNode); + return explainNode; + } + + private String getOperatorName(PhysicalPlan node) { + return node.getClass().getSimpleName(); + } + + private Map convertPairListToMap(List> pairs) { + return pairs.stream().collect(Collectors.toMap(p -> p.getLeft().toString(), p -> p.getRight().toString())); + } + + private Map> describeSortList(List> sortList) { + return sortList.stream() + .collect( + Collectors.toMap( + p -> p.getRight().toString(), + p -> ImmutableMap.of( + "sortOrder", + p.getLeft().getSortOrder().toString(), + "nullOrder", + p.getLeft().getNullOrder().toString() + ) + ) + ); + } } diff --git a/core/src/main/java/org/opensearch/sql/executor/QueryId.java b/core/src/main/java/org/opensearch/sql/executor/QueryId.java index 933cb5d82d..dc62879384 100644 --- a/core/src/main/java/org/opensearch/sql/executor/QueryId.java +++ b/core/src/main/java/org/opensearch/sql/executor/QueryId.java @@ -16,21 +16,21 @@ * Query id of {@link AbstractPlan}. */ public class QueryId { - /** - * Query id. - */ - @Getter - private final String queryId; + /** + * Query id. + */ + @Getter + private final String queryId; - /** - * Generate {@link QueryId}. - * @return {@link QueryId}. - */ - public static QueryId queryId() { - return new QueryId(RandomStringUtils.random(10, true, true)); - } + /** + * Generate {@link QueryId}. + * @return {@link QueryId}. + */ + public static QueryId queryId() { + return new QueryId(RandomStringUtils.random(10, true, true)); + } - private QueryId(String queryId) { - this.queryId = queryId; - } + private QueryId(String queryId) { + this.queryId = queryId; + } } diff --git a/core/src/main/java/org/opensearch/sql/executor/QueryManager.java b/core/src/main/java/org/opensearch/sql/executor/QueryManager.java index 5b41d7ce2e..d1068de0d7 100644 --- a/core/src/main/java/org/opensearch/sql/executor/QueryManager.java +++ b/core/src/main/java/org/opensearch/sql/executor/QueryManager.java @@ -16,19 +16,19 @@ */ public interface QueryManager { - /** - * Submit {@link AbstractPlan}. - * @param queryPlan {@link AbstractPlan}. - * @return {@link QueryId}. - */ - QueryId submit(AbstractPlan queryPlan); + /** + * Submit {@link AbstractPlan}. + * @param queryPlan {@link AbstractPlan}. + * @return {@link QueryId}. + */ + QueryId submit(AbstractPlan queryPlan); - /** - * Cancel submitted {@link AbstractPlan} by {@link QueryId}. - * - * @return true indicate successful. - */ - default boolean cancel(QueryId queryId) { - throw new UnsupportedOperationException(); - } + /** + * Cancel submitted {@link AbstractPlan} by {@link QueryId}. + * + * @return true indicate successful. + */ + default boolean cancel(QueryId queryId) { + throw new UnsupportedOperationException(); + } } diff --git a/core/src/main/java/org/opensearch/sql/executor/QueryService.java b/core/src/main/java/org/opensearch/sql/executor/QueryService.java index 94e7081920..7db9eea7f3 100644 --- a/core/src/main/java/org/opensearch/sql/executor/QueryService.java +++ b/core/src/main/java/org/opensearch/sql/executor/QueryService.java @@ -24,79 +24,74 @@ @RequiredArgsConstructor public class QueryService { - private final Analyzer analyzer; + private final Analyzer analyzer; - private final ExecutionEngine executionEngine; + private final ExecutionEngine executionEngine; - private final Planner planner; + private final Planner planner; - /** - * Execute the {@link UnresolvedPlan}, using {@link ResponseListener} to get response. - * Todo. deprecated this interface after finalize {@link PlanContext}. - * - * @param plan {@link UnresolvedPlan} - * @param listener {@link ResponseListener} - */ - public void execute(UnresolvedPlan plan, - ResponseListener listener) { - try { - executePlan(analyze(plan), PlanContext.emptyPlanContext(), listener); - } catch (Exception e) { - listener.onFailure(e); + /** + * Execute the {@link UnresolvedPlan}, using {@link ResponseListener} to get response. + * Todo. deprecated this interface after finalize {@link PlanContext}. + * + * @param plan {@link UnresolvedPlan} + * @param listener {@link ResponseListener} + */ + public void execute(UnresolvedPlan plan, ResponseListener listener) { + try { + executePlan(analyze(plan), PlanContext.emptyPlanContext(), listener); + } catch (Exception e) { + listener.onFailure(e); + } } - } - /** - * Execute the {@link UnresolvedPlan}, with {@link PlanContext} and using {@link ResponseListener} - * to get response. - * Todo. Pass split from PlanContext to ExecutionEngine in following PR. - * - * @param plan {@link LogicalPlan} - * @param planContext {@link PlanContext} - * @param listener {@link ResponseListener} - */ - public void executePlan(LogicalPlan plan, - PlanContext planContext, - ResponseListener listener) { - try { - planContext - .getSplit() - .ifPresentOrElse( - split -> executionEngine.execute(plan(plan), new ExecutionContext(split), listener), - () -> executionEngine.execute( - plan(plan), ExecutionContext.emptyExecutionContext(), listener)); - } catch (Exception e) { - listener.onFailure(e); + /** + * Execute the {@link UnresolvedPlan}, with {@link PlanContext} and using {@link ResponseListener} + * to get response. + * Todo. Pass split from PlanContext to ExecutionEngine in following PR. + * + * @param plan {@link LogicalPlan} + * @param planContext {@link PlanContext} + * @param listener {@link ResponseListener} + */ + public void executePlan(LogicalPlan plan, PlanContext planContext, ResponseListener listener) { + try { + planContext.getSplit() + .ifPresentOrElse( + split -> executionEngine.execute(plan(plan), new ExecutionContext(split), listener), + () -> executionEngine.execute(plan(plan), ExecutionContext.emptyExecutionContext(), listener) + ); + } catch (Exception e) { + listener.onFailure(e); + } } - } - /** - * Explain the query in {@link UnresolvedPlan} using {@link ResponseListener} to - * get and format explain response. - * - * @param plan {@link UnresolvedPlan} - * @param listener {@link ResponseListener} for explain response - */ - public void explain(UnresolvedPlan plan, - ResponseListener listener) { - try { - executionEngine.explain(plan(analyze(plan)), listener); - } catch (Exception e) { - listener.onFailure(e); + /** + * Explain the query in {@link UnresolvedPlan} using {@link ResponseListener} to + * get and format explain response. + * + * @param plan {@link UnresolvedPlan} + * @param listener {@link ResponseListener} for explain response + */ + public void explain(UnresolvedPlan plan, ResponseListener listener) { + try { + executionEngine.explain(plan(analyze(plan)), listener); + } catch (Exception e) { + listener.onFailure(e); + } } - } - /** - * Analyze {@link UnresolvedPlan}. - */ - public LogicalPlan analyze(UnresolvedPlan plan) { - return analyzer.analyze(plan, new AnalysisContext()); - } + /** + * Analyze {@link UnresolvedPlan}. + */ + public LogicalPlan analyze(UnresolvedPlan plan) { + return analyzer.analyze(plan, new AnalysisContext()); + } - /** - * Translate {@link LogicalPlan} to {@link PhysicalPlan}. - */ - public PhysicalPlan plan(LogicalPlan plan) { - return planner.plan(plan); - } + /** + * Translate {@link LogicalPlan} to {@link PhysicalPlan}. + */ + public PhysicalPlan plan(LogicalPlan plan) { + return planner.plan(plan); + } } diff --git a/core/src/main/java/org/opensearch/sql/executor/execution/AbstractPlan.java b/core/src/main/java/org/opensearch/sql/executor/execution/AbstractPlan.java index 1654293c04..9724688ab6 100644 --- a/core/src/main/java/org/opensearch/sql/executor/execution/AbstractPlan.java +++ b/core/src/main/java/org/opensearch/sql/executor/execution/AbstractPlan.java @@ -8,7 +8,6 @@ package org.opensearch.sql.executor.execution; - import lombok.Getter; import lombok.RequiredArgsConstructor; import org.opensearch.sql.common.response.ResponseListener; @@ -21,21 +20,21 @@ @RequiredArgsConstructor public abstract class AbstractPlan { - /** - * Uniq query id. - */ - @Getter - private final QueryId queryId; + /** + * Uniq query id. + */ + @Getter + private final QueryId queryId; - /** - * Start query execution. - */ - public abstract void execute(); + /** + * Start query execution. + */ + public abstract void execute(); - /** - * Explain query execution. - * - * @param listener query explain response listener. - */ - public abstract void explain(ResponseListener listener); + /** + * Explain query execution. + * + * @param listener query explain response listener. + */ + public abstract void explain(ResponseListener listener); } diff --git a/core/src/main/java/org/opensearch/sql/executor/execution/CommandPlan.java b/core/src/main/java/org/opensearch/sql/executor/execution/CommandPlan.java index 0ea5266084..ac20d75225 100644 --- a/core/src/main/java/org/opensearch/sql/executor/execution/CommandPlan.java +++ b/core/src/main/java/org/opensearch/sql/executor/execution/CommandPlan.java @@ -20,34 +20,38 @@ */ public class CommandPlan extends AbstractPlan { - /** - * The query plan ast. - */ - protected final UnresolvedPlan plan; - - /** - * Query service. - */ - protected final QueryService queryService; - - protected final ResponseListener listener; - - /** Constructor. */ - public CommandPlan(QueryId queryId, UnresolvedPlan plan, QueryService queryService, - ResponseListener listener) { - super(queryId); - this.plan = plan; - this.queryService = queryService; - this.listener = listener; - } - - @Override - public void execute() { - queryService.execute(plan, listener); - } - - @Override - public void explain(ResponseListener listener) { - throw new UnsupportedOperationException("CommandPlan does not support explain"); - } + /** + * The query plan ast. + */ + protected final UnresolvedPlan plan; + + /** + * Query service. + */ + protected final QueryService queryService; + + protected final ResponseListener listener; + + /** Constructor. */ + public CommandPlan( + QueryId queryId, + UnresolvedPlan plan, + QueryService queryService, + ResponseListener listener + ) { + super(queryId); + this.plan = plan; + this.queryService = queryService; + this.listener = listener; + } + + @Override + public void execute() { + queryService.execute(plan, listener); + } + + @Override + public void explain(ResponseListener listener) { + throw new UnsupportedOperationException("CommandPlan does not support explain"); + } } diff --git a/core/src/main/java/org/opensearch/sql/executor/execution/ExplainPlan.java b/core/src/main/java/org/opensearch/sql/executor/execution/ExplainPlan.java index 8c784f82ed..1f9f938d16 100644 --- a/core/src/main/java/org/opensearch/sql/executor/execution/ExplainPlan.java +++ b/core/src/main/java/org/opensearch/sql/executor/execution/ExplainPlan.java @@ -17,28 +17,26 @@ */ public class ExplainPlan extends AbstractPlan { - private final AbstractPlan plan; - - private final ResponseListener explainListener; - - /** - * Constructor. - */ - public ExplainPlan(QueryId queryId, - AbstractPlan plan, - ResponseListener explainListener) { - super(queryId); - this.plan = plan; - this.explainListener = explainListener; - } - - @Override - public void execute() { - plan.explain(explainListener); - } - - @Override - public void explain(ResponseListener listener) { - throw new UnsupportedOperationException("explain query can not been explained."); - } + private final AbstractPlan plan; + + private final ResponseListener explainListener; + + /** + * Constructor. + */ + public ExplainPlan(QueryId queryId, AbstractPlan plan, ResponseListener explainListener) { + super(queryId); + this.plan = plan; + this.explainListener = explainListener; + } + + @Override + public void execute() { + plan.explain(explainListener); + } + + @Override + public void explain(ResponseListener listener) { + throw new UnsupportedOperationException("explain query can not been explained."); + } } diff --git a/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlan.java b/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlan.java index aeecf3e76f..966a4d43e1 100644 --- a/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlan.java +++ b/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlan.java @@ -22,63 +22,64 @@ */ public class QueryPlan extends AbstractPlan { - /** - * The query plan ast. - */ - protected final UnresolvedPlan plan; + /** + * The query plan ast. + */ + protected final UnresolvedPlan plan; - /** - * Query service. - */ - protected final QueryService queryService; + /** + * Query service. + */ + protected final QueryService queryService; - protected final ResponseListener listener; + protected final ResponseListener listener; - protected final Optional pageSize; + protected final Optional pageSize; - /** Constructor. */ - public QueryPlan( - QueryId queryId, - UnresolvedPlan plan, - QueryService queryService, - ResponseListener listener) { - super(queryId); - this.plan = plan; - this.queryService = queryService; - this.listener = listener; - this.pageSize = Optional.empty(); - } + /** Constructor. */ + public QueryPlan( + QueryId queryId, + UnresolvedPlan plan, + QueryService queryService, + ResponseListener listener + ) { + super(queryId); + this.plan = plan; + this.queryService = queryService; + this.listener = listener; + this.pageSize = Optional.empty(); + } - /** Constructor with page size. */ - public QueryPlan( - QueryId queryId, - UnresolvedPlan plan, - int pageSize, - QueryService queryService, - ResponseListener listener) { - super(queryId); - this.plan = plan; - this.queryService = queryService; - this.listener = listener; - this.pageSize = Optional.of(pageSize); - } + /** Constructor with page size. */ + public QueryPlan( + QueryId queryId, + UnresolvedPlan plan, + int pageSize, + QueryService queryService, + ResponseListener listener + ) { + super(queryId); + this.plan = plan; + this.queryService = queryService; + this.listener = listener; + this.pageSize = Optional.of(pageSize); + } - @Override - public void execute() { - if (pageSize.isPresent()) { - queryService.execute(new Paginate(pageSize.get(), plan), listener); - } else { - queryService.execute(plan, listener); + @Override + public void execute() { + if (pageSize.isPresent()) { + queryService.execute(new Paginate(pageSize.get(), plan), listener); + } else { + queryService.execute(plan, listener); + } } - } - @Override - public void explain(ResponseListener listener) { - if (pageSize.isPresent()) { - listener.onFailure(new NotImplementedException( - "`explain` feature for paginated requests is not implemented yet.")); - } else { - queryService.explain(plan, listener); + @Override + public void explain(ResponseListener listener) { + if (pageSize.isPresent()) { + listener.onFailure(new NotImplementedException("`explain` feature for paginated requests is not implemented yet.")); + } else { + queryService.explain(plan, listener); + } } - } } diff --git a/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlanFactory.java b/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlanFactory.java index 3273eb3c18..e1606c033d 100644 --- a/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlanFactory.java +++ b/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlanFactory.java @@ -31,110 +31,98 @@ * QueryExecution Factory. */ @RequiredArgsConstructor -public class QueryPlanFactory - extends AbstractNodeVisitor< - AbstractPlan, - Pair< - Optional>, - Optional>>> { +public class QueryPlanFactory extends AbstractNodeVisitor< + AbstractPlan, + Pair>, Optional>>> { - /** - * Query Service. - */ - private final QueryService queryService; + /** + * Query Service. + */ + private final QueryService queryService; - /** - * NO_CONSUMER_RESPONSE_LISTENER should never be called. It is only used as constructor - * parameter of {@link QueryPlan}. - */ - @VisibleForTesting - protected static final ResponseListener - NO_CONSUMER_RESPONSE_LISTENER = - new ResponseListener<>() { - @Override - public void onResponse(ExecutionEngine.QueryResponse response) { - throw new IllegalStateException( - "[BUG] query response should not sent to unexpected channel"); - } + /** + * NO_CONSUMER_RESPONSE_LISTENER should never be called. It is only used as constructor + * parameter of {@link QueryPlan}. + */ + @VisibleForTesting + protected static final ResponseListener NO_CONSUMER_RESPONSE_LISTENER = new ResponseListener<>() { + @Override + public void onResponse(ExecutionEngine.QueryResponse response) { + throw new IllegalStateException("[BUG] query response should not sent to unexpected channel"); + } - @Override - public void onFailure(Exception e) { - throw new IllegalStateException( - "[BUG] exception response should not sent to unexpected channel"); - } - }; + @Override + public void onFailure(Exception e) { + throw new IllegalStateException("[BUG] exception response should not sent to unexpected channel"); + } + }; - /** - * Create QueryExecution from Statement. - */ - public AbstractPlan create( - Statement statement, - Optional> queryListener, - Optional> explainListener) { - return statement.accept(this, Pair.of(queryListener, explainListener)); - } + /** + * Create QueryExecution from Statement. + */ + public AbstractPlan create( + Statement statement, + Optional> queryListener, + Optional> explainListener + ) { + return statement.accept(this, Pair.of(queryListener, explainListener)); + } - /** - * Creates a QueryPlan from a cursor. - */ - public AbstractPlan create(String cursor, boolean isExplain, - ResponseListener queryResponseListener, - ResponseListener explainListener) { - QueryId queryId = QueryId.queryId(); - var plan = new QueryPlan(queryId, new FetchCursor(cursor), queryService, queryResponseListener); - return isExplain ? new ExplainPlan(queryId, plan, explainListener) : plan; - } + /** + * Creates a QueryPlan from a cursor. + */ + public AbstractPlan create( + String cursor, + boolean isExplain, + ResponseListener queryResponseListener, + ResponseListener explainListener + ) { + QueryId queryId = QueryId.queryId(); + var plan = new QueryPlan(queryId, new FetchCursor(cursor), queryService, queryResponseListener); + return isExplain ? new ExplainPlan(queryId, plan, explainListener) : plan; + } - boolean canConvertToCursor(UnresolvedPlan plan) { - return plan.accept(new CanPaginateVisitor(), null); - } + boolean canConvertToCursor(UnresolvedPlan plan) { + return plan.accept(new CanPaginateVisitor(), null); + } - /** - * Creates a {@link CloseCursor} command on a cursor. - */ - public AbstractPlan createCloseCursor(String cursor, - ResponseListener queryResponseListener) { - return new CommandPlan(QueryId.queryId(), new CloseCursor().attach(new FetchCursor(cursor)), - queryService, queryResponseListener); - } + /** + * Creates a {@link CloseCursor} command on a cursor. + */ + public AbstractPlan createCloseCursor(String cursor, ResponseListener queryResponseListener) { + return new CommandPlan(QueryId.queryId(), new CloseCursor().attach(new FetchCursor(cursor)), queryService, queryResponseListener); + } - @Override - public AbstractPlan visitQuery( - Query node, - Pair>, - Optional>> - context) { - Preconditions.checkArgument( - context.getLeft().isPresent(), "[BUG] query listener must be not null"); + @Override + public AbstractPlan visitQuery( + Query node, + Pair>, Optional>> context + ) { + Preconditions.checkArgument(context.getLeft().isPresent(), "[BUG] query listener must be not null"); - if (node.getFetchSize() > 0) { - if (canConvertToCursor(node.getPlan())) { - return new QueryPlan(QueryId.queryId(), node.getPlan(), node.getFetchSize(), - queryService, - context.getLeft().get()); - } else { - // This should be picked up by the legacy engine. - throw new UnsupportedCursorRequestException(); - } - } else { - return new QueryPlan(QueryId.queryId(), node.getPlan(), queryService, - context.getLeft().get()); + if (node.getFetchSize() > 0) { + if (canConvertToCursor(node.getPlan())) { + return new QueryPlan(QueryId.queryId(), node.getPlan(), node.getFetchSize(), queryService, context.getLeft().get()); + } else { + // This should be picked up by the legacy engine. + throw new UnsupportedCursorRequestException(); + } + } else { + return new QueryPlan(QueryId.queryId(), node.getPlan(), queryService, context.getLeft().get()); + } } - } - @Override - public AbstractPlan visitExplain( - Explain node, - Pair>, - Optional>> - context) { - Preconditions.checkArgument( - context.getRight().isPresent(), "[BUG] explain listener must be not null"); + @Override + public AbstractPlan visitExplain( + Explain node, + Pair>, Optional>> context + ) { + Preconditions.checkArgument(context.getRight().isPresent(), "[BUG] explain listener must be not null"); - return new ExplainPlan( - QueryId.queryId(), - create(node.getStatement(), - Optional.of(NO_CONSUMER_RESPONSE_LISTENER), Optional.empty()), - context.getRight().get()); - } + return new ExplainPlan( + QueryId.queryId(), + create(node.getStatement(), Optional.of(NO_CONSUMER_RESPONSE_LISTENER), Optional.empty()), + context.getRight().get() + ); + } } diff --git a/core/src/main/java/org/opensearch/sql/executor/execution/StreamingQueryPlan.java b/core/src/main/java/org/opensearch/sql/executor/execution/StreamingQueryPlan.java index 9bb37b064c..34adfdbca9 100644 --- a/core/src/main/java/org/opensearch/sql/executor/execution/StreamingQueryPlan.java +++ b/core/src/main/java/org/opensearch/sql/executor/execution/StreamingQueryPlan.java @@ -29,109 +29,106 @@ */ public class StreamingQueryPlan extends QueryPlan { - private static final Logger log = LogManager.getLogger(StreamingQueryPlan.class); - - private final ExecutionStrategy executionStrategy; - - private MicroBatchStreamingExecution streamingExecution; - - /** - * constructor. - */ - public StreamingQueryPlan(QueryId queryId, - UnresolvedPlan plan, - QueryService queryService, - ResponseListener listener, - ExecutionStrategy executionStrategy) { - super(queryId, plan, queryService, listener); - - this.executionStrategy = executionStrategy; - } - - @Override - public void execute() { - try { - LogicalPlan logicalPlan = queryService.analyze(plan); - StreamingSource streamingSource = buildStreamingSource(logicalPlan); - streamingExecution = - new MicroBatchStreamingExecution( - streamingSource, - logicalPlan, - queryService, - new DefaultMetadataLog<>(), - new DefaultMetadataLog<>()); - executionStrategy.execute(streamingExecution::execute); - } catch (UnsupportedOperationException | IllegalArgumentException e) { - listener.onFailure(e); - } catch (InterruptedException e) { - log.error(e); - // todo, update async task status. - } - } + private static final Logger log = LogManager.getLogger(StreamingQueryPlan.class); - interface ExecutionStrategy { - /** - * execute task. - */ - void execute(Runnable task) throws InterruptedException; - } + private final ExecutionStrategy executionStrategy; - /** - * execute task with fixed interval. - * if task run time < interval, trigger next task on next interval. - * if task run time >= interval, trigger next task immediately. - */ - @RequiredArgsConstructor - public static class IntervalTriggerExecution implements ExecutionStrategy { + private MicroBatchStreamingExecution streamingExecution; - private final long intervalInSeconds; + /** + * constructor. + */ + public StreamingQueryPlan( + QueryId queryId, + UnresolvedPlan plan, + QueryService queryService, + ResponseListener listener, + ExecutionStrategy executionStrategy + ) { + super(queryId, plan, queryService, listener); + + this.executionStrategy = executionStrategy; + } @Override - public void execute(Runnable runnable) throws InterruptedException { - while (!Thread.currentThread().isInterrupted()) { + public void execute() { try { - Instant start = Instant.now(); - runnable.run(); - Instant end = Instant.now(); - long took = Duration.between(start, end).toSeconds(); - TimeUnit.SECONDS.sleep(intervalInSeconds > took ? intervalInSeconds - took : 0); + LogicalPlan logicalPlan = queryService.analyze(plan); + StreamingSource streamingSource = buildStreamingSource(logicalPlan); + streamingExecution = new MicroBatchStreamingExecution( + streamingSource, + logicalPlan, + queryService, + new DefaultMetadataLog<>(), + new DefaultMetadataLog<>() + ); + executionStrategy.execute(streamingExecution::execute); + } catch (UnsupportedOperationException | IllegalArgumentException e) { + listener.onFailure(e); } catch (InterruptedException e) { - Thread.currentThread().interrupt(); + log.error(e); + // todo, update async task status. } - } } - } - private StreamingSource buildStreamingSource(LogicalPlan logicalPlan) { - return logicalPlan.accept(new StreamingSourceBuilder(), null); - } + interface ExecutionStrategy { + /** + * execute task. + */ + void execute(Runnable task) throws InterruptedException; + } - static class StreamingSourceBuilder extends LogicalPlanNodeVisitor { - @Override - public StreamingSource visitNode(LogicalPlan plan, Void context) { - List children = plan.getChild(); - if (children.isEmpty()) { - String errorMsg = - String.format( - "Could find relation plan, %s does not have child node.", - plan.getClass().getSimpleName()); - log.error(errorMsg); - throw new IllegalArgumentException(errorMsg); - } - return children.get(0).accept(this, context); + /** + * execute task with fixed interval. + * if task run time < interval, trigger next task on next interval. + * if task run time >= interval, trigger next task immediately. + */ + @RequiredArgsConstructor + public static class IntervalTriggerExecution implements ExecutionStrategy { + + private final long intervalInSeconds; + + @Override + public void execute(Runnable runnable) throws InterruptedException { + while (!Thread.currentThread().isInterrupted()) { + try { + Instant start = Instant.now(); + runnable.run(); + Instant end = Instant.now(); + long took = Duration.between(start, end).toSeconds(); + TimeUnit.SECONDS.sleep(intervalInSeconds > took ? intervalInSeconds - took : 0); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + } } - @Override - public StreamingSource visitRelation(LogicalRelation plan, Void context) { - try { - return plan.getTable().asStreamingSource(); - } catch (UnsupportedOperationException e) { - String errorMsg = - String.format( - "table %s could not been used as streaming source.", plan.getRelationName()); - log.error(errorMsg); - throw new UnsupportedOperationException(errorMsg); - } + private StreamingSource buildStreamingSource(LogicalPlan logicalPlan) { + return logicalPlan.accept(new StreamingSourceBuilder(), null); + } + + static class StreamingSourceBuilder extends LogicalPlanNodeVisitor { + @Override + public StreamingSource visitNode(LogicalPlan plan, Void context) { + List children = plan.getChild(); + if (children.isEmpty()) { + String errorMsg = String.format("Could find relation plan, %s does not have child node.", plan.getClass().getSimpleName()); + log.error(errorMsg); + throw new IllegalArgumentException(errorMsg); + } + return children.get(0).accept(this, context); + } + + @Override + public StreamingSource visitRelation(LogicalRelation plan, Void context) { + try { + return plan.getTable().asStreamingSource(); + } catch (UnsupportedOperationException e) { + String errorMsg = String.format("table %s could not been used as streaming source.", plan.getRelationName()); + log.error(errorMsg); + throw new UnsupportedOperationException(errorMsg); + } + } } - } } diff --git a/core/src/main/java/org/opensearch/sql/executor/pagination/CanPaginateVisitor.java b/core/src/main/java/org/opensearch/sql/executor/pagination/CanPaginateVisitor.java index e304c132bd..4b76aca3b9 100644 --- a/core/src/main/java/org/opensearch/sql/executor/pagination/CanPaginateVisitor.java +++ b/core/src/main/java/org/opensearch/sql/executor/pagination/CanPaginateVisitor.java @@ -57,204 +57,200 @@ */ public class CanPaginateVisitor extends AbstractNodeVisitor { - @Override - public Boolean visitRelation(Relation node, Object context) { - if (!node.getChild().isEmpty()) { - // Relation instance should never have a child, but check just in case. - return Boolean.FALSE; - } - - return Boolean.TRUE; - } - - protected Boolean canPaginate(Node node, Object context) { - var childList = node.getChild(); - if (childList != null) { - return childList.stream().allMatch(n -> n.accept(this, context)); - } - return Boolean.TRUE; - } - - // Only column references in ORDER BY clause are supported in pagination, - // because expressions can't be pushed down due to #1471. - // https://github.com/opensearch-project/sql/issues/1471 - @Override - public Boolean visitSort(Sort node, Object context) { - return node.getSortList().stream().allMatch(f -> f.getField() instanceof QualifiedName - && visitField(f, context)) - && canPaginate(node, context); - } - - // For queries with WHERE clause: - @Override - public Boolean visitFilter(Filter node, Object context) { - return canPaginate(node, context) && node.getCondition().accept(this, context); - } - - // Queries with GROUP BY clause are not supported - @Override - public Boolean visitAggregation(Aggregation node, Object context) { - return Boolean.FALSE; - } - - // For queries without FROM clause: - @Override - public Boolean visitValues(Values node, Object context) { - return Boolean.TRUE; - } - - // Queries with LIMIT/OFFSET clauses are unsupported - @Override - public Boolean visitLimit(Limit node, Object context) { - return Boolean.FALSE; - } - - @Override - public Boolean visitLiteral(Literal node, Object context) { - return canPaginate(node, context); - } - - @Override - public Boolean visitField(Field node, Object context) { - return canPaginate(node, context) && node.getFieldArgs().stream() - .allMatch(n -> n.accept(this, context)); - } - - @Override - public Boolean visitAlias(Alias node, Object context) { - return canPaginate(node, context) && node.getDelegated().accept(this, context); - } - - @Override - public Boolean visitAllFields(AllFields node, Object context) { - return canPaginate(node, context); - } - - @Override - public Boolean visitQualifiedName(QualifiedName node, Object context) { - return canPaginate(node, context); - } - - @Override - public Boolean visitEqualTo(EqualTo node, Object context) { - return canPaginate(node, context); - } - - @Override - public Boolean visitRelevanceFieldList(RelevanceFieldList node, Object context) { - return canPaginate(node, context); - } - - @Override - public Boolean visitInterval(Interval node, Object context) { - return canPaginate(node, context); - } - - @Override - public Boolean visitCompare(Compare node, Object context) { - return canPaginate(node, context); - } - - @Override - public Boolean visitNot(Not node, Object context) { - return canPaginate(node, context); - } - - @Override - public Boolean visitOr(Or node, Object context) { - return canPaginate(node, context); - } - - @Override - public Boolean visitAnd(And node, Object context) { - return canPaginate(node, context); - } - - @Override - public Boolean visitArgument(Argument node, Object context) { - return canPaginate(node, context); - } - - @Override - public Boolean visitXor(Xor node, Object context) { - return canPaginate(node, context); - } - - @Override - public Boolean visitFunction(Function node, Object context) { - // https://github.com/opensearch-project/sql/issues/1718 - if (node.getFuncName() - .equalsIgnoreCase(BuiltinFunctionName.NESTED.getName().getFunctionName())) { - return Boolean.FALSE; - } - return canPaginate(node, context); - } - - @Override - public Boolean visitIn(In node, Object context) { - return canPaginate(node, context) && node.getValueList().stream() - .allMatch(n -> n.accept(this, context)); - } - - @Override - public Boolean visitBetween(Between node, Object context) { - return canPaginate(node, context); - } - - @Override - public Boolean visitCase(Case node, Object context) { - return canPaginate(node, context); - } - - @Override - public Boolean visitWhen(When node, Object context) { - return canPaginate(node, context); - } - - @Override - public Boolean visitCast(Cast node, Object context) { - return canPaginate(node, context) && node.getConvertedType().accept(this, context); - } - - @Override - public Boolean visitHighlightFunction(HighlightFunction node, Object context) { - return canPaginate(node, context); - } - - @Override - public Boolean visitUnresolvedArgument(UnresolvedArgument node, Object context) { - return canPaginate(node, context); - } - - @Override - public Boolean visitUnresolvedAttribute(UnresolvedAttribute node, Object context) { - return canPaginate(node, context); - } - - @Override - public Boolean visitChildren(Node node, Object context) { - // for all not listed (= unchecked) - false - return Boolean.FALSE; - } - - @Override - public Boolean visitWindowFunction(WindowFunction node, Object context) { - // don't support in-memory aggregation - // SELECT max(age) OVER (PARTITION BY city) ... - return Boolean.FALSE; - } - - @Override - public Boolean visitProject(Project node, Object context) { - if (!node.getProjectList().stream().allMatch(n -> n.accept(this, context))) { - return Boolean.FALSE; - } - - var children = node.getChild(); - if (children.size() != 1) { - return Boolean.FALSE; - } - - return children.get(0).accept(this, context); - } + @Override + public Boolean visitRelation(Relation node, Object context) { + if (!node.getChild().isEmpty()) { + // Relation instance should never have a child, but check just in case. + return Boolean.FALSE; + } + + return Boolean.TRUE; + } + + protected Boolean canPaginate(Node node, Object context) { + var childList = node.getChild(); + if (childList != null) { + return childList.stream().allMatch(n -> n.accept(this, context)); + } + return Boolean.TRUE; + } + + // Only column references in ORDER BY clause are supported in pagination, + // because expressions can't be pushed down due to #1471. + // https://github.com/opensearch-project/sql/issues/1471 + @Override + public Boolean visitSort(Sort node, Object context) { + return node.getSortList().stream().allMatch(f -> f.getField() instanceof QualifiedName && visitField(f, context)) + && canPaginate(node, context); + } + + // For queries with WHERE clause: + @Override + public Boolean visitFilter(Filter node, Object context) { + return canPaginate(node, context) && node.getCondition().accept(this, context); + } + + // Queries with GROUP BY clause are not supported + @Override + public Boolean visitAggregation(Aggregation node, Object context) { + return Boolean.FALSE; + } + + // For queries without FROM clause: + @Override + public Boolean visitValues(Values node, Object context) { + return Boolean.TRUE; + } + + // Queries with LIMIT/OFFSET clauses are unsupported + @Override + public Boolean visitLimit(Limit node, Object context) { + return Boolean.FALSE; + } + + @Override + public Boolean visitLiteral(Literal node, Object context) { + return canPaginate(node, context); + } + + @Override + public Boolean visitField(Field node, Object context) { + return canPaginate(node, context) && node.getFieldArgs().stream().allMatch(n -> n.accept(this, context)); + } + + @Override + public Boolean visitAlias(Alias node, Object context) { + return canPaginate(node, context) && node.getDelegated().accept(this, context); + } + + @Override + public Boolean visitAllFields(AllFields node, Object context) { + return canPaginate(node, context); + } + + @Override + public Boolean visitQualifiedName(QualifiedName node, Object context) { + return canPaginate(node, context); + } + + @Override + public Boolean visitEqualTo(EqualTo node, Object context) { + return canPaginate(node, context); + } + + @Override + public Boolean visitRelevanceFieldList(RelevanceFieldList node, Object context) { + return canPaginate(node, context); + } + + @Override + public Boolean visitInterval(Interval node, Object context) { + return canPaginate(node, context); + } + + @Override + public Boolean visitCompare(Compare node, Object context) { + return canPaginate(node, context); + } + + @Override + public Boolean visitNot(Not node, Object context) { + return canPaginate(node, context); + } + + @Override + public Boolean visitOr(Or node, Object context) { + return canPaginate(node, context); + } + + @Override + public Boolean visitAnd(And node, Object context) { + return canPaginate(node, context); + } + + @Override + public Boolean visitArgument(Argument node, Object context) { + return canPaginate(node, context); + } + + @Override + public Boolean visitXor(Xor node, Object context) { + return canPaginate(node, context); + } + + @Override + public Boolean visitFunction(Function node, Object context) { + // https://github.com/opensearch-project/sql/issues/1718 + if (node.getFuncName().equalsIgnoreCase(BuiltinFunctionName.NESTED.getName().getFunctionName())) { + return Boolean.FALSE; + } + return canPaginate(node, context); + } + + @Override + public Boolean visitIn(In node, Object context) { + return canPaginate(node, context) && node.getValueList().stream().allMatch(n -> n.accept(this, context)); + } + + @Override + public Boolean visitBetween(Between node, Object context) { + return canPaginate(node, context); + } + + @Override + public Boolean visitCase(Case node, Object context) { + return canPaginate(node, context); + } + + @Override + public Boolean visitWhen(When node, Object context) { + return canPaginate(node, context); + } + + @Override + public Boolean visitCast(Cast node, Object context) { + return canPaginate(node, context) && node.getConvertedType().accept(this, context); + } + + @Override + public Boolean visitHighlightFunction(HighlightFunction node, Object context) { + return canPaginate(node, context); + } + + @Override + public Boolean visitUnresolvedArgument(UnresolvedArgument node, Object context) { + return canPaginate(node, context); + } + + @Override + public Boolean visitUnresolvedAttribute(UnresolvedAttribute node, Object context) { + return canPaginate(node, context); + } + + @Override + public Boolean visitChildren(Node node, Object context) { + // for all not listed (= unchecked) - false + return Boolean.FALSE; + } + + @Override + public Boolean visitWindowFunction(WindowFunction node, Object context) { + // don't support in-memory aggregation + // SELECT max(age) OVER (PARTITION BY city) ... + return Boolean.FALSE; + } + + @Override + public Boolean visitProject(Project node, Object context) { + if (!node.getProjectList().stream().allMatch(n -> n.accept(this, context))) { + return Boolean.FALSE; + } + + var children = node.getChild(); + if (children.size() != 1) { + return Boolean.FALSE; + } + + return children.get(0).accept(this, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/executor/pagination/Cursor.java b/core/src/main/java/org/opensearch/sql/executor/pagination/Cursor.java index bb320f5c67..dc73bbf378 100644 --- a/core/src/main/java/org/opensearch/sql/executor/pagination/Cursor.java +++ b/core/src/main/java/org/opensearch/sql/executor/pagination/Cursor.java @@ -12,12 +12,12 @@ @EqualsAndHashCode @RequiredArgsConstructor public class Cursor { - public static final Cursor None = new Cursor(null); + public static final Cursor None = new Cursor(null); - @Getter - private final String data; + @Getter + private final String data; - public String toString() { - return data; - } + public String toString() { + return data; + } } diff --git a/core/src/main/java/org/opensearch/sql/executor/pagination/PlanSerializer.java b/core/src/main/java/org/opensearch/sql/executor/pagination/PlanSerializer.java index 07cf174d73..dbfe44df69 100644 --- a/core/src/main/java/org/opensearch/sql/executor/pagination/PlanSerializer.java +++ b/core/src/main/java/org/opensearch/sql/executor/pagination/PlanSerializer.java @@ -29,99 +29,96 @@ */ @RequiredArgsConstructor public class PlanSerializer { - public static final String CURSOR_PREFIX = "n:"; + public static final String CURSOR_PREFIX = "n:"; - private final StorageEngine engine; + private final StorageEngine engine; - - /** - * Converts a physical plan tree to a cursor. - */ - public Cursor convertToCursor(PhysicalPlan plan) { - try { - return new Cursor(CURSOR_PREFIX - + serialize(((SerializablePlan) plan).getPlanForSerialization())); - // ClassCastException thrown when a plan in the tree doesn't implement SerializablePlan - } catch (NotSerializableException | ClassCastException | NoCursorException e) { - return Cursor.None; + /** + * Converts a physical plan tree to a cursor. + */ + public Cursor convertToCursor(PhysicalPlan plan) { + try { + return new Cursor(CURSOR_PREFIX + serialize(((SerializablePlan) plan).getPlanForSerialization())); + // ClassCastException thrown when a plan in the tree doesn't implement SerializablePlan + } catch (NotSerializableException | ClassCastException | NoCursorException e) { + return Cursor.None; + } } - } - /** - * Serializes and compresses the object. - * @param object The object. - * @return Encoded binary data. - */ - protected String serialize(Serializable object) throws NotSerializableException { - try { - ByteArrayOutputStream output = new ByteArrayOutputStream(); - ObjectOutputStream objectOutput = new ObjectOutputStream(output); - objectOutput.writeObject(object); - objectOutput.flush(); + /** + * Serializes and compresses the object. + * @param object The object. + * @return Encoded binary data. + */ + protected String serialize(Serializable object) throws NotSerializableException { + try { + ByteArrayOutputStream output = new ByteArrayOutputStream(); + ObjectOutputStream objectOutput = new ObjectOutputStream(output); + objectOutput.writeObject(object); + objectOutput.flush(); - ByteArrayOutputStream out = new ByteArrayOutputStream(); - // GZIP provides 35-45%, lzma from apache commons-compress has few % better compression - GZIPOutputStream gzip = new GZIPOutputStream(out) { { - this.def.setLevel(Deflater.BEST_COMPRESSION); - } }; - gzip.write(output.toByteArray()); - gzip.close(); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + // GZIP provides 35-45%, lzma from apache commons-compress has few % better compression + GZIPOutputStream gzip = new GZIPOutputStream(out) { + { + this.def.setLevel(Deflater.BEST_COMPRESSION); + } + }; + gzip.write(output.toByteArray()); + gzip.close(); - return HashCode.fromBytes(out.toByteArray()).toString(); - } catch (NotSerializableException e) { - throw e; - } catch (IOException e) { - throw new IllegalStateException("Failed to serialize: " + object, e); + return HashCode.fromBytes(out.toByteArray()).toString(); + } catch (NotSerializableException e) { + throw e; + } catch (IOException e) { + throw new IllegalStateException("Failed to serialize: " + object, e); + } } - } - /** - * Decompresses and deserializes the binary data. - * @param code Encoded binary data. - * @return An object. - */ - protected Serializable deserialize(String code) { - try { - GZIPInputStream gzip = new GZIPInputStream( - new ByteArrayInputStream(HashCode.fromString(code).asBytes())); - ObjectInputStream objectInput = new CursorDeserializationStream( - new ByteArrayInputStream(gzip.readAllBytes())); - return (Serializable) objectInput.readObject(); - } catch (Exception e) { - throw new IllegalStateException("Failed to deserialize object", e); + /** + * Decompresses and deserializes the binary data. + * @param code Encoded binary data. + * @return An object. + */ + protected Serializable deserialize(String code) { + try { + GZIPInputStream gzip = new GZIPInputStream(new ByteArrayInputStream(HashCode.fromString(code).asBytes())); + ObjectInputStream objectInput = new CursorDeserializationStream(new ByteArrayInputStream(gzip.readAllBytes())); + return (Serializable) objectInput.readObject(); + } catch (Exception e) { + throw new IllegalStateException("Failed to deserialize object", e); + } } - } - /** - * Converts a cursor to a physical plan tree. - */ - public PhysicalPlan convertToPlan(String cursor) { - if (!cursor.startsWith(CURSOR_PREFIX)) { - throw new UnsupportedOperationException("Unsupported cursor"); - } - try { - return (PhysicalPlan) deserialize(cursor.substring(CURSOR_PREFIX.length())); - } catch (Exception e) { - throw new UnsupportedOperationException("Unsupported cursor", e); + /** + * Converts a cursor to a physical plan tree. + */ + public PhysicalPlan convertToPlan(String cursor) { + if (!cursor.startsWith(CURSOR_PREFIX)) { + throw new UnsupportedOperationException("Unsupported cursor"); + } + try { + return (PhysicalPlan) deserialize(cursor.substring(CURSOR_PREFIX.length())); + } catch (Exception e) { + throw new UnsupportedOperationException("Unsupported cursor", e); + } } - } - /** - * This function is used in testing only, to get access to {@link CursorDeserializationStream}. - */ - public CursorDeserializationStream getCursorDeserializationStream(InputStream in) - throws IOException { - return new CursorDeserializationStream(in); - } - - public class CursorDeserializationStream extends ObjectInputStream { - public CursorDeserializationStream(InputStream in) throws IOException { - super(in); + /** + * This function is used in testing only, to get access to {@link CursorDeserializationStream}. + */ + public CursorDeserializationStream getCursorDeserializationStream(InputStream in) throws IOException { + return new CursorDeserializationStream(in); } - @Override - public Object resolveObject(Object obj) throws IOException { - return obj.equals("engine") ? engine : obj; + public class CursorDeserializationStream extends ObjectInputStream { + public CursorDeserializationStream(InputStream in) throws IOException { + super(in); + } + + @Override + public Object resolveObject(Object obj) throws IOException { + return obj.equals("engine") ? engine : obj; + } } - } } diff --git a/core/src/main/java/org/opensearch/sql/executor/streaming/Batch.java b/core/src/main/java/org/opensearch/sql/executor/streaming/Batch.java index cd7d7dae5a..05f375aff7 100644 --- a/core/src/main/java/org/opensearch/sql/executor/streaming/Batch.java +++ b/core/src/main/java/org/opensearch/sql/executor/streaming/Batch.java @@ -13,5 +13,5 @@ */ @Data public class Batch { - private final Split split; + private final Split split; } diff --git a/core/src/main/java/org/opensearch/sql/executor/streaming/DefaultMetadataLog.java b/core/src/main/java/org/opensearch/sql/executor/streaming/DefaultMetadataLog.java index e439d93f6c..16f66474b2 100644 --- a/core/src/main/java/org/opensearch/sql/executor/streaming/DefaultMetadataLog.java +++ b/core/src/main/java/org/opensearch/sql/executor/streaming/DefaultMetadataLog.java @@ -24,53 +24,53 @@ */ public class DefaultMetadataLog implements MetadataLog { - private static final long MIN_ACCEPTABLE_ID = 0L; + private static final long MIN_ACCEPTABLE_ID = 0L; - private SortedMap metadataMap = new TreeMap<>(); + private SortedMap metadataMap = new TreeMap<>(); - @Override - public boolean add(Long batchId, T metadata) { - Preconditions.checkArgument(batchId >= MIN_ACCEPTABLE_ID, "batch id must large or equal 0"); + @Override + public boolean add(Long batchId, T metadata) { + Preconditions.checkArgument(batchId >= MIN_ACCEPTABLE_ID, "batch id must large or equal 0"); - if (metadataMap.containsKey(batchId)) { - return false; + if (metadataMap.containsKey(batchId)) { + return false; + } + metadataMap.put(batchId, metadata); + return true; } - metadataMap.put(batchId, metadata); - return true; - } - @Override - public Optional get(Long batchId) { - if (!metadataMap.containsKey(batchId)) { - return Optional.empty(); - } else { - return Optional.of(metadataMap.get(batchId)); + @Override + public Optional get(Long batchId) { + if (!metadataMap.containsKey(batchId)) { + return Optional.empty(); + } else { + return Optional.of(metadataMap.get(batchId)); + } } - } - @Override - public List get(Optional startBatchId, Optional endBatchId) { - if (startBatchId.isEmpty() && endBatchId.isEmpty()) { - return new ArrayList<>(metadataMap.values()); - } else { - Long s = startBatchId.orElse(MIN_ACCEPTABLE_ID); - Long e = endBatchId.map(i -> i + 1).orElse(Long.MAX_VALUE); - return new ArrayList<>(metadataMap.subMap(s, e).values()); + @Override + public List get(Optional startBatchId, Optional endBatchId) { + if (startBatchId.isEmpty() && endBatchId.isEmpty()) { + return new ArrayList<>(metadataMap.values()); + } else { + Long s = startBatchId.orElse(MIN_ACCEPTABLE_ID); + Long e = endBatchId.map(i -> i + 1).orElse(Long.MAX_VALUE); + return new ArrayList<>(metadataMap.subMap(s, e).values()); + } } - } - @Override - public Optional> getLatest() { - if (metadataMap.isEmpty()) { - return Optional.empty(); - } else { - Long latestId = metadataMap.lastKey(); - return Optional.of(Pair.of(latestId, metadataMap.get(latestId))); + @Override + public Optional> getLatest() { + if (metadataMap.isEmpty()) { + return Optional.empty(); + } else { + Long latestId = metadataMap.lastKey(); + return Optional.of(Pair.of(latestId, metadataMap.get(latestId))); + } } - } - @Override - public void purge(Long batchId) { - metadataMap.headMap(batchId).clear(); - } + @Override + public void purge(Long batchId) { + metadataMap.headMap(batchId).clear(); + } } diff --git a/core/src/main/java/org/opensearch/sql/executor/streaming/MetadataLog.java b/core/src/main/java/org/opensearch/sql/executor/streaming/MetadataLog.java index d6bb9bacd6..f1bc4091d0 100644 --- a/core/src/main/java/org/opensearch/sql/executor/streaming/MetadataLog.java +++ b/core/src/main/java/org/opensearch/sql/executor/streaming/MetadataLog.java @@ -19,43 +19,43 @@ */ public interface MetadataLog { - /** - * add metadata to WAL. - * - * @param id metadata index in WAL. - * @param metadata metadata. - * @return true if add success, otherwise return false. - */ - boolean add(Long id, T metadata); - - /** - * get metadata from WAL. - * - * @param id metadata index in WAL. - * @return metadata. - */ - Optional get(Long id); - - /** - * Return metadata for id between [startId, endId]. - * - * @param startId If startId is empty, return all metadata before endId (inclusive). - * @param endId If end is empty, return all batches after endId (inclusive). - * @return a list of metadata sorted by id (nature order). - */ - List get(Optional startId, Optional endId); - - /** - * Get latest batchId and metadata. - * - * @return pair of id and metadata if not empty. - */ - Optional> getLatest(); - - /** - * Remove all the metadata less then id (exclusive). - * - * @param id smallest batchId should keep. - */ - void purge(Long id); + /** + * add metadata to WAL. + * + * @param id metadata index in WAL. + * @param metadata metadata. + * @return true if add success, otherwise return false. + */ + boolean add(Long id, T metadata); + + /** + * get metadata from WAL. + * + * @param id metadata index in WAL. + * @return metadata. + */ + Optional get(Long id); + + /** + * Return metadata for id between [startId, endId]. + * + * @param startId If startId is empty, return all metadata before endId (inclusive). + * @param endId If end is empty, return all batches after endId (inclusive). + * @return a list of metadata sorted by id (nature order). + */ + List get(Optional startId, Optional endId); + + /** + * Get latest batchId and metadata. + * + * @return pair of id and metadata if not empty. + */ + Optional> getLatest(); + + /** + * Remove all the metadata less then id (exclusive). + * + * @param id smallest batchId should keep. + */ + void purge(Long id); } diff --git a/core/src/main/java/org/opensearch/sql/executor/streaming/MicroBatchStreamingExecution.java b/core/src/main/java/org/opensearch/sql/executor/streaming/MicroBatchStreamingExecution.java index c31ed18c57..dd00f3dff7 100644 --- a/core/src/main/java/org/opensearch/sql/executor/streaming/MicroBatchStreamingExecution.java +++ b/core/src/main/java/org/opensearch/sql/executor/streaming/MicroBatchStreamingExecution.java @@ -25,106 +25,103 @@ */ public class MicroBatchStreamingExecution { - private static final Logger log = LogManager.getLogger(MicroBatchStreamingExecution.class); - - static final long INITIAL_LATEST_BATCH_ID = -1L; - - private final StreamingSource source; - - private final LogicalPlan batchPlan; - - private final QueryService queryService; - - /** - * A write-ahead-log that records the offsets that are present in each batch. In order to ensure - * that a given batch will always consist of the same data, we write to this log before any - * processing is done. Thus, the Nth record in this log indicated data that is currently being - * processed and the N-1th entry indicates which offsets have been durably committed to the sink. - */ - private final MetadataLog offsetLog; - - /** keep track the latest commit batchId. */ - private final MetadataLog committedLog; - - /** - * Constructor. - */ - public MicroBatchStreamingExecution( - StreamingSource source, - LogicalPlan batchPlan, - QueryService queryService, - MetadataLog offsetLog, - MetadataLog committedLog) { - this.source = source; - this.batchPlan = batchPlan; - this.queryService = queryService; - // todo. add offsetLog and committedLog offset recovery. - this.offsetLog = offsetLog; - this.committedLog = committedLog; - } - - /** - * Pull the {@link Batch} from {@link StreamingSource} and execute the {@link Batch}. - */ - public void execute() { - Long latestBatchId = offsetLog.getLatest().map(Pair::getKey).orElse(INITIAL_LATEST_BATCH_ID); - Long latestCommittedBatchId = - committedLog.getLatest().map(Pair::getKey).orElse(INITIAL_LATEST_BATCH_ID); - Optional committedOffset = offsetLog.get(latestCommittedBatchId); - AtomicLong currentBatchId = new AtomicLong(INITIAL_LATEST_BATCH_ID); - - if (latestBatchId.equals(latestCommittedBatchId)) { - // there are no unhandled Offset. - currentBatchId.set(latestCommittedBatchId + 1L); - } else { - Preconditions.checkArgument( - latestBatchId.equals(latestCommittedBatchId + 1L), - "[BUG] Expected latestBatchId - latestCommittedBatchId = 0 or 1, " - + "but latestBatchId=%d, latestCommittedBatchId=%d", - latestBatchId, - latestCommittedBatchId); - - // latestBatchId is not committed yet. - currentBatchId.set(latestBatchId); + private static final Logger log = LogManager.getLogger(MicroBatchStreamingExecution.class); + + static final long INITIAL_LATEST_BATCH_ID = -1L; + + private final StreamingSource source; + + private final LogicalPlan batchPlan; + + private final QueryService queryService; + + /** + * A write-ahead-log that records the offsets that are present in each batch. In order to ensure + * that a given batch will always consist of the same data, we write to this log before any + * processing is done. Thus, the Nth record in this log indicated data that is currently being + * processed and the N-1th entry indicates which offsets have been durably committed to the sink. + */ + private final MetadataLog offsetLog; + + /** keep track the latest commit batchId. */ + private final MetadataLog committedLog; + + /** + * Constructor. + */ + public MicroBatchStreamingExecution( + StreamingSource source, + LogicalPlan batchPlan, + QueryService queryService, + MetadataLog offsetLog, + MetadataLog committedLog + ) { + this.source = source; + this.batchPlan = batchPlan; + this.queryService = queryService; + // todo. add offsetLog and committedLog offset recovery. + this.offsetLog = offsetLog; + this.committedLog = committedLog; } - Optional availableOffsets = source.getLatestOffset(); - if (hasNewData(availableOffsets, committedOffset)) { - Batch batch = source.getBatch(committedOffset, availableOffsets.get()); - offsetLog.add(currentBatchId.get(), availableOffsets.get()); - queryService.executePlan( - batchPlan, - new PlanContext(batch.getSplit()), - new ResponseListener<>() { - @Override - public void onResponse(ExecutionEngine.QueryResponse response) { - long finalBatchId = currentBatchId.get(); - Offset finalAvailableOffsets = availableOffsets.get(); - committedLog.add(finalBatchId, finalAvailableOffsets); - } - - @Override - public void onFailure(Exception e) { - log.error("streaming processing failed. source = {} {}", source, e); - } - }); + /** + * Pull the {@link Batch} from {@link StreamingSource} and execute the {@link Batch}. + */ + public void execute() { + Long latestBatchId = offsetLog.getLatest().map(Pair::getKey).orElse(INITIAL_LATEST_BATCH_ID); + Long latestCommittedBatchId = committedLog.getLatest().map(Pair::getKey).orElse(INITIAL_LATEST_BATCH_ID); + Optional committedOffset = offsetLog.get(latestCommittedBatchId); + AtomicLong currentBatchId = new AtomicLong(INITIAL_LATEST_BATCH_ID); + + if (latestBatchId.equals(latestCommittedBatchId)) { + // there are no unhandled Offset. + currentBatchId.set(latestCommittedBatchId + 1L); + } else { + Preconditions.checkArgument( + latestBatchId.equals(latestCommittedBatchId + 1L), + "[BUG] Expected latestBatchId - latestCommittedBatchId = 0 or 1, " + "but latestBatchId=%d, latestCommittedBatchId=%d", + latestBatchId, + latestCommittedBatchId + ); + + // latestBatchId is not committed yet. + currentBatchId.set(latestBatchId); + } + + Optional availableOffsets = source.getLatestOffset(); + if (hasNewData(availableOffsets, committedOffset)) { + Batch batch = source.getBatch(committedOffset, availableOffsets.get()); + offsetLog.add(currentBatchId.get(), availableOffsets.get()); + queryService.executePlan(batchPlan, new PlanContext(batch.getSplit()), new ResponseListener<>() { + @Override + public void onResponse(ExecutionEngine.QueryResponse response) { + long finalBatchId = currentBatchId.get(); + Offset finalAvailableOffsets = availableOffsets.get(); + committedLog.add(finalBatchId, finalAvailableOffsets); + } + + @Override + public void onFailure(Exception e) { + log.error("streaming processing failed. source = {} {}", source, e); + } + }); + } } - } - - private boolean hasNewData(Optional availableOffsets, Optional committedOffset) { - if (availableOffsets.equals(committedOffset)) { - log.debug("source does not have new data, exit. source = {}", source); - return false; - } else { - Preconditions.checkArgument( - availableOffsets.isPresent(), "[BUG] available offsets must be no empty"); - - log.debug( - "source has new data. source = {}, availableOffsets:{}, committedOffset:{}", - source, - availableOffsets, - committedOffset); - return true; + + private boolean hasNewData(Optional availableOffsets, Optional committedOffset) { + if (availableOffsets.equals(committedOffset)) { + log.debug("source does not have new data, exit. source = {}", source); + return false; + } else { + Preconditions.checkArgument(availableOffsets.isPresent(), "[BUG] available offsets must be no empty"); + + log.debug( + "source has new data. source = {}, availableOffsets:{}, committedOffset:{}", + source, + availableOffsets, + committedOffset + ); + return true; + } } - } } diff --git a/core/src/main/java/org/opensearch/sql/executor/streaming/Offset.java b/core/src/main/java/org/opensearch/sql/executor/streaming/Offset.java index 00f040e437..b3c81fc361 100644 --- a/core/src/main/java/org/opensearch/sql/executor/streaming/Offset.java +++ b/core/src/main/java/org/opensearch/sql/executor/streaming/Offset.java @@ -13,5 +13,5 @@ @Data public class Offset { - private final Long offset; + private final Long offset; } diff --git a/core/src/main/java/org/opensearch/sql/executor/streaming/StreamingSource.java b/core/src/main/java/org/opensearch/sql/executor/streaming/StreamingSource.java index ebd3fa714b..fe5675b951 100644 --- a/core/src/main/java/org/opensearch/sql/executor/streaming/StreamingSource.java +++ b/core/src/main/java/org/opensearch/sql/executor/streaming/StreamingSource.java @@ -11,19 +11,19 @@ * Streaming source. */ public interface StreamingSource { - /** - * Get current {@link Offset} of stream data. - * - * @return empty if the stream does not has new data. - */ - Optional getLatestOffset(); + /** + * Get current {@link Offset} of stream data. + * + * @return empty if the stream does not has new data. + */ + Optional getLatestOffset(); - /** - * Get a {@link Batch} from source between (start, end]. - * - * @param start start offset. - * @param end end offset. - * @return @link Batch}. - */ - Batch getBatch(Optional start, Offset end); + /** + * Get a {@link Batch} from source between (start, end]. + * + * @param start start offset. + * @param end end offset. + * @return @link Batch}. + */ + Batch getBatch(Optional start, Offset end); } diff --git a/core/src/main/java/org/opensearch/sql/expression/DSL.java b/core/src/main/java/org/opensearch/sql/expression/DSL.java index 3f1897e483..e74229f9c2 100644 --- a/core/src/main/java/org/opensearch/sql/expression/DSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/DSL.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression; import java.util.Arrays; @@ -29,947 +28,906 @@ public class DSL { - private DSL() { - } - - public static LiteralExpression literal(Byte value) { - return new LiteralExpression(ExprValueUtils.byteValue(value)); - } - - public static LiteralExpression literal(Short value) { - return new LiteralExpression(new ExprShortValue(value)); - } - - public static LiteralExpression literal(Integer value) { - return new LiteralExpression(ExprValueUtils.integerValue(value)); - } - - public static LiteralExpression literal(Long value) { - return new LiteralExpression(ExprValueUtils.longValue(value)); - } - - public static LiteralExpression literal(Float value) { - return new LiteralExpression(ExprValueUtils.floatValue(value)); - } - - public static LiteralExpression literal(Double value) { - return new LiteralExpression(ExprValueUtils.doubleValue(value)); - } - - public static LiteralExpression literal(String value) { - return new LiteralExpression(ExprValueUtils.stringValue(value)); - } - - public static LiteralExpression literal(Boolean value) { - return new LiteralExpression(ExprValueUtils.booleanValue(value)); - } - - public static LiteralExpression literal(ExprValue value) { - return new LiteralExpression(value); - } - - /** - * Wrap a number to {@link LiteralExpression}. - */ - public static LiteralExpression literal(Number value) { - if (value instanceof Integer) { - return new LiteralExpression(ExprValueUtils.integerValue(value.intValue())); - } else if (value instanceof Long) { - return new LiteralExpression(ExprValueUtils.longValue(value.longValue())); - } else if (value instanceof Float) { - return new LiteralExpression(ExprValueUtils.floatValue(value.floatValue())); - } else { - return new LiteralExpression(ExprValueUtils.doubleValue(value.doubleValue())); - } - } - - public static ReferenceExpression ref(String ref, ExprType type) { - return new ReferenceExpression(ref, type); - } - - /** - * Wrap a named expression if not yet. The intent is that different languages may use - * Alias or not when building AST. This caused either named or unnamed expression - * is resolved by analyzer. To make unnamed expression acceptable for logical project, - * it is required to wrap it by named expression here before passing to logical project. - * - * @param expression expression - * @return expression if named already or expression wrapped by named expression. - */ - public static NamedExpression named(Expression expression) { - if (expression instanceof NamedExpression) { - return (NamedExpression) expression; - } - if (expression instanceof ParseExpression) { - return named(((ParseExpression) expression).getIdentifier().valueOf().stringValue(), - expression); - } - return named(expression.toString(), expression); - } - - public static NamedExpression named(String name, Expression expression) { - return new NamedExpression(name, expression); - } - - public static NamedExpression named(String name, Expression expression, String alias) { - return new NamedExpression(name, expression, alias); - } - - public static NamedAggregator named(String name, Aggregator aggregator) { - return new NamedAggregator(name, aggregator); - } - - public static NamedArgumentExpression namedArgument(String argName, Expression value) { - return new NamedArgumentExpression(argName, value); - } - - public static NamedArgumentExpression namedArgument(String name, String value) { - return namedArgument(name, literal(value)); - } - - public static GrokExpression grok(Expression sourceField, Expression pattern, - Expression identifier) { - return new GrokExpression(sourceField, pattern, identifier); - } - - public static RegexExpression regex(Expression sourceField, Expression pattern, - Expression identifier) { - return new RegexExpression(sourceField, pattern, identifier); - } - - public static PatternsExpression patterns(Expression sourceField, Expression pattern, - Expression identifier) { - return new PatternsExpression(sourceField, pattern, identifier); - } - - public static SpanExpression span(Expression field, Expression value, String unit) { - return new SpanExpression(field, value, SpanUnit.of(unit)); - } - - public static FunctionExpression abs(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.ABS, expressions); - } - - public static FunctionExpression add(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.ADD, expressions); - } - - public static FunctionExpression addFunction(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.ADDFUNCTION, expressions); - } - - public static FunctionExpression ceil(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.CEIL, expressions); - } - - public static FunctionExpression ceiling(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.CEILING, expressions); - } - - public static FunctionExpression conv(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.CONV, expressions); - } - - public static FunctionExpression crc32(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.CRC32, expressions); - } - - public static FunctionExpression divide(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.DIVIDE, expressions); - } - - public static FunctionExpression divideFunction(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.DIVIDEFUNCTION, expressions); - } - - public static FunctionExpression euler(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.E, expressions); - } - - public static FunctionExpression exp(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.EXP, expressions); - } - - public static FunctionExpression expm1(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.EXPM1, expressions); - } - - public static FunctionExpression floor(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.FLOOR, expressions); - } + private DSL() {} - public static FunctionExpression ln(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.LN, expressions); - } + public static LiteralExpression literal(Byte value) { + return new LiteralExpression(ExprValueUtils.byteValue(value)); + } - public static FunctionExpression log(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.LOG, expressions); - } + public static LiteralExpression literal(Short value) { + return new LiteralExpression(new ExprShortValue(value)); + } - public static FunctionExpression log10(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.LOG10, expressions); - } + public static LiteralExpression literal(Integer value) { + return new LiteralExpression(ExprValueUtils.integerValue(value)); + } - public static FunctionExpression log2(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.LOG2, expressions); - } + public static LiteralExpression literal(Long value) { + return new LiteralExpression(ExprValueUtils.longValue(value)); + } - public static FunctionExpression mod(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.MOD, expressions); - } + public static LiteralExpression literal(Float value) { + return new LiteralExpression(ExprValueUtils.floatValue(value)); + } - public static FunctionExpression modulus(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.MODULUS, expressions); - } + public static LiteralExpression literal(Double value) { + return new LiteralExpression(ExprValueUtils.doubleValue(value)); + } - public static FunctionExpression modulusFunction(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.MODULUSFUNCTION, expressions); - } + public static LiteralExpression literal(String value) { + return new LiteralExpression(ExprValueUtils.stringValue(value)); + } - public static FunctionExpression multiply(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.MULTIPLY, expressions); - } + public static LiteralExpression literal(Boolean value) { + return new LiteralExpression(ExprValueUtils.booleanValue(value)); + } - public static FunctionExpression multiplyFunction(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.MULTIPLYFUNCTION, expressions); - } + public static LiteralExpression literal(ExprValue value) { + return new LiteralExpression(value); + } - public static FunctionExpression pi(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.PI, expressions); - } + /** + * Wrap a number to {@link LiteralExpression}. + */ + public static LiteralExpression literal(Number value) { + if (value instanceof Integer) { + return new LiteralExpression(ExprValueUtils.integerValue(value.intValue())); + } else if (value instanceof Long) { + return new LiteralExpression(ExprValueUtils.longValue(value.longValue())); + } else if (value instanceof Float) { + return new LiteralExpression(ExprValueUtils.floatValue(value.floatValue())); + } else { + return new LiteralExpression(ExprValueUtils.doubleValue(value.doubleValue())); + } + } - public static FunctionExpression pow(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.POW, expressions); - } + public static ReferenceExpression ref(String ref, ExprType type) { + return new ReferenceExpression(ref, type); + } - public static FunctionExpression power(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.POWER, expressions); - } + /** + * Wrap a named expression if not yet. The intent is that different languages may use + * Alias or not when building AST. This caused either named or unnamed expression + * is resolved by analyzer. To make unnamed expression acceptable for logical project, + * it is required to wrap it by named expression here before passing to logical project. + * + * @param expression expression + * @return expression if named already or expression wrapped by named expression. + */ + public static NamedExpression named(Expression expression) { + if (expression instanceof NamedExpression) { + return (NamedExpression) expression; + } + if (expression instanceof ParseExpression) { + return named(((ParseExpression) expression).getIdentifier().valueOf().stringValue(), expression); + } + return named(expression.toString(), expression); + } - public static FunctionExpression rand(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.RAND, expressions); - } + public static NamedExpression named(String name, Expression expression) { + return new NamedExpression(name, expression); + } - public static FunctionExpression rint(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.RINT, expressions); - } + public static NamedExpression named(String name, Expression expression, String alias) { + return new NamedExpression(name, expression, alias); + } - public static FunctionExpression round(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.ROUND, expressions); - } + public static NamedAggregator named(String name, Aggregator aggregator) { + return new NamedAggregator(name, aggregator); + } - public static FunctionExpression sign(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.SIGN, expressions); - } + public static NamedArgumentExpression namedArgument(String argName, Expression value) { + return new NamedArgumentExpression(argName, value); + } - public static FunctionExpression signum(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.SIGNUM, expressions); - } + public static NamedArgumentExpression namedArgument(String name, String value) { + return namedArgument(name, literal(value)); + } - public static FunctionExpression sinh(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.SINH, expressions); + public static GrokExpression grok(Expression sourceField, Expression pattern, Expression identifier) { + return new GrokExpression(sourceField, pattern, identifier); + } - } + public static RegexExpression regex(Expression sourceField, Expression pattern, Expression identifier) { + return new RegexExpression(sourceField, pattern, identifier); + } - public static FunctionExpression sqrt(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.SQRT, expressions); - } + public static PatternsExpression patterns(Expression sourceField, Expression pattern, Expression identifier) { + return new PatternsExpression(sourceField, pattern, identifier); + } - public static FunctionExpression cbrt(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.CBRT, expressions); - } + public static SpanExpression span(Expression field, Expression value, String unit) { + return new SpanExpression(field, value, SpanUnit.of(unit)); + } - public static FunctionExpression position(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.POSITION, expressions); - } - - public static FunctionExpression truncate(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.TRUNCATE, expressions); - } - - public static FunctionExpression acos(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.ACOS, expressions); - } - - public static FunctionExpression asin(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.ASIN, expressions); - } - - public static FunctionExpression atan(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.ATAN, expressions); - } - - public static FunctionExpression atan2(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.ATAN2, expressions); - } - - public static FunctionExpression cos(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.COS, expressions); - } - - public static FunctionExpression cosh(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.COSH, expressions); - } - - public static FunctionExpression cot(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.COT, expressions); - } - - public static FunctionExpression degrees(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.DEGREES, expressions); - } - - public static FunctionExpression radians(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.RADIANS, expressions); - } - - public static FunctionExpression sin(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.SIN, expressions); - } - - public static FunctionExpression subtract(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.SUBTRACT, expressions); - } - - public static FunctionExpression subtractFunction(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.SUBTRACTFUNCTION, expressions); - } - - public static FunctionExpression tan(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.TAN, expressions); - } - - public static FunctionExpression convert_tz(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.CONVERT_TZ, expressions); - } - - public static FunctionExpression date(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.DATE, expressions); - } - - public static FunctionExpression datetime(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.DATETIME, expressions); - } - - public static FunctionExpression date_add(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.DATE_ADD, expressions); - } - - public static FunctionExpression day(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.DAY, expressions); - } - - public static FunctionExpression dayname(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.DAYNAME, expressions); - } - - public static FunctionExpression dayofmonth( - FunctionProperties functionProperties, - Expression... expressions) { - return compile(functionProperties, BuiltinFunctionName.DAYOFMONTH, expressions); - } - - public static FunctionExpression dayofweek( - FunctionProperties functionProperties, Expression... expressions) { - return compile(functionProperties, BuiltinFunctionName.DAYOFWEEK, expressions); - } - - public static FunctionExpression dayofyear(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.DAYOFYEAR, expressions); - } - - public static FunctionExpression day_of_month( - FunctionProperties functionProperties, - Expression... expressions) { - return compile(functionProperties, BuiltinFunctionName.DAY_OF_MONTH, expressions); - } - - public static FunctionExpression day_of_year( - FunctionProperties functionProperties, Expression... expressions) { - return compile(functionProperties, BuiltinFunctionName.DAY_OF_YEAR, expressions); - } - - public static FunctionExpression day_of_week( - FunctionProperties functionProperties, Expression... expressions) { - return compile(functionProperties, BuiltinFunctionName.DAY_OF_WEEK, expressions); - } - - public static FunctionExpression extract(FunctionProperties functionProperties, - Expression... expressions) { - return compile(functionProperties, BuiltinFunctionName.EXTRACT, expressions); - } - - public static FunctionExpression extract(Expression... expressions) { - return extract(FunctionProperties.None, expressions); - } - - public static FunctionExpression from_days(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.FROM_DAYS, expressions); - } - - public static FunctionExpression get_format(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.GET_FORMAT, expressions); - } - - public static FunctionExpression hour(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.HOUR, expressions); - } - - public static FunctionExpression hour_of_day(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.HOUR_OF_DAY, expressions); - } - - public static FunctionExpression last_day(FunctionProperties functionProperties, - Expression... expressions) { - return compile(functionProperties, BuiltinFunctionName.LAST_DAY, expressions); - } - - public static FunctionExpression microsecond(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.MICROSECOND, expressions); - } - - public static FunctionExpression minute(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.MINUTE, expressions); - } - - public static FunctionExpression minute_of_day(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.MINUTE_OF_DAY, expressions); - } - - public static FunctionExpression minute_of_hour(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.MINUTE_OF_HOUR, expressions); - } - - public static FunctionExpression month(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.MONTH, expressions); - } - - public static FunctionExpression month_of_year( - FunctionProperties functionProperties, Expression... expressions) { - return compile(functionProperties, BuiltinFunctionName.MONTH_OF_YEAR, expressions); - } - - public static FunctionExpression monthname(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.MONTHNAME, expressions); - } - - public static FunctionExpression quarter(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.QUARTER, expressions); - } - - public static FunctionExpression second(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.SECOND, expressions); - } - - public static FunctionExpression second_of_minute(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.SECOND_OF_MINUTE, expressions); - } - - public static FunctionExpression time(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.TIME, expressions); - } - - public static FunctionExpression time_to_sec(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.TIME_TO_SEC, expressions); - } - - public static FunctionExpression timestamp(Expression... expressions) { - return timestamp(FunctionProperties.None, expressions); - } - - public static FunctionExpression timestamp(FunctionProperties functionProperties, - Expression... expressions) { - return compile(functionProperties, BuiltinFunctionName.TIMESTAMP, expressions); - } - - public static FunctionExpression date_format( - FunctionProperties functionProperties, - Expression... expressions) { - return compile(functionProperties, BuiltinFunctionName.DATE_FORMAT, expressions); - } - - public static FunctionExpression to_days(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.TO_DAYS, expressions); - } - - public static FunctionExpression to_seconds(FunctionProperties functionProperties, - Expression... expressions) { - return compile(functionProperties, BuiltinFunctionName.TO_SECONDS, expressions); - } - - public static FunctionExpression to_seconds(Expression... expressions) { - return to_seconds(FunctionProperties.None, expressions); - } - - public static FunctionExpression week( - FunctionProperties functionProperties, Expression... expressions) { - return compile(functionProperties, BuiltinFunctionName.WEEK, expressions); - } - - public static FunctionExpression weekday(FunctionProperties functionProperties, - Expression... expressions) { - return compile(functionProperties, BuiltinFunctionName.WEEKDAY, expressions); - } - - public static FunctionExpression weekofyear( - FunctionProperties functionProperties, Expression... expressions) { - return compile(functionProperties, BuiltinFunctionName.WEEKOFYEAR, expressions); - } - - public static FunctionExpression week_of_year( - FunctionProperties functionProperties, Expression... expressions) { - return compile(functionProperties, BuiltinFunctionName.WEEK_OF_YEAR, expressions); - } - - public static FunctionExpression year(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.YEAR, expressions); - } - - public static FunctionExpression yearweek( - FunctionProperties functionProperties, Expression... expressions) { - return compile(functionProperties, BuiltinFunctionName.YEARWEEK, expressions); - } - - public static FunctionExpression str_to_date(FunctionProperties functionProperties, - Expression... expressions) { - return compile(functionProperties, BuiltinFunctionName.STR_TO_DATE, expressions); - } - - public static FunctionExpression sec_to_time(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.SEC_TO_TIME, expressions); - } - - public static FunctionExpression substr(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.SUBSTR, expressions); - } - - public static FunctionExpression substring(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.SUBSTR, expressions); - } - - public static FunctionExpression ltrim(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.LTRIM, expressions); - } - - public static FunctionExpression rtrim(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.RTRIM, expressions); - } - - public static FunctionExpression trim(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.TRIM, expressions); - } - - public static FunctionExpression upper(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.UPPER, expressions); - } - - public static FunctionExpression lower(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.LOWER, expressions); - } - - public static FunctionExpression regexp(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.REGEXP, expressions); - } - - public static FunctionExpression concat(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.CONCAT, expressions); - } - - public static FunctionExpression concat_ws(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.CONCAT_WS, expressions); - } - - public static FunctionExpression length(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.LENGTH, expressions); - } - - public static FunctionExpression strcmp(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.STRCMP, expressions); - } - - public static FunctionExpression right(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.RIGHT, expressions); - } - - public static FunctionExpression left(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.LEFT, expressions); - } - - public static FunctionExpression ascii(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.ASCII, expressions); - } - - public static FunctionExpression locate(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.LOCATE, expressions); - } - - public static FunctionExpression replace(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.REPLACE, expressions); - } - - public static FunctionExpression reverse(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.REVERSE, expressions); - } - - public static FunctionExpression and(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.AND, expressions); - } - - public static FunctionExpression or(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.OR, expressions); - } - - public static FunctionExpression xor(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.XOR, expressions); - } - - public static FunctionExpression nested(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.NESTED, expressions); - } - - public static FunctionExpression not(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.NOT, expressions); - } - - public static FunctionExpression equal(FunctionProperties fp, Expression... expressions) { - return compile(fp, BuiltinFunctionName.EQUAL, expressions); - } + public static FunctionExpression abs(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.ABS, expressions); + } - public static FunctionExpression equal(Expression... expressions) { - return equal(FunctionProperties.None, expressions); - } + public static FunctionExpression add(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.ADD, expressions); + } - public static FunctionExpression notequal(FunctionProperties fp, Expression... expressions) { - return compile(fp, BuiltinFunctionName.NOTEQUAL, expressions); - } + public static FunctionExpression addFunction(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.ADDFUNCTION, expressions); + } - public static FunctionExpression notequal(Expression... expressions) { - return notequal(FunctionProperties.None, expressions); - } + public static FunctionExpression ceil(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.CEIL, expressions); + } - public static FunctionExpression less(FunctionProperties fp, Expression... expressions) { - return compile(fp, BuiltinFunctionName.LESS, expressions); - } + public static FunctionExpression ceiling(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.CEILING, expressions); + } - public static FunctionExpression less(Expression... expressions) { - return less(FunctionProperties.None, expressions); - } + public static FunctionExpression conv(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.CONV, expressions); + } - public static FunctionExpression lte(FunctionProperties fp, Expression... expressions) { - return compile(fp, BuiltinFunctionName.LTE, expressions); - } + public static FunctionExpression crc32(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.CRC32, expressions); + } - public static FunctionExpression lte(Expression... expressions) { - return lte(FunctionProperties.None, expressions); - } + public static FunctionExpression divide(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.DIVIDE, expressions); + } - public static FunctionExpression greater(FunctionProperties fp, Expression... expressions) { - return compile(fp, BuiltinFunctionName.GREATER, expressions); - } + public static FunctionExpression divideFunction(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.DIVIDEFUNCTION, expressions); + } - public static FunctionExpression greater(Expression... expressions) { - return greater(FunctionProperties.None, expressions); - } + public static FunctionExpression euler(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.E, expressions); + } - public static FunctionExpression gte(FunctionProperties fp, Expression... expressions) { - return compile(fp, BuiltinFunctionName.GTE, expressions); - } + public static FunctionExpression exp(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.EXP, expressions); + } - public static FunctionExpression gte(Expression... expressions) { - return gte(FunctionProperties.None, expressions); - } + public static FunctionExpression expm1(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.EXPM1, expressions); + } - public static FunctionExpression like(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.LIKE, expressions); - } + public static FunctionExpression floor(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.FLOOR, expressions); + } - public static FunctionExpression notLike(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.NOT_LIKE, expressions); - } + public static FunctionExpression ln(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.LN, expressions); + } - public static Aggregator avg(Expression... expressions) { - return aggregate(BuiltinFunctionName.AVG, expressions); - } + public static FunctionExpression log(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.LOG, expressions); + } - public static Aggregator sum(Expression... expressions) { - return aggregate(BuiltinFunctionName.SUM, expressions); - } + public static FunctionExpression log10(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.LOG10, expressions); + } - public static Aggregator count(Expression... expressions) { - return aggregate(BuiltinFunctionName.COUNT, expressions); - } + public static FunctionExpression log2(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.LOG2, expressions); + } - public static Aggregator distinctCount(Expression... expressions) { - return count(expressions).distinct(true); - } + public static FunctionExpression mod(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.MOD, expressions); + } - public static Aggregator varSamp(Expression... expressions) { - return aggregate(BuiltinFunctionName.VARSAMP, expressions); - } + public static FunctionExpression modulus(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.MODULUS, expressions); + } - public static Aggregator varPop(Expression... expressions) { - return aggregate(BuiltinFunctionName.VARPOP, expressions); - } + public static FunctionExpression modulusFunction(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.MODULUSFUNCTION, expressions); + } - public static Aggregator stddevSamp(Expression... expressions) { - return aggregate(BuiltinFunctionName.STDDEV_SAMP, expressions); - } + public static FunctionExpression multiply(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.MULTIPLY, expressions); + } - public static Aggregator stddevPop(Expression... expressions) { - return aggregate(BuiltinFunctionName.STDDEV_POP, expressions); - } + public static FunctionExpression multiplyFunction(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.MULTIPLYFUNCTION, expressions); + } - public static Aggregator take(Expression... expressions) { - return aggregate(BuiltinFunctionName.TAKE, expressions); - } + public static FunctionExpression pi(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.PI, expressions); + } - public static RankingWindowFunction rowNumber() { - return compile(FunctionProperties.None, BuiltinFunctionName.ROW_NUMBER); - } + public static FunctionExpression pow(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.POW, expressions); + } - public static RankingWindowFunction rank() { - return compile(FunctionProperties.None, BuiltinFunctionName.RANK); - } + public static FunctionExpression power(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.POWER, expressions); + } - public static RankingWindowFunction denseRank() { - return compile(FunctionProperties.None, BuiltinFunctionName.DENSE_RANK); - } + public static FunctionExpression rand(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.RAND, expressions); + } - public static Aggregator min(Expression... expressions) { - return aggregate(BuiltinFunctionName.MIN, expressions); - } + public static FunctionExpression rint(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.RINT, expressions); + } + + public static FunctionExpression round(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.ROUND, expressions); + } + + public static FunctionExpression sign(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.SIGN, expressions); + } + + public static FunctionExpression signum(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.SIGNUM, expressions); + } + + public static FunctionExpression sinh(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.SINH, expressions); + + } + + public static FunctionExpression sqrt(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.SQRT, expressions); + } - public static Aggregator max(Expression... expressions) { - return aggregate(BuiltinFunctionName.MAX, expressions); - } + public static FunctionExpression cbrt(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.CBRT, expressions); + } + + public static FunctionExpression position(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.POSITION, expressions); + } - private static Aggregator aggregate(BuiltinFunctionName functionName, Expression... expressions) { - return compile(FunctionProperties.None, functionName, expressions); - } + public static FunctionExpression truncate(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.TRUNCATE, expressions); + } - public static FunctionExpression isnull(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.ISNULL, expressions); - } + public static FunctionExpression acos(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.ACOS, expressions); + } - public static FunctionExpression is_null(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.IS_NULL, expressions); - } + public static FunctionExpression asin(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.ASIN, expressions); + } - public static FunctionExpression isnotnull(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.IS_NOT_NULL, expressions); - } + public static FunctionExpression atan(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.ATAN, expressions); + } - public static FunctionExpression ifnull(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.IFNULL, expressions); - } + public static FunctionExpression atan2(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.ATAN2, expressions); + } - public static FunctionExpression nullif(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.NULLIF, expressions); - } + public static FunctionExpression cos(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.COS, expressions); + } - public static FunctionExpression iffunction(Expression... expressions) { - return compile(FunctionProperties.None, BuiltinFunctionName.IF, expressions); - } + public static FunctionExpression cosh(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.COSH, expressions); + } - public static Expression cases(Expression defaultResult, - WhenClause... whenClauses) { - return new CaseClause(Arrays.asList(whenClauses), defaultResult); - } + public static FunctionExpression cot(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.COT, expressions); + } - public static WhenClause when(Expression condition, Expression result) { - return new WhenClause(condition, result); - } + public static FunctionExpression degrees(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.DEGREES, expressions); + } - public static FunctionExpression interval(Expression value, Expression unit) { - return compile(FunctionProperties.None, BuiltinFunctionName.INTERVAL, value, unit); - } + public static FunctionExpression radians(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.RADIANS, expressions); + } + + public static FunctionExpression sin(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.SIN, expressions); + } + + public static FunctionExpression subtract(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.SUBTRACT, expressions); + } + + public static FunctionExpression subtractFunction(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.SUBTRACTFUNCTION, expressions); + } + + public static FunctionExpression tan(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.TAN, expressions); + } + + public static FunctionExpression convert_tz(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.CONVERT_TZ, expressions); + } + + public static FunctionExpression date(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.DATE, expressions); + } + + public static FunctionExpression datetime(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.DATETIME, expressions); + } + + public static FunctionExpression date_add(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.DATE_ADD, expressions); + } + + public static FunctionExpression day(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.DAY, expressions); + } + + public static FunctionExpression dayname(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.DAYNAME, expressions); + } + + public static FunctionExpression dayofmonth(FunctionProperties functionProperties, Expression... expressions) { + return compile(functionProperties, BuiltinFunctionName.DAYOFMONTH, expressions); + } + + public static FunctionExpression dayofweek(FunctionProperties functionProperties, Expression... expressions) { + return compile(functionProperties, BuiltinFunctionName.DAYOFWEEK, expressions); + } + + public static FunctionExpression dayofyear(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.DAYOFYEAR, expressions); + } + + public static FunctionExpression day_of_month(FunctionProperties functionProperties, Expression... expressions) { + return compile(functionProperties, BuiltinFunctionName.DAY_OF_MONTH, expressions); + } + + public static FunctionExpression day_of_year(FunctionProperties functionProperties, Expression... expressions) { + return compile(functionProperties, BuiltinFunctionName.DAY_OF_YEAR, expressions); + } + + public static FunctionExpression day_of_week(FunctionProperties functionProperties, Expression... expressions) { + return compile(functionProperties, BuiltinFunctionName.DAY_OF_WEEK, expressions); + } + + public static FunctionExpression extract(FunctionProperties functionProperties, Expression... expressions) { + return compile(functionProperties, BuiltinFunctionName.EXTRACT, expressions); + } + + public static FunctionExpression extract(Expression... expressions) { + return extract(FunctionProperties.None, expressions); + } + + public static FunctionExpression from_days(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.FROM_DAYS, expressions); + } + + public static FunctionExpression get_format(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.GET_FORMAT, expressions); + } + + public static FunctionExpression hour(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.HOUR, expressions); + } + + public static FunctionExpression hour_of_day(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.HOUR_OF_DAY, expressions); + } + + public static FunctionExpression last_day(FunctionProperties functionProperties, Expression... expressions) { + return compile(functionProperties, BuiltinFunctionName.LAST_DAY, expressions); + } + + public static FunctionExpression microsecond(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.MICROSECOND, expressions); + } + + public static FunctionExpression minute(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.MINUTE, expressions); + } + + public static FunctionExpression minute_of_day(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.MINUTE_OF_DAY, expressions); + } + + public static FunctionExpression minute_of_hour(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.MINUTE_OF_HOUR, expressions); + } + + public static FunctionExpression month(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.MONTH, expressions); + } + + public static FunctionExpression month_of_year(FunctionProperties functionProperties, Expression... expressions) { + return compile(functionProperties, BuiltinFunctionName.MONTH_OF_YEAR, expressions); + } + + public static FunctionExpression monthname(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.MONTHNAME, expressions); + } + + public static FunctionExpression quarter(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.QUARTER, expressions); + } + + public static FunctionExpression second(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.SECOND, expressions); + } + + public static FunctionExpression second_of_minute(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.SECOND_OF_MINUTE, expressions); + } + + public static FunctionExpression time(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.TIME, expressions); + } + + public static FunctionExpression time_to_sec(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.TIME_TO_SEC, expressions); + } + + public static FunctionExpression timestamp(Expression... expressions) { + return timestamp(FunctionProperties.None, expressions); + } + + public static FunctionExpression timestamp(FunctionProperties functionProperties, Expression... expressions) { + return compile(functionProperties, BuiltinFunctionName.TIMESTAMP, expressions); + } + + public static FunctionExpression date_format(FunctionProperties functionProperties, Expression... expressions) { + return compile(functionProperties, BuiltinFunctionName.DATE_FORMAT, expressions); + } + + public static FunctionExpression to_days(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.TO_DAYS, expressions); + } + + public static FunctionExpression to_seconds(FunctionProperties functionProperties, Expression... expressions) { + return compile(functionProperties, BuiltinFunctionName.TO_SECONDS, expressions); + } + + public static FunctionExpression to_seconds(Expression... expressions) { + return to_seconds(FunctionProperties.None, expressions); + } + + public static FunctionExpression week(FunctionProperties functionProperties, Expression... expressions) { + return compile(functionProperties, BuiltinFunctionName.WEEK, expressions); + } + + public static FunctionExpression weekday(FunctionProperties functionProperties, Expression... expressions) { + return compile(functionProperties, BuiltinFunctionName.WEEKDAY, expressions); + } + + public static FunctionExpression weekofyear(FunctionProperties functionProperties, Expression... expressions) { + return compile(functionProperties, BuiltinFunctionName.WEEKOFYEAR, expressions); + } + + public static FunctionExpression week_of_year(FunctionProperties functionProperties, Expression... expressions) { + return compile(functionProperties, BuiltinFunctionName.WEEK_OF_YEAR, expressions); + } + + public static FunctionExpression year(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.YEAR, expressions); + } + + public static FunctionExpression yearweek(FunctionProperties functionProperties, Expression... expressions) { + return compile(functionProperties, BuiltinFunctionName.YEARWEEK, expressions); + } + + public static FunctionExpression str_to_date(FunctionProperties functionProperties, Expression... expressions) { + return compile(functionProperties, BuiltinFunctionName.STR_TO_DATE, expressions); + } + + public static FunctionExpression sec_to_time(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.SEC_TO_TIME, expressions); + } + + public static FunctionExpression substr(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.SUBSTR, expressions); + } + + public static FunctionExpression substring(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.SUBSTR, expressions); + } + + public static FunctionExpression ltrim(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.LTRIM, expressions); + } + + public static FunctionExpression rtrim(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.RTRIM, expressions); + } + + public static FunctionExpression trim(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.TRIM, expressions); + } + + public static FunctionExpression upper(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.UPPER, expressions); + } + + public static FunctionExpression lower(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.LOWER, expressions); + } + + public static FunctionExpression regexp(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.REGEXP, expressions); + } + + public static FunctionExpression concat(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.CONCAT, expressions); + } - public static FunctionExpression castString(Expression value) { - return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_STRING, value); - } - - public static FunctionExpression castByte(Expression value) { - return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_BYTE, value); - } - - public static FunctionExpression castShort(Expression value) { - return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_SHORT, value); - } - - public static FunctionExpression castInt(Expression value) { - return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_INT, value); - } - - public static FunctionExpression castLong(Expression value) { - return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_LONG, value); - } - - public static FunctionExpression castFloat(Expression value) { - return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_FLOAT, value); - } - - public static FunctionExpression castDouble(Expression value) { - return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_DOUBLE, value); - } - - public static FunctionExpression castBoolean(Expression value) { - return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_BOOLEAN, value); - } - - public static FunctionExpression castDate(Expression value) { - return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_DATE, value); - } - - public static FunctionExpression castTime(Expression value) { - return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_TIME, value); - } - - public static FunctionExpression castTimestamp(Expression value) { - return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_TIMESTAMP, value); - } - - public static FunctionExpression castDatetime(Expression value) { - return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_DATETIME, value); - } - - public static FunctionExpression typeof(Expression value) { - return compile(FunctionProperties.None, BuiltinFunctionName.TYPEOF, value); - } - - public static FunctionExpression match(Expression... args) { - return compile(FunctionProperties.None, BuiltinFunctionName.MATCH, args); - } - - public static FunctionExpression match_phrase(Expression... args) { - return compile(FunctionProperties.None, BuiltinFunctionName.MATCH_PHRASE, args); - } - - public static FunctionExpression match_phrase_prefix(Expression... args) { - return compile(FunctionProperties.None, BuiltinFunctionName.MATCH_PHRASE_PREFIX, args); - } - - public static FunctionExpression multi_match(Expression... args) { - return compile(FunctionProperties.None, BuiltinFunctionName.MULTI_MATCH, args); - } - - public static FunctionExpression simple_query_string(Expression... args) { - return compile(FunctionProperties.None, BuiltinFunctionName.SIMPLE_QUERY_STRING, args); - } - - public static FunctionExpression query(Expression... args) { - return compile(FunctionProperties.None, BuiltinFunctionName.QUERY, args); - } - - public static FunctionExpression query_string(Expression... args) { - return compile(FunctionProperties.None, BuiltinFunctionName.QUERY_STRING, args); - } - - public static FunctionExpression match_bool_prefix(Expression... args) { - return compile(FunctionProperties.None, BuiltinFunctionName.MATCH_BOOL_PREFIX, args); - } - - public static FunctionExpression wildcard_query(Expression... args) { - return compile(FunctionProperties.None, BuiltinFunctionName.WILDCARD_QUERY, args); - } - - public static FunctionExpression score(Expression... args) { - return compile(FunctionProperties.None, BuiltinFunctionName.SCORE, args); - } - - public static FunctionExpression scorequery(Expression... args) { - return compile(FunctionProperties.None, BuiltinFunctionName.SCOREQUERY, args); - } - - public static FunctionExpression score_query(Expression... args) { - return compile(FunctionProperties.None, BuiltinFunctionName.SCORE_QUERY, args); - } - - public static FunctionExpression now(FunctionProperties functionProperties, - Expression... args) { - return compile(functionProperties, BuiltinFunctionName.NOW, args); - } - - public static FunctionExpression current_timestamp(FunctionProperties functionProperties, - Expression... args) { - return compile(functionProperties, BuiltinFunctionName.CURRENT_TIMESTAMP, args); - } - - public static FunctionExpression localtimestamp(FunctionProperties functionProperties, - Expression... args) { - return compile(functionProperties, BuiltinFunctionName.LOCALTIMESTAMP, args); - } - - public static FunctionExpression localtime(FunctionProperties functionProperties, - Expression... args) { - return compile(functionProperties, BuiltinFunctionName.LOCALTIME, args); - } - - public static FunctionExpression sysdate(FunctionProperties functionProperties, - Expression... args) { - return compile(functionProperties, BuiltinFunctionName.SYSDATE, args); - } - - public static FunctionExpression curtime(FunctionProperties functionProperties, - Expression... args) { - return compile(functionProperties, BuiltinFunctionName.CURTIME, args); - } - - public static FunctionExpression current_time(FunctionProperties functionProperties, - Expression... args) { - return compile(functionProperties, BuiltinFunctionName.CURRENT_TIME, args); - } - - public static FunctionExpression curdate(FunctionProperties functionProperties, - Expression... args) { - return compile(functionProperties, BuiltinFunctionName.CURDATE, args); - } - - public static FunctionExpression current_date(FunctionProperties functionProperties, - Expression... args) { - return compile(functionProperties, BuiltinFunctionName.CURRENT_DATE, args); - } - - public static FunctionExpression time_format(FunctionProperties functionProperties, - Expression... expressions) { - return compile(functionProperties, BuiltinFunctionName.TIME_FORMAT, expressions); - } - - public static FunctionExpression timestampadd(Expression... expressions) { - return timestampadd(FunctionProperties.None, expressions); - } - - public static FunctionExpression timestampadd(FunctionProperties functionProperties, - Expression... expressions) { - return compile(functionProperties, BuiltinFunctionName.TIMESTAMPADD, expressions); - } - - public static FunctionExpression timestampdiff(FunctionProperties functionProperties, - Expression... expressions) { - return compile(functionProperties, BuiltinFunctionName.TIMESTAMPDIFF, expressions); - } - - - public static FunctionExpression utc_date(FunctionProperties functionProperties, - Expression... args) { - return compile(functionProperties, BuiltinFunctionName.UTC_DATE, args); - } - - public static FunctionExpression utc_time(FunctionProperties functionProperties, - Expression... args) { - return compile(functionProperties, BuiltinFunctionName.UTC_TIME, args); - } - - public static FunctionExpression utc_timestamp(FunctionProperties functionProperties, - Expression... args) { - return compile(functionProperties, BuiltinFunctionName.UTC_TIMESTAMP, args); - - } - - @SuppressWarnings("unchecked") - private static - T compile(FunctionProperties functionProperties, - BuiltinFunctionName bfn, Expression... args) { - return (T) BuiltinFunctionRepository.getInstance().compile(functionProperties, - bfn.getName(), Arrays.asList(args)); - } + public static FunctionExpression concat_ws(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.CONCAT_WS, expressions); + } + + public static FunctionExpression length(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.LENGTH, expressions); + } + + public static FunctionExpression strcmp(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.STRCMP, expressions); + } + + public static FunctionExpression right(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.RIGHT, expressions); + } + + public static FunctionExpression left(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.LEFT, expressions); + } + + public static FunctionExpression ascii(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.ASCII, expressions); + } + + public static FunctionExpression locate(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.LOCATE, expressions); + } + + public static FunctionExpression replace(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.REPLACE, expressions); + } + + public static FunctionExpression reverse(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.REVERSE, expressions); + } + + public static FunctionExpression and(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.AND, expressions); + } + + public static FunctionExpression or(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.OR, expressions); + } + + public static FunctionExpression xor(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.XOR, expressions); + } + + public static FunctionExpression nested(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.NESTED, expressions); + } + + public static FunctionExpression not(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.NOT, expressions); + } + + public static FunctionExpression equal(FunctionProperties fp, Expression... expressions) { + return compile(fp, BuiltinFunctionName.EQUAL, expressions); + } + + public static FunctionExpression equal(Expression... expressions) { + return equal(FunctionProperties.None, expressions); + } + + public static FunctionExpression notequal(FunctionProperties fp, Expression... expressions) { + return compile(fp, BuiltinFunctionName.NOTEQUAL, expressions); + } + + public static FunctionExpression notequal(Expression... expressions) { + return notequal(FunctionProperties.None, expressions); + } + + public static FunctionExpression less(FunctionProperties fp, Expression... expressions) { + return compile(fp, BuiltinFunctionName.LESS, expressions); + } + + public static FunctionExpression less(Expression... expressions) { + return less(FunctionProperties.None, expressions); + } + + public static FunctionExpression lte(FunctionProperties fp, Expression... expressions) { + return compile(fp, BuiltinFunctionName.LTE, expressions); + } + + public static FunctionExpression lte(Expression... expressions) { + return lte(FunctionProperties.None, expressions); + } + + public static FunctionExpression greater(FunctionProperties fp, Expression... expressions) { + return compile(fp, BuiltinFunctionName.GREATER, expressions); + } + + public static FunctionExpression greater(Expression... expressions) { + return greater(FunctionProperties.None, expressions); + } + + public static FunctionExpression gte(FunctionProperties fp, Expression... expressions) { + return compile(fp, BuiltinFunctionName.GTE, expressions); + } + + public static FunctionExpression gte(Expression... expressions) { + return gte(FunctionProperties.None, expressions); + } + + public static FunctionExpression like(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.LIKE, expressions); + } + + public static FunctionExpression notLike(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.NOT_LIKE, expressions); + } + + public static Aggregator avg(Expression... expressions) { + return aggregate(BuiltinFunctionName.AVG, expressions); + } + + public static Aggregator sum(Expression... expressions) { + return aggregate(BuiltinFunctionName.SUM, expressions); + } + + public static Aggregator count(Expression... expressions) { + return aggregate(BuiltinFunctionName.COUNT, expressions); + } + + public static Aggregator distinctCount(Expression... expressions) { + return count(expressions).distinct(true); + } + + public static Aggregator varSamp(Expression... expressions) { + return aggregate(BuiltinFunctionName.VARSAMP, expressions); + } + + public static Aggregator varPop(Expression... expressions) { + return aggregate(BuiltinFunctionName.VARPOP, expressions); + } + + public static Aggregator stddevSamp(Expression... expressions) { + return aggregate(BuiltinFunctionName.STDDEV_SAMP, expressions); + } + + public static Aggregator stddevPop(Expression... expressions) { + return aggregate(BuiltinFunctionName.STDDEV_POP, expressions); + } + + public static Aggregator take(Expression... expressions) { + return aggregate(BuiltinFunctionName.TAKE, expressions); + } + + public static RankingWindowFunction rowNumber() { + return compile(FunctionProperties.None, BuiltinFunctionName.ROW_NUMBER); + } + + public static RankingWindowFunction rank() { + return compile(FunctionProperties.None, BuiltinFunctionName.RANK); + } + + public static RankingWindowFunction denseRank() { + return compile(FunctionProperties.None, BuiltinFunctionName.DENSE_RANK); + } + + public static Aggregator min(Expression... expressions) { + return aggregate(BuiltinFunctionName.MIN, expressions); + } + + public static Aggregator max(Expression... expressions) { + return aggregate(BuiltinFunctionName.MAX, expressions); + } + + private static Aggregator aggregate(BuiltinFunctionName functionName, Expression... expressions) { + return compile(FunctionProperties.None, functionName, expressions); + } + + public static FunctionExpression isnull(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.ISNULL, expressions); + } + + public static FunctionExpression is_null(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.IS_NULL, expressions); + } + + public static FunctionExpression isnotnull(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.IS_NOT_NULL, expressions); + } + + public static FunctionExpression ifnull(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.IFNULL, expressions); + } + + public static FunctionExpression nullif(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.NULLIF, expressions); + } + + public static FunctionExpression iffunction(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.IF, expressions); + } + + public static Expression cases(Expression defaultResult, WhenClause... whenClauses) { + return new CaseClause(Arrays.asList(whenClauses), defaultResult); + } + + public static WhenClause when(Expression condition, Expression result) { + return new WhenClause(condition, result); + } + + public static FunctionExpression interval(Expression value, Expression unit) { + return compile(FunctionProperties.None, BuiltinFunctionName.INTERVAL, value, unit); + } + + public static FunctionExpression castString(Expression value) { + return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_STRING, value); + } + + public static FunctionExpression castByte(Expression value) { + return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_BYTE, value); + } + + public static FunctionExpression castShort(Expression value) { + return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_SHORT, value); + } + + public static FunctionExpression castInt(Expression value) { + return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_INT, value); + } + + public static FunctionExpression castLong(Expression value) { + return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_LONG, value); + } + + public static FunctionExpression castFloat(Expression value) { + return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_FLOAT, value); + } + + public static FunctionExpression castDouble(Expression value) { + return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_DOUBLE, value); + } + + public static FunctionExpression castBoolean(Expression value) { + return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_BOOLEAN, value); + } + + public static FunctionExpression castDate(Expression value) { + return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_DATE, value); + } + + public static FunctionExpression castTime(Expression value) { + return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_TIME, value); + } + + public static FunctionExpression castTimestamp(Expression value) { + return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_TIMESTAMP, value); + } + + public static FunctionExpression castDatetime(Expression value) { + return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_DATETIME, value); + } + + public static FunctionExpression typeof(Expression value) { + return compile(FunctionProperties.None, BuiltinFunctionName.TYPEOF, value); + } + + public static FunctionExpression match(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.MATCH, args); + } + + public static FunctionExpression match_phrase(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.MATCH_PHRASE, args); + } + + public static FunctionExpression match_phrase_prefix(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.MATCH_PHRASE_PREFIX, args); + } + + public static FunctionExpression multi_match(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.MULTI_MATCH, args); + } + + public static FunctionExpression simple_query_string(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.SIMPLE_QUERY_STRING, args); + } + + public static FunctionExpression query(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.QUERY, args); + } + + public static FunctionExpression query_string(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.QUERY_STRING, args); + } + + public static FunctionExpression match_bool_prefix(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.MATCH_BOOL_PREFIX, args); + } + + public static FunctionExpression wildcard_query(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.WILDCARD_QUERY, args); + } + + public static FunctionExpression score(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.SCORE, args); + } + + public static FunctionExpression scorequery(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.SCOREQUERY, args); + } + + public static FunctionExpression score_query(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.SCORE_QUERY, args); + } + + public static FunctionExpression now(FunctionProperties functionProperties, Expression... args) { + return compile(functionProperties, BuiltinFunctionName.NOW, args); + } + + public static FunctionExpression current_timestamp(FunctionProperties functionProperties, Expression... args) { + return compile(functionProperties, BuiltinFunctionName.CURRENT_TIMESTAMP, args); + } + + public static FunctionExpression localtimestamp(FunctionProperties functionProperties, Expression... args) { + return compile(functionProperties, BuiltinFunctionName.LOCALTIMESTAMP, args); + } + + public static FunctionExpression localtime(FunctionProperties functionProperties, Expression... args) { + return compile(functionProperties, BuiltinFunctionName.LOCALTIME, args); + } + + public static FunctionExpression sysdate(FunctionProperties functionProperties, Expression... args) { + return compile(functionProperties, BuiltinFunctionName.SYSDATE, args); + } + + public static FunctionExpression curtime(FunctionProperties functionProperties, Expression... args) { + return compile(functionProperties, BuiltinFunctionName.CURTIME, args); + } + + public static FunctionExpression current_time(FunctionProperties functionProperties, Expression... args) { + return compile(functionProperties, BuiltinFunctionName.CURRENT_TIME, args); + } + + public static FunctionExpression curdate(FunctionProperties functionProperties, Expression... args) { + return compile(functionProperties, BuiltinFunctionName.CURDATE, args); + } + + public static FunctionExpression current_date(FunctionProperties functionProperties, Expression... args) { + return compile(functionProperties, BuiltinFunctionName.CURRENT_DATE, args); + } + + public static FunctionExpression time_format(FunctionProperties functionProperties, Expression... expressions) { + return compile(functionProperties, BuiltinFunctionName.TIME_FORMAT, expressions); + } + + public static FunctionExpression timestampadd(Expression... expressions) { + return timestampadd(FunctionProperties.None, expressions); + } + + public static FunctionExpression timestampadd(FunctionProperties functionProperties, Expression... expressions) { + return compile(functionProperties, BuiltinFunctionName.TIMESTAMPADD, expressions); + } + + public static FunctionExpression timestampdiff(FunctionProperties functionProperties, Expression... expressions) { + return compile(functionProperties, BuiltinFunctionName.TIMESTAMPDIFF, expressions); + } + + public static FunctionExpression utc_date(FunctionProperties functionProperties, Expression... args) { + return compile(functionProperties, BuiltinFunctionName.UTC_DATE, args); + } + + public static FunctionExpression utc_time(FunctionProperties functionProperties, Expression... args) { + return compile(functionProperties, BuiltinFunctionName.UTC_TIME, args); + } + + public static FunctionExpression utc_timestamp(FunctionProperties functionProperties, Expression... args) { + return compile(functionProperties, BuiltinFunctionName.UTC_TIMESTAMP, args); + + } + + @SuppressWarnings("unchecked") + private static T compile( + FunctionProperties functionProperties, + BuiltinFunctionName bfn, + Expression... args + ) { + return (T) BuiltinFunctionRepository.getInstance().compile(functionProperties, bfn.getName(), Arrays.asList(args)); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/Expression.java b/core/src/main/java/org/opensearch/sql/expression/Expression.java index 25a8173efa..8778af6a69 100644 --- a/core/src/main/java/org/opensearch/sql/expression/Expression.java +++ b/core/src/main/java/org/opensearch/sql/expression/Expression.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression; import java.io.Serializable; @@ -16,31 +15,31 @@ */ public interface Expression extends Serializable { - /** - * Evaluate the value of expression that does not depend on value environment. - */ - default ExprValue valueOf() { - return valueOf(null); - } - - /** - * Evaluate the value of expression in the value environment. - */ - ExprValue valueOf(Environment valueEnv); - - /** - * The type of the expression. - */ - ExprType type(); - - /** - * Accept a visitor to visit current expression node. - * @param visitor visitor - * @param context context - * @param result type - * @param context type - * @return result accumulated by visitor when visiting - */ - T accept(ExpressionNodeVisitor visitor, C context); + /** + * Evaluate the value of expression that does not depend on value environment. + */ + default ExprValue valueOf() { + return valueOf(null); + } + + /** + * Evaluate the value of expression in the value environment. + */ + ExprValue valueOf(Environment valueEnv); + + /** + * The type of the expression. + */ + ExprType type(); + + /** + * Accept a visitor to visit current expression node. + * @param visitor visitor + * @param context context + * @param result type + * @param context type + * @return result accumulated by visitor when visiting + */ + T accept(ExpressionNodeVisitor visitor, C context); } diff --git a/core/src/main/java/org/opensearch/sql/expression/ExpressionNodeVisitor.java b/core/src/main/java/org/opensearch/sql/expression/ExpressionNodeVisitor.java index e3d4e38674..a6aefeeca5 100644 --- a/core/src/main/java/org/opensearch/sql/expression/ExpressionNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/expression/ExpressionNodeVisitor.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression; import org.opensearch.sql.expression.aggregation.Aggregator; @@ -20,81 +19,81 @@ */ public abstract class ExpressionNodeVisitor { - public T visitNode(Expression node, C context) { - return null; - } - - /** - * Visit children nodes in function arguments. - * @param node function node - * @param context context - * @return result - */ - public T visitChildren(FunctionImplementation node, C context) { - T result = defaultResult(); - - for (Expression child : node.getArguments()) { - T childResult = child.accept(this, context); - result = aggregateResult(result, childResult); + public T visitNode(Expression node, C context) { + return null; + } + + /** + * Visit children nodes in function arguments. + * @param node function node + * @param context context + * @return result + */ + public T visitChildren(FunctionImplementation node, C context) { + T result = defaultResult(); + + for (Expression child : node.getArguments()) { + T childResult = child.accept(this, context); + result = aggregateResult(result, childResult); + } + return result; + } + + private T defaultResult() { + return null; + } + + private T aggregateResult(T aggregate, T nextResult) { + return nextResult; + } + + public T visitLiteral(LiteralExpression node, C context) { + return visitNode(node, context); + } + + public T visitNamed(NamedExpression node, C context) { + return node.getDelegated().accept(this, context); + } + + public T visitHighlight(HighlightExpression node, C context) { + return visitNode(node, context); + } + + public T visitReference(ReferenceExpression node, C context) { + return visitNode(node, context); + } + + public T visitParse(ParseExpression node, C context) { + return visitNode(node, context); + } + + public T visitFunction(FunctionExpression node, C context) { + return visitChildren(node, context); + } + + public T visitAggregator(Aggregator node, C context) { + return visitChildren(node, context); + } + + public T visitNamedAggregator(NamedAggregator node, C context) { + return visitChildren(node, context); + } + + /** + * Call visitFunction() by default rather than visitChildren(). + * This makes CASE/WHEN able to be handled: + * 1) by visitFunction() if not overwritten: ex. FilterQueryBuilder + * 2) by visitCase/When() otherwise if any special logic: ex. ExprReferenceOptimizer + */ + public T visitCase(CaseClause node, C context) { + return visitFunction(node, context); + } + + public T visitWhen(WhenClause node, C context) { + return visitFunction(node, context); + } + + public T visitNamedArgument(NamedArgumentExpression node, C context) { + return visitNode(node, context); } - return result; - } - - private T defaultResult() { - return null; - } - - private T aggregateResult(T aggregate, T nextResult) { - return nextResult; - } - - public T visitLiteral(LiteralExpression node, C context) { - return visitNode(node, context); - } - - public T visitNamed(NamedExpression node, C context) { - return node.getDelegated().accept(this, context); - } - - public T visitHighlight(HighlightExpression node, C context) { - return visitNode(node, context); - } - - public T visitReference(ReferenceExpression node, C context) { - return visitNode(node, context); - } - - public T visitParse(ParseExpression node, C context) { - return visitNode(node, context); - } - - public T visitFunction(FunctionExpression node, C context) { - return visitChildren(node, context); - } - - public T visitAggregator(Aggregator node, C context) { - return visitChildren(node, context); - } - - public T visitNamedAggregator(NamedAggregator node, C context) { - return visitChildren(node, context); - } - - /** - * Call visitFunction() by default rather than visitChildren(). - * This makes CASE/WHEN able to be handled: - * 1) by visitFunction() if not overwritten: ex. FilterQueryBuilder - * 2) by visitCase/When() otherwise if any special logic: ex. ExprReferenceOptimizer - */ - public T visitCase(CaseClause node, C context) { - return visitFunction(node, context); - } - - public T visitWhen(WhenClause node, C context) { - return visitFunction(node, context); - } - - public T visitNamedArgument(NamedArgumentExpression node, C context) { - return visitNode(node, context); - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/FunctionExpression.java b/core/src/main/java/org/opensearch/sql/expression/FunctionExpression.java index 2a695f26e6..2985fd4d57 100644 --- a/core/src/main/java/org/opensearch/sql/expression/FunctionExpression.java +++ b/core/src/main/java/org/opensearch/sql/expression/FunctionExpression.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression; import java.util.List; @@ -21,15 +20,15 @@ @RequiredArgsConstructor @ToString public abstract class FunctionExpression implements Expression, FunctionImplementation { - @Getter - private final FunctionName functionName; + @Getter + private final FunctionName functionName; - @Getter - private final List arguments; + @Getter + private final List arguments; - @Override - public T accept(ExpressionNodeVisitor visitor, C context) { - return visitor.visitFunction(this, context); - } + @Override + public T accept(ExpressionNodeVisitor visitor, C context) { + return visitor.visitFunction(this, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/HighlightExpression.java b/core/src/main/java/org/opensearch/sql/expression/HighlightExpression.java index 804c38a6f7..deccacddb6 100644 --- a/core/src/main/java/org/opensearch/sql/expression/HighlightExpression.java +++ b/core/src/main/java/org/opensearch/sql/expression/HighlightExpression.java @@ -25,78 +25,77 @@ */ @Getter public class HighlightExpression extends FunctionExpression { - private final Expression highlightField; - private final ExprType type; + private final Expression highlightField; + private final ExprType type; - /** - * HighlightExpression Constructor. - * @param highlightField : Highlight field for expression. - */ - public HighlightExpression(Expression highlightField) { - super(BuiltinFunctionName.HIGHLIGHT.getName(), List.of(highlightField)); - this.highlightField = highlightField; - this.type = this.highlightField.toString().contains("*") - ? ExprCoreType.STRUCT : ExprCoreType.ARRAY; - } - - /** - * Return collection value matching highlight field. - * @param valueEnv : Dataset to parse value from. - * @return : collection value of highlight fields. - */ - @Override - public ExprValue valueOf(Environment valueEnv) { - String refName = "_highlight"; - // Not a wilcard expression - if (this.type == ExprCoreType.ARRAY) { - refName += "." + StringUtils.unquoteText(getHighlightField().toString()); + /** + * HighlightExpression Constructor. + * @param highlightField : Highlight field for expression. + */ + public HighlightExpression(Expression highlightField) { + super(BuiltinFunctionName.HIGHLIGHT.getName(), List.of(highlightField)); + this.highlightField = highlightField; + this.type = this.highlightField.toString().contains("*") ? ExprCoreType.STRUCT : ExprCoreType.ARRAY; } - ExprValue value = valueEnv.resolve(DSL.ref(refName, ExprCoreType.STRING)); - // In the event of multiple returned highlights and wildcard being - // used in conjunction with other highlight calls, we need to ensure - // only wildcard regex matching is mapped to wildcard call. - if (this.type == ExprCoreType.STRUCT && value.type() == ExprCoreType.STRUCT) { - value = new ExprTupleValue( - new LinkedHashMap(value.tupleValue() - .entrySet() - .stream() - .filter(s -> matchesHighlightRegex(s.getKey(), - StringUtils.unquoteText(highlightField.toString()))) - .collect(Collectors.toMap( - e -> e.getKey(), - e -> e.getValue())))); - if (value.tupleValue().isEmpty()) { - value = ExprValueUtils.missingValue(); - } - } + /** + * Return collection value matching highlight field. + * @param valueEnv : Dataset to parse value from. + * @return : collection value of highlight fields. + */ + @Override + public ExprValue valueOf(Environment valueEnv) { + String refName = "_highlight"; + // Not a wilcard expression + if (this.type == ExprCoreType.ARRAY) { + refName += "." + StringUtils.unquoteText(getHighlightField().toString()); + } + ExprValue value = valueEnv.resolve(DSL.ref(refName, ExprCoreType.STRING)); - return value; - } + // In the event of multiple returned highlights and wildcard being + // used in conjunction with other highlight calls, we need to ensure + // only wildcard regex matching is mapped to wildcard call. + if (this.type == ExprCoreType.STRUCT && value.type() == ExprCoreType.STRUCT) { + value = new ExprTupleValue( + new LinkedHashMap( + value.tupleValue() + .entrySet() + .stream() + .filter(s -> matchesHighlightRegex(s.getKey(), StringUtils.unquoteText(highlightField.toString()))) + .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())) + ) + ); + if (value.tupleValue().isEmpty()) { + value = ExprValueUtils.missingValue(); + } + } - /** - * Get type for HighlightExpression. - * @return : Expression type. - */ - @Override - public ExprType type() { - return this.type; - } + return value; + } - @Override - public T accept(ExpressionNodeVisitor visitor, C context) { - return visitor.visitHighlight(this, context); - } + /** + * Get type for HighlightExpression. + * @return : Expression type. + */ + @Override + public ExprType type() { + return this.type; + } - /** - * Check if field matches the wildcard pattern used in highlight query. - * @param field Highlight selected field for query - * @param pattern Wildcard regex to match field against - * @return True if field matches wildcard pattern - */ - private boolean matchesHighlightRegex(String field, String pattern) { - Pattern p = Pattern.compile(pattern.replace("*", ".*")); - Matcher matcher = p.matcher(field); - return matcher.matches(); - } + @Override + public T accept(ExpressionNodeVisitor visitor, C context) { + return visitor.visitHighlight(this, context); + } + + /** + * Check if field matches the wildcard pattern used in highlight query. + * @param field Highlight selected field for query + * @param pattern Wildcard regex to match field against + * @return True if field matches wildcard pattern + */ + private boolean matchesHighlightRegex(String field, String pattern) { + Pattern p = Pattern.compile(pattern.replace("*", ".*")); + Matcher matcher = p.matcher(field); + return matcher.matches(); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/LiteralExpression.java b/core/src/main/java/org/opensearch/sql/expression/LiteralExpression.java index adb8e197d1..41e90daa0f 100644 --- a/core/src/main/java/org/opensearch/sql/expression/LiteralExpression.java +++ b/core/src/main/java/org/opensearch/sql/expression/LiteralExpression.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression; import lombok.EqualsAndHashCode; @@ -18,25 +17,25 @@ @EqualsAndHashCode @RequiredArgsConstructor public class LiteralExpression implements Expression { - private final ExprValue exprValue; - - @Override - public ExprValue valueOf(Environment env) { - return exprValue; - } - - @Override - public ExprType type() { - return exprValue.type(); - } - - @Override - public T accept(ExpressionNodeVisitor visitor, C context) { - return visitor.visitLiteral(this, context); - } - - @Override - public String toString() { - return exprValue.toString(); - } + private final ExprValue exprValue; + + @Override + public ExprValue valueOf(Environment env) { + return exprValue; + } + + @Override + public ExprType type() { + return exprValue.type(); + } + + @Override + public T accept(ExpressionNodeVisitor visitor, C context) { + return visitor.visitLiteral(this, context); + } + + @Override + public String toString() { + return exprValue.toString(); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/NamedArgumentExpression.java b/core/src/main/java/org/opensearch/sql/expression/NamedArgumentExpression.java index 0f4601f1bf..f22e372aca 100644 --- a/core/src/main/java/org/opensearch/sql/expression/NamedArgumentExpression.java +++ b/core/src/main/java/org/opensearch/sql/expression/NamedArgumentExpression.java @@ -21,21 +21,21 @@ @EqualsAndHashCode @ToString public class NamedArgumentExpression implements Expression { - private final String argName; - private final Expression value; + private final String argName; + private final Expression value; - @Override - public ExprValue valueOf(Environment valueEnv) { - return value.valueOf(valueEnv); - } + @Override + public ExprValue valueOf(Environment valueEnv) { + return value.valueOf(valueEnv); + } - @Override - public ExprType type() { - return value.type(); - } + @Override + public ExprType type() { + return value.type(); + } - @Override - public T accept(ExpressionNodeVisitor visitor, C context) { - return visitor.visitNamedArgument(this, context); - } + @Override + public T accept(ExpressionNodeVisitor visitor, C context) { + return visitor.visitNamedArgument(this, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/NamedExpression.java b/core/src/main/java/org/opensearch/sql/expression/NamedExpression.java index 26996eb93d..cd5a75604d 100644 --- a/core/src/main/java/org/opensearch/sql/expression/NamedExpression.java +++ b/core/src/main/java/org/opensearch/sql/expression/NamedExpression.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression; import com.google.common.base.Strings; @@ -26,47 +25,47 @@ @RequiredArgsConstructor public class NamedExpression implements Expression { - /** - * Expression name. - */ - private final String name; + /** + * Expression name. + */ + private final String name; - /** - * Expression that being named. - */ - private final Expression delegated; + /** + * Expression that being named. + */ + private final Expression delegated; - /** - * Optional alias. - */ - private String alias; + /** + * Optional alias. + */ + private String alias; - @Override - public ExprValue valueOf(Environment valueEnv) { - return delegated.valueOf(valueEnv); - } + @Override + public ExprValue valueOf(Environment valueEnv) { + return delegated.valueOf(valueEnv); + } - @Override - public ExprType type() { - return delegated.type(); - } + @Override + public ExprType type() { + return delegated.type(); + } - /** - * Get expression name using name or its alias (if it's present). - * @return expression name - */ - public String getNameOrAlias() { - return Strings.isNullOrEmpty(alias) ? name : alias; - } + /** + * Get expression name using name or its alias (if it's present). + * @return expression name + */ + public String getNameOrAlias() { + return Strings.isNullOrEmpty(alias) ? name : alias; + } - @Override - public T accept(ExpressionNodeVisitor visitor, C context) { - return visitor.visitNamed(this, context); - } + @Override + public T accept(ExpressionNodeVisitor visitor, C context) { + return visitor.visitNamed(this, context); + } - @Override - public String toString() { - return getNameOrAlias(); - } + @Override + public String toString() { + return getNameOrAlias(); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/ReferenceExpression.java b/core/src/main/java/org/opensearch/sql/expression/ReferenceExpression.java index 3c5b2af23c..e7a75d2032 100644 --- a/core/src/main/java/org/opensearch/sql/expression/ReferenceExpression.java +++ b/core/src/main/java/org/opensearch/sql/expression/ReferenceExpression.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression; import static org.opensearch.sql.utils.ExpressionUtils.PATH_SEP; @@ -22,95 +21,95 @@ @EqualsAndHashCode @RequiredArgsConstructor public class ReferenceExpression implements Expression { - @Getter - private final String attr; - - @Getter - private final List paths; + @Getter + private final String attr; - private final ExprType type; + @Getter + private final List paths; - /** - * Constructor of ReferenceExpression. - * @param ref the field name. e.g. addr.state/addr. - * @param type type. - */ - public ReferenceExpression(String ref, ExprType type) { - this.attr = ref; - // Todo. the define of paths need to be redefined after adding multiple index/variable support. - this.paths = Arrays.asList(ref.split("\\.")); - this.type = type; - } + private final ExprType type; - @Override - public ExprValue valueOf(Environment env) { - return env.resolve(this); - } + /** + * Constructor of ReferenceExpression. + * @param ref the field name. e.g. addr.state/addr. + * @param type type. + */ + public ReferenceExpression(String ref, ExprType type) { + this.attr = ref; + // Todo. the define of paths need to be redefined after adding multiple index/variable support. + this.paths = Arrays.asList(ref.split("\\.")); + this.type = type; + } - @Override - public ExprType type() { - return type; - } + @Override + public ExprValue valueOf(Environment env) { + return env.resolve(this); + } - @Override - public T accept(ExpressionNodeVisitor visitor, C context) { - return visitor.visitReference(this, context); - } + @Override + public ExprType type() { + return type; + } - @Override - public String toString() { - return attr; - } + @Override + public T accept(ExpressionNodeVisitor visitor, C context) { + return visitor.visitReference(this, context); + } - /** - * Resolve the ExprValue from {@link ExprTupleValue} using paths. - * Considering the following sample data. - * { - * "name": "bob smith" - * "project.year": 1990, - * "project": { - * "year": "2020" - * } - * "address": { - * "state": "WA", - * "city": "seattle", - * "project.year": 1990 - * } - * "address.local": { - * "state": "WA", - * } - * } - * The paths could be - * 1. top level, e.g. "name", which will be resolved as "bob smith" - * 2. multiple paths, e.g. "name.address.state", which will be resolved as "WA" - * 3. special case, the "." is the path separator, but it is possible that the path include - * ".", for handling this use case, we define the resolve rule as bellow, e.g. "project.year" is - * resolved as 1990 instead of 2020. Note. This logic only applied top level none object field. - * e.g. "address.local.state" been resolved to Missing. but "address.project.year" could been - * resolved as 1990. - * - *

Resolve Rule - * 1. Resolve the full name by combine the paths("x"."y"."z") as whole ("x.y.z"). - * 2. Resolve the path recursively through ExprValue. - * - * @param value {@link ExprTupleValue}. - * @return {@link ExprTupleValue}. - */ - public ExprValue resolve(ExprTupleValue value) { - return resolve(value, paths); - } + @Override + public String toString() { + return attr; + } - private ExprValue resolve(ExprValue value, List paths) { - ExprValue wholePathValue = value.keyValue(String.join(PATH_SEP, paths)); - // For array types only first index currently supported. - if (value.type().equals(ExprCoreType.ARRAY)) { - wholePathValue = value.collectionValue().get(0).keyValue(paths.get(0)); + /** + * Resolve the ExprValue from {@link ExprTupleValue} using paths. + * Considering the following sample data. + * { + * "name": "bob smith" + * "project.year": 1990, + * "project": { + * "year": "2020" + * } + * "address": { + * "state": "WA", + * "city": "seattle", + * "project.year": 1990 + * } + * "address.local": { + * "state": "WA", + * } + * } + * The paths could be + * 1. top level, e.g. "name", which will be resolved as "bob smith" + * 2. multiple paths, e.g. "name.address.state", which will be resolved as "WA" + * 3. special case, the "." is the path separator, but it is possible that the path include + * ".", for handling this use case, we define the resolve rule as bellow, e.g. "project.year" is + * resolved as 1990 instead of 2020. Note. This logic only applied top level none object field. + * e.g. "address.local.state" been resolved to Missing. but "address.project.year" could been + * resolved as 1990. + * + *

Resolve Rule + * 1. Resolve the full name by combine the paths("x"."y"."z") as whole ("x.y.z"). + * 2. Resolve the path recursively through ExprValue. + * + * @param value {@link ExprTupleValue}. + * @return {@link ExprTupleValue}. + */ + public ExprValue resolve(ExprTupleValue value) { + return resolve(value, paths); } - if (!wholePathValue.isMissing() || paths.size() == 1) { - return wholePathValue; - } else { - return resolve(value.keyValue(paths.get(0)), paths.subList(1, paths.size())); + private ExprValue resolve(ExprValue value, List paths) { + ExprValue wholePathValue = value.keyValue(String.join(PATH_SEP, paths)); + // For array types only first index currently supported. + if (value.type().equals(ExprCoreType.ARRAY)) { + wholePathValue = value.collectionValue().get(0).keyValue(paths.get(0)); + } + + if (!wholePathValue.isMissing() || paths.size() == 1) { + return wholePathValue; + } else { + return resolve(value.keyValue(paths.get(0)), paths.subList(1, paths.size())); + } } - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java index 345c6c00dd..9bee65420f 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.aggregation; import org.opensearch.sql.data.model.ExprValue; @@ -13,8 +12,8 @@ * Maintain the state when {@link Aggregator} iterate on the {@link BindingTuple}. */ public interface AggregationState { - /** - * Get {@link ExprValue} result. - */ - ExprValue result(); + /** + * Get {@link ExprValue} result. + */ + ExprValue result(); } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java index a122ea6540..4476b547e4 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.aggregation; import java.util.List; @@ -33,76 +32,74 @@ */ @EqualsAndHashCode @RequiredArgsConstructor -public abstract class Aggregator - implements FunctionImplementation, Expression { - @Getter - private final FunctionName functionName; - @Getter - private final List arguments; - protected final ExprCoreType returnType; - @Setter - @Getter - @Accessors(fluent = true) - protected Expression condition; - @Setter - @Getter - @Accessors(fluent = true) - protected Boolean distinct = false; +public abstract class Aggregator implements FunctionImplementation, Expression { + @Getter + private final FunctionName functionName; + @Getter + private final List arguments; + protected final ExprCoreType returnType; + @Setter + @Getter + @Accessors(fluent = true) + protected Expression condition; + @Setter + @Getter + @Accessors(fluent = true) + protected Boolean distinct = false; - /** - * Create an {@link AggregationState} which will be used for aggregation. - */ - public abstract S create(); + /** + * Create an {@link AggregationState} which will be used for aggregation. + */ + public abstract S create(); - /** - * Iterate on {@link ExprValue}. - * @param value {@link ExprValue} - * @param state {@link AggregationState} - * @return {@link AggregationState} - */ - protected abstract S iterate(ExprValue value, S state); + /** + * Iterate on {@link ExprValue}. + * @param value {@link ExprValue} + * @param state {@link AggregationState} + * @return {@link AggregationState} + */ + protected abstract S iterate(ExprValue value, S state); - /** - * Let the aggregator iterate on the {@link BindingTuple} - * To filter out ExprValues that are missing, null or cannot satisfy {@link #condition} - * Before the specific aggregator iterating ExprValue in the tuple. - * - * @param tuple {@link BindingTuple} - * @param state {@link AggregationState} - * @return {@link AggregationState} - */ - public S iterate(BindingTuple tuple, S state) { - ExprValue value = getArguments().get(0).valueOf(tuple); - if (value.isNull() || value.isMissing() || !conditionValue(tuple)) { - return state; + /** + * Let the aggregator iterate on the {@link BindingTuple} + * To filter out ExprValues that are missing, null or cannot satisfy {@link #condition} + * Before the specific aggregator iterating ExprValue in the tuple. + * + * @param tuple {@link BindingTuple} + * @param state {@link AggregationState} + * @return {@link AggregationState} + */ + public S iterate(BindingTuple tuple, S state) { + ExprValue value = getArguments().get(0).valueOf(tuple); + if (value.isNull() || value.isMissing() || !conditionValue(tuple)) { + return state; + } + return iterate(value, state); } - return iterate(value, state); - } - @Override - public ExprValue valueOf(Environment valueEnv) { - throw new ExpressionEvaluationException( - String.format("can't evaluate on aggregator: %s", functionName)); - } + @Override + public ExprValue valueOf(Environment valueEnv) { + throw new ExpressionEvaluationException(String.format("can't evaluate on aggregator: %s", functionName)); + } - @Override - public ExprType type() { - return returnType; - } + @Override + public ExprType type() { + return returnType; + } - @Override - public T accept(ExpressionNodeVisitor visitor, C context) { - return visitor.visitAggregator(this, context); - } + @Override + public T accept(ExpressionNodeVisitor visitor, C context) { + return visitor.visitAggregator(this, context); + } - /** - * Util method to get value of condition in aggregation filter. - */ - public boolean conditionValue(BindingTuple tuple) { - if (condition == null) { - return true; + /** + * Util method to get value of condition in aggregation filter. + */ + public boolean conditionValue(BindingTuple tuple) { + if (condition == null) { + return true; + } + return ExprValueUtils.getBooleanValue(condition.valueOf(tuple)); } - return ExprValueUtils.getBooleanValue(condition.valueOf(tuple)); - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java index a24eeca1c1..ad47359cdf 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.aggregation; import static org.opensearch.sql.data.type.ExprCoreType.ARRAY; @@ -44,174 +43,234 @@ */ @UtilityClass public class AggregatorFunction { - /** - * Register Aggregation Function. - * - * @param repository {@link BuiltinFunctionRepository}. - */ - public static void register(BuiltinFunctionRepository repository) { - repository.register(avg()); - repository.register(sum()); - repository.register(count()); - repository.register(min()); - repository.register(max()); - repository.register(varSamp()); - repository.register(varPop()); - repository.register(stddevSamp()); - repository.register(stddevPop()); - repository.register(take()); - } + /** + * Register Aggregation Function. + * + * @param repository {@link BuiltinFunctionRepository}. + */ + public static void register(BuiltinFunctionRepository repository) { + repository.register(avg()); + repository.register(sum()); + repository.register(count()); + repository.register(min()); + repository.register(max()); + repository.register(varSamp()); + repository.register(varPop()); + repository.register(stddevSamp()); + repository.register(stddevPop()); + repository.register(take()); + } - private static DefaultFunctionResolver avg() { - FunctionName functionName = BuiltinFunctionName.AVG.getName(); - return new DefaultFunctionResolver( - functionName, - new ImmutableMap.Builder() - .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - (functionProperties, arguments) -> new AvgAggregator(arguments, DOUBLE)) - .put(new FunctionSignature(functionName, Collections.singletonList(DATE)), - (functionProperties, arguments) -> new AvgAggregator(arguments, DATE)) - .put(new FunctionSignature(functionName, Collections.singletonList(DATETIME)), - (functionProperties, arguments) -> new AvgAggregator(arguments, DATETIME)) - .put(new FunctionSignature(functionName, Collections.singletonList(TIME)), - (functionProperties, arguments) -> new AvgAggregator(arguments, TIME)) - .put(new FunctionSignature(functionName, Collections.singletonList(TIMESTAMP)), - (functionProperties, arguments) -> new AvgAggregator(arguments, TIMESTAMP)) - .build() - ); - } + private static DefaultFunctionResolver avg() { + FunctionName functionName = BuiltinFunctionName.AVG.getName(); + return new DefaultFunctionResolver( + functionName, + new ImmutableMap.Builder().put( + new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), + (functionProperties, arguments) -> new AvgAggregator(arguments, DOUBLE) + ) + .put( + new FunctionSignature(functionName, Collections.singletonList(DATE)), + (functionProperties, arguments) -> new AvgAggregator(arguments, DATE) + ) + .put( + new FunctionSignature(functionName, Collections.singletonList(DATETIME)), + (functionProperties, arguments) -> new AvgAggregator(arguments, DATETIME) + ) + .put( + new FunctionSignature(functionName, Collections.singletonList(TIME)), + (functionProperties, arguments) -> new AvgAggregator(arguments, TIME) + ) + .put( + new FunctionSignature(functionName, Collections.singletonList(TIMESTAMP)), + (functionProperties, arguments) -> new AvgAggregator(arguments, TIMESTAMP) + ) + .build() + ); + } - private static DefaultFunctionResolver count() { - FunctionName functionName = BuiltinFunctionName.COUNT.getName(); - DefaultFunctionResolver functionResolver = new DefaultFunctionResolver(functionName, - ExprCoreType.coreTypes().stream().collect(Collectors.toMap( - type -> new FunctionSignature(functionName, Collections.singletonList(type)), - type -> (functionProperties, arguments) -> new CountAggregator(arguments, INTEGER)))); - return functionResolver; - } + private static DefaultFunctionResolver count() { + FunctionName functionName = BuiltinFunctionName.COUNT.getName(); + DefaultFunctionResolver functionResolver = new DefaultFunctionResolver( + functionName, + ExprCoreType.coreTypes() + .stream() + .collect( + Collectors.toMap( + type -> new FunctionSignature(functionName, Collections.singletonList(type)), + type -> (functionProperties, arguments) -> new CountAggregator(arguments, INTEGER) + ) + ) + ); + return functionResolver; + } - private static DefaultFunctionResolver sum() { - FunctionName functionName = BuiltinFunctionName.SUM.getName(); - return new DefaultFunctionResolver( - functionName, - new ImmutableMap.Builder() - .put(new FunctionSignature(functionName, Collections.singletonList(INTEGER)), - (functionProperties, arguments) -> new SumAggregator(arguments, INTEGER)) - .put(new FunctionSignature(functionName, Collections.singletonList(LONG)), - (functionProperties, arguments) -> new SumAggregator(arguments, LONG)) - .put(new FunctionSignature(functionName, Collections.singletonList(FLOAT)), - (functionProperties, arguments) -> new SumAggregator(arguments, FLOAT)) - .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - (functionProperties, arguments) -> new SumAggregator(arguments, DOUBLE)) - .build() - ); - } + private static DefaultFunctionResolver sum() { + FunctionName functionName = BuiltinFunctionName.SUM.getName(); + return new DefaultFunctionResolver( + functionName, + new ImmutableMap.Builder().put( + new FunctionSignature(functionName, Collections.singletonList(INTEGER)), + (functionProperties, arguments) -> new SumAggregator(arguments, INTEGER) + ) + .put( + new FunctionSignature(functionName, Collections.singletonList(LONG)), + (functionProperties, arguments) -> new SumAggregator(arguments, LONG) + ) + .put( + new FunctionSignature(functionName, Collections.singletonList(FLOAT)), + (functionProperties, arguments) -> new SumAggregator(arguments, FLOAT) + ) + .put( + new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), + (functionProperties, arguments) -> new SumAggregator(arguments, DOUBLE) + ) + .build() + ); + } - private static DefaultFunctionResolver min() { - FunctionName functionName = BuiltinFunctionName.MIN.getName(); - return new DefaultFunctionResolver( - functionName, - new ImmutableMap.Builder() - .put(new FunctionSignature(functionName, Collections.singletonList(INTEGER)), - (functionProperties, arguments) -> new MinAggregator(arguments, INTEGER)) - .put(new FunctionSignature(functionName, Collections.singletonList(LONG)), - (functionProperties, arguments) -> new MinAggregator(arguments, LONG)) - .put(new FunctionSignature(functionName, Collections.singletonList(FLOAT)), - (functionProperties, arguments) -> new MinAggregator(arguments, FLOAT)) - .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - (functionProperties, arguments) -> new MinAggregator(arguments, DOUBLE)) - .put(new FunctionSignature(functionName, Collections.singletonList(STRING)), - (functionProperties, arguments) -> new MinAggregator(arguments, STRING)) - .put(new FunctionSignature(functionName, Collections.singletonList(DATE)), - (functionProperties, arguments) -> new MinAggregator(arguments, DATE)) - .put(new FunctionSignature(functionName, Collections.singletonList(DATETIME)), - (functionProperties, arguments) -> new MinAggregator(arguments, DATETIME)) - .put(new FunctionSignature(functionName, Collections.singletonList(TIME)), - (functionProperties, arguments) -> new MinAggregator(arguments, TIME)) - .put(new FunctionSignature(functionName, Collections.singletonList(TIMESTAMP)), - (functionProperties, arguments) -> new MinAggregator(arguments, TIMESTAMP)) - .build()); - } + private static DefaultFunctionResolver min() { + FunctionName functionName = BuiltinFunctionName.MIN.getName(); + return new DefaultFunctionResolver( + functionName, + new ImmutableMap.Builder().put( + new FunctionSignature(functionName, Collections.singletonList(INTEGER)), + (functionProperties, arguments) -> new MinAggregator(arguments, INTEGER) + ) + .put( + new FunctionSignature(functionName, Collections.singletonList(LONG)), + (functionProperties, arguments) -> new MinAggregator(arguments, LONG) + ) + .put( + new FunctionSignature(functionName, Collections.singletonList(FLOAT)), + (functionProperties, arguments) -> new MinAggregator(arguments, FLOAT) + ) + .put( + new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), + (functionProperties, arguments) -> new MinAggregator(arguments, DOUBLE) + ) + .put( + new FunctionSignature(functionName, Collections.singletonList(STRING)), + (functionProperties, arguments) -> new MinAggregator(arguments, STRING) + ) + .put( + new FunctionSignature(functionName, Collections.singletonList(DATE)), + (functionProperties, arguments) -> new MinAggregator(arguments, DATE) + ) + .put( + new FunctionSignature(functionName, Collections.singletonList(DATETIME)), + (functionProperties, arguments) -> new MinAggregator(arguments, DATETIME) + ) + .put( + new FunctionSignature(functionName, Collections.singletonList(TIME)), + (functionProperties, arguments) -> new MinAggregator(arguments, TIME) + ) + .put( + new FunctionSignature(functionName, Collections.singletonList(TIMESTAMP)), + (functionProperties, arguments) -> new MinAggregator(arguments, TIMESTAMP) + ) + .build() + ); + } - private static DefaultFunctionResolver max() { - FunctionName functionName = BuiltinFunctionName.MAX.getName(); - return new DefaultFunctionResolver( - functionName, - new ImmutableMap.Builder() - .put(new FunctionSignature(functionName, Collections.singletonList(INTEGER)), - (functionProperties, arguments) -> new MaxAggregator(arguments, INTEGER)) - .put(new FunctionSignature(functionName, Collections.singletonList(LONG)), - (functionProperties, arguments) -> new MaxAggregator(arguments, LONG)) - .put(new FunctionSignature(functionName, Collections.singletonList(FLOAT)), - (functionProperties, arguments) -> new MaxAggregator(arguments, FLOAT)) - .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - (functionProperties, arguments) -> new MaxAggregator(arguments, DOUBLE)) - .put(new FunctionSignature(functionName, Collections.singletonList(STRING)), - (functionProperties, arguments) -> new MaxAggregator(arguments, STRING)) - .put(new FunctionSignature(functionName, Collections.singletonList(DATE)), - (functionProperties, arguments) -> new MaxAggregator(arguments, DATE)) - .put(new FunctionSignature(functionName, Collections.singletonList(DATETIME)), - (functionProperties, arguments) -> new MaxAggregator(arguments, DATETIME)) - .put(new FunctionSignature(functionName, Collections.singletonList(TIME)), - (functionProperties, arguments) -> new MaxAggregator(arguments, TIME)) - .put(new FunctionSignature(functionName, Collections.singletonList(TIMESTAMP)), - (functionProperties, arguments) -> new MaxAggregator(arguments, TIMESTAMP)) - .build() - ); - } + private static DefaultFunctionResolver max() { + FunctionName functionName = BuiltinFunctionName.MAX.getName(); + return new DefaultFunctionResolver( + functionName, + new ImmutableMap.Builder().put( + new FunctionSignature(functionName, Collections.singletonList(INTEGER)), + (functionProperties, arguments) -> new MaxAggregator(arguments, INTEGER) + ) + .put( + new FunctionSignature(functionName, Collections.singletonList(LONG)), + (functionProperties, arguments) -> new MaxAggregator(arguments, LONG) + ) + .put( + new FunctionSignature(functionName, Collections.singletonList(FLOAT)), + (functionProperties, arguments) -> new MaxAggregator(arguments, FLOAT) + ) + .put( + new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), + (functionProperties, arguments) -> new MaxAggregator(arguments, DOUBLE) + ) + .put( + new FunctionSignature(functionName, Collections.singletonList(STRING)), + (functionProperties, arguments) -> new MaxAggregator(arguments, STRING) + ) + .put( + new FunctionSignature(functionName, Collections.singletonList(DATE)), + (functionProperties, arguments) -> new MaxAggregator(arguments, DATE) + ) + .put( + new FunctionSignature(functionName, Collections.singletonList(DATETIME)), + (functionProperties, arguments) -> new MaxAggregator(arguments, DATETIME) + ) + .put( + new FunctionSignature(functionName, Collections.singletonList(TIME)), + (functionProperties, arguments) -> new MaxAggregator(arguments, TIME) + ) + .put( + new FunctionSignature(functionName, Collections.singletonList(TIMESTAMP)), + (functionProperties, arguments) -> new MaxAggregator(arguments, TIMESTAMP) + ) + .build() + ); + } - private static DefaultFunctionResolver varSamp() { - FunctionName functionName = BuiltinFunctionName.VARSAMP.getName(); - return new DefaultFunctionResolver( - functionName, - new ImmutableMap.Builder() - .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - (functionProperties, arguments) -> varianceSample(arguments, DOUBLE)) - .build() - ); - } + private static DefaultFunctionResolver varSamp() { + FunctionName functionName = BuiltinFunctionName.VARSAMP.getName(); + return new DefaultFunctionResolver( + functionName, + new ImmutableMap.Builder().put( + new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), + (functionProperties, arguments) -> varianceSample(arguments, DOUBLE) + ).build() + ); + } - private static DefaultFunctionResolver varPop() { - FunctionName functionName = BuiltinFunctionName.VARPOP.getName(); - return new DefaultFunctionResolver( - functionName, - new ImmutableMap.Builder() - .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - (functionProperties, arguments) -> variancePopulation(arguments, DOUBLE)) - .build() - ); - } + private static DefaultFunctionResolver varPop() { + FunctionName functionName = BuiltinFunctionName.VARPOP.getName(); + return new DefaultFunctionResolver( + functionName, + new ImmutableMap.Builder().put( + new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), + (functionProperties, arguments) -> variancePopulation(arguments, DOUBLE) + ).build() + ); + } - private static DefaultFunctionResolver stddevSamp() { - FunctionName functionName = BuiltinFunctionName.STDDEV_SAMP.getName(); - return new DefaultFunctionResolver( - functionName, - new ImmutableMap.Builder() - .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - (functionProperties, arguments) -> stddevSample(arguments, DOUBLE)) - .build() - ); - } + private static DefaultFunctionResolver stddevSamp() { + FunctionName functionName = BuiltinFunctionName.STDDEV_SAMP.getName(); + return new DefaultFunctionResolver( + functionName, + new ImmutableMap.Builder().put( + new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), + (functionProperties, arguments) -> stddevSample(arguments, DOUBLE) + ).build() + ); + } - private static DefaultFunctionResolver stddevPop() { - FunctionName functionName = BuiltinFunctionName.STDDEV_POP.getName(); - return new DefaultFunctionResolver( - functionName, - new ImmutableMap.Builder() - .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - (functionProperties, arguments) -> stddevPopulation(arguments, DOUBLE)) - .build() - ); - } + private static DefaultFunctionResolver stddevPop() { + FunctionName functionName = BuiltinFunctionName.STDDEV_POP.getName(); + return new DefaultFunctionResolver( + functionName, + new ImmutableMap.Builder().put( + new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), + (functionProperties, arguments) -> stddevPopulation(arguments, DOUBLE) + ).build() + ); + } - private static DefaultFunctionResolver take() { - FunctionName functionName = BuiltinFunctionName.TAKE.getName(); - DefaultFunctionResolver functionResolver = new DefaultFunctionResolver(functionName, - new ImmutableMap.Builder() - .put(new FunctionSignature(functionName, ImmutableList.of(STRING, INTEGER)), - (functionProperties, arguments) -> new TakeAggregator(arguments, ARRAY)) - .build()); - return functionResolver; - } + private static DefaultFunctionResolver take() { + FunctionName functionName = BuiltinFunctionName.TAKE.getName(); + DefaultFunctionResolver functionResolver = new DefaultFunctionResolver( + functionName, + new ImmutableMap.Builder().put( + new FunctionSignature(functionName, ImmutableList.of(STRING, INTEGER)), + (functionProperties, arguments) -> new TakeAggregator(arguments, ARRAY) + ).build() + ); + return functionResolver; + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AvgAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AvgAggregator.java index a899a6b45b..f709e45349 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AvgAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AvgAggregator.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.aggregation; import static java.time.temporal.ChronoUnit.MILLIS; @@ -32,160 +31,153 @@ */ public class AvgAggregator extends Aggregator { - /** - * To process by different ways different data types, we need to store the type. - * Input data has the same type as the result. - */ - private final ExprCoreType dataType; - - public AvgAggregator(List arguments, ExprCoreType returnType) { - super(BuiltinFunctionName.AVG.getName(), arguments, returnType); - dataType = returnType; - } - - @Override - public AvgState create() { - switch (dataType) { - case DATE: - return new DateAvgState(); - case DATETIME: - return new DateTimeAvgState(); - case TIMESTAMP: - return new TimestampAvgState(); - case TIME: - return new TimeAvgState(); - case DOUBLE: - return new DoubleAvgState(); - default: //unreachable code - we don't expose signatures for unsupported types - throw new IllegalArgumentException( - String.format("avg aggregation over %s type is not supported", dataType)); - } - } - - @Override - protected AvgState iterate(ExprValue value, AvgState state) { - return state.iterate(value); - } - - @Override - public String toString() { - return String.format(Locale.ROOT, "avg(%s)", format(getArguments())); - } - - /** - * Average State. - */ - protected abstract static class AvgState implements AggregationState { - protected ExprValue count; - protected ExprValue total; - - AvgState() { - this.count = new ExprIntegerValue(0); - this.total = new ExprDoubleValue(0D); - } - - @Override - public abstract ExprValue result(); - - protected AvgState iterate(ExprValue value) { - count = DSL.add(DSL.literal(count), DSL.literal(1)).valueOf(); - return this; - } - } + /** + * To process by different ways different data types, we need to store the type. + * Input data has the same type as the result. + */ + private final ExprCoreType dataType; - protected static class DoubleAvgState extends AvgState { - @Override - public ExprValue result() { - if (0 == count.integerValue()) { - return ExprNullValue.of(); - } - return DSL.divide(DSL.literal(total), DSL.literal(count)).valueOf(); + public AvgAggregator(List arguments, ExprCoreType returnType) { + super(BuiltinFunctionName.AVG.getName(), arguments, returnType); + dataType = returnType; } @Override - protected AvgState iterate(ExprValue value) { - total = DSL.add(DSL.literal(total), DSL.literal(value)).valueOf(); - return super.iterate(value); + public AvgState create() { + switch (dataType) { + case DATE: + return new DateAvgState(); + case DATETIME: + return new DateTimeAvgState(); + case TIMESTAMP: + return new TimestampAvgState(); + case TIME: + return new TimeAvgState(); + case DOUBLE: + return new DoubleAvgState(); + default: // unreachable code - we don't expose signatures for unsupported types + throw new IllegalArgumentException(String.format("avg aggregation over %s type is not supported", dataType)); + } } - } - protected static class DateAvgState extends AvgState { @Override - public ExprValue result() { - if (0 == count.integerValue()) { - return ExprNullValue.of(); - } - - return new ExprDateValue( - new ExprTimestampValue(Instant.ofEpochMilli( - DSL.divide(DSL.literal(total), DSL.literal(count)).valueOf().longValue())) - .dateValue()); + protected AvgState iterate(ExprValue value, AvgState state) { + return state.iterate(value); } @Override - protected AvgState iterate(ExprValue value) { - total = DSL.add(DSL.literal(total), DSL.literal(value.timestampValue().toEpochMilli())) - .valueOf(); - return super.iterate(value); + public String toString() { + return String.format(Locale.ROOT, "avg(%s)", format(getArguments())); } - } - protected static class DateTimeAvgState extends AvgState { - @Override - public ExprValue result() { - if (0 == count.integerValue()) { - return ExprNullValue.of(); - } - - return new ExprDatetimeValue( - new ExprTimestampValue(Instant.ofEpochMilli( - DSL.divide(DSL.literal(total), DSL.literal(count)).valueOf().longValue())) - .datetimeValue()); + /** + * Average State. + */ + protected abstract static class AvgState implements AggregationState { + protected ExprValue count; + protected ExprValue total; + + AvgState() { + this.count = new ExprIntegerValue(0); + this.total = new ExprDoubleValue(0D); + } + + @Override + public abstract ExprValue result(); + + protected AvgState iterate(ExprValue value) { + count = DSL.add(DSL.literal(count), DSL.literal(1)).valueOf(); + return this; + } } - @Override - protected AvgState iterate(ExprValue value) { - total = DSL.add(DSL.literal(total), DSL.literal(value.timestampValue().toEpochMilli())) - .valueOf(); - return super.iterate(value); + protected static class DoubleAvgState extends AvgState { + @Override + public ExprValue result() { + if (0 == count.integerValue()) { + return ExprNullValue.of(); + } + return DSL.divide(DSL.literal(total), DSL.literal(count)).valueOf(); + } + + @Override + protected AvgState iterate(ExprValue value) { + total = DSL.add(DSL.literal(total), DSL.literal(value)).valueOf(); + return super.iterate(value); + } } - } - - protected static class TimestampAvgState extends AvgState { - @Override - public ExprValue result() { - if (0 == count.integerValue()) { - return ExprNullValue.of(); - } - return new ExprTimestampValue(Instant.ofEpochMilli( - DSL.divide(DSL.literal(total), DSL.literal(count)).valueOf().longValue())); + protected static class DateAvgState extends AvgState { + @Override + public ExprValue result() { + if (0 == count.integerValue()) { + return ExprNullValue.of(); + } + + return new ExprDateValue( + new ExprTimestampValue(Instant.ofEpochMilli(DSL.divide(DSL.literal(total), DSL.literal(count)).valueOf().longValue())) + .dateValue() + ); + } + + @Override + protected AvgState iterate(ExprValue value) { + total = DSL.add(DSL.literal(total), DSL.literal(value.timestampValue().toEpochMilli())).valueOf(); + return super.iterate(value); + } } - @Override - protected AvgState iterate(ExprValue value) { - total = DSL.add(DSL.literal(total), DSL.literal(value.timestampValue().toEpochMilli())) - .valueOf(); - return super.iterate(value); + protected static class DateTimeAvgState extends AvgState { + @Override + public ExprValue result() { + if (0 == count.integerValue()) { + return ExprNullValue.of(); + } + + return new ExprDatetimeValue( + new ExprTimestampValue(Instant.ofEpochMilli(DSL.divide(DSL.literal(total), DSL.literal(count)).valueOf().longValue())) + .datetimeValue() + ); + } + + @Override + protected AvgState iterate(ExprValue value) { + total = DSL.add(DSL.literal(total), DSL.literal(value.timestampValue().toEpochMilli())).valueOf(); + return super.iterate(value); + } } - } - - protected static class TimeAvgState extends AvgState { - @Override - public ExprValue result() { - if (0 == count.integerValue()) { - return ExprNullValue.of(); - } - return new ExprTimeValue(LocalTime.MIN.plus( - DSL.divide(DSL.literal(total), DSL.literal(count)).valueOf().longValue(), MILLIS)); + protected static class TimestampAvgState extends AvgState { + @Override + public ExprValue result() { + if (0 == count.integerValue()) { + return ExprNullValue.of(); + } + + return new ExprTimestampValue(Instant.ofEpochMilli(DSL.divide(DSL.literal(total), DSL.literal(count)).valueOf().longValue())); + } + + @Override + protected AvgState iterate(ExprValue value) { + total = DSL.add(DSL.literal(total), DSL.literal(value.timestampValue().toEpochMilli())).valueOf(); + return super.iterate(value); + } } - @Override - protected AvgState iterate(ExprValue value) { - total = DSL.add(DSL.literal(total), - DSL.literal(MILLIS.between(LocalTime.MIN, value.timeValue()))).valueOf(); - return super.iterate(value); + protected static class TimeAvgState extends AvgState { + @Override + public ExprValue result() { + if (0 == count.integerValue()) { + return ExprNullValue.of(); + } + + return new ExprTimeValue(LocalTime.MIN.plus(DSL.divide(DSL.literal(total), DSL.literal(count)).valueOf().longValue(), MILLIS)); + } + + @Override + protected AvgState iterate(ExprValue value) { + total = DSL.add(DSL.literal(total), DSL.literal(MILLIS.between(LocalTime.MIN, value.timeValue()))).valueOf(); + return super.iterate(value); + } } - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java index 813842cadc..acd1dd99a5 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.aggregation; import static org.opensearch.sql.utils.ExpressionUtils.format; @@ -21,57 +20,57 @@ public class CountAggregator extends Aggregator { - public CountAggregator(List arguments, ExprCoreType returnType) { - super(BuiltinFunctionName.COUNT.getName(), arguments, returnType); - } + public CountAggregator(List arguments, ExprCoreType returnType) { + super(BuiltinFunctionName.COUNT.getName(), arguments, returnType); + } - @Override - public CountAggregator.CountState create() { - return distinct ? new DistinctCountState() : new CountState(); - } + @Override + public CountAggregator.CountState create() { + return distinct ? new DistinctCountState() : new CountState(); + } - @Override - protected CountState iterate(ExprValue value, CountState state) { - state.count(value); - return state; - } + @Override + protected CountState iterate(ExprValue value, CountState state) { + state.count(value); + return state; + } - @Override - public String toString() { - return distinct - ? String.format(Locale.ROOT, "count(distinct %s)", format(getArguments())) - : String.format(Locale.ROOT, "count(%s)", format(getArguments())); - } + @Override + public String toString() { + return distinct + ? String.format(Locale.ROOT, "count(distinct %s)", format(getArguments())) + : String.format(Locale.ROOT, "count(%s)", format(getArguments())); + } - /** - * Count State. - */ - protected static class CountState implements AggregationState { - protected int count; + /** + * Count State. + */ + protected static class CountState implements AggregationState { + protected int count; - CountState() { - this.count = 0; - } + CountState() { + this.count = 0; + } - public void count(ExprValue value) { - count++; - } + public void count(ExprValue value) { + count++; + } - @Override - public ExprValue result() { - return ExprValueUtils.integerValue(count); + @Override + public ExprValue result() { + return ExprValueUtils.integerValue(count); + } } - } - protected static class DistinctCountState extends CountState { - private final Set distinctValues = new HashSet<>(); + protected static class DistinctCountState extends CountState { + private final Set distinctValues = new HashSet<>(); - @Override - public void count(ExprValue value) { - if (!distinctValues.contains(value)) { - distinctValues.add(value); - count++; - } + @Override + public void count(ExprValue value) { + if (!distinctValues.contains(value)) { + distinctValues.add(value); + count++; + } + } } - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/MaxAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/MaxAggregator.java index e9123c0ac2..5e8e4b931d 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/MaxAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/MaxAggregator.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.aggregation; import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_NULL; @@ -17,42 +16,40 @@ public class MaxAggregator extends Aggregator { - public MaxAggregator(List arguments, ExprCoreType returnType) { - super(BuiltinFunctionName.MAX.getName(), arguments, returnType); - } + public MaxAggregator(List arguments, ExprCoreType returnType) { + super(BuiltinFunctionName.MAX.getName(), arguments, returnType); + } - @Override - public MaxState create() { - return new MaxState(); - } + @Override + public MaxState create() { + return new MaxState(); + } - @Override - protected MaxState iterate(ExprValue value, MaxState state) { - state.max(value); - return state; - } + @Override + protected MaxState iterate(ExprValue value, MaxState state) { + state.max(value); + return state; + } - @Override - public String toString() { - return String.format("max(%s)", format(getArguments())); - } + @Override + public String toString() { + return String.format("max(%s)", format(getArguments())); + } - protected static class MaxState implements AggregationState { - private ExprValue maxResult; + protected static class MaxState implements AggregationState { + private ExprValue maxResult; - MaxState() { - maxResult = LITERAL_NULL; - } + MaxState() { + maxResult = LITERAL_NULL; + } - public void max(ExprValue value) { - maxResult = maxResult.isNull() ? value - : (maxResult.compareTo(value) > 0) - ? maxResult : value; - } + public void max(ExprValue value) { + maxResult = maxResult.isNull() ? value : (maxResult.compareTo(value) > 0) ? maxResult : value; + } - @Override - public ExprValue result() { - return maxResult; + @Override + public ExprValue result() { + return maxResult; + } } - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java index 897fe857ff..63bee9285b 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.aggregation; import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_NULL; @@ -21,43 +20,40 @@ */ public class MinAggregator extends Aggregator { - public MinAggregator(List arguments, ExprCoreType returnType) { - super(BuiltinFunctionName.MIN.getName(), arguments, returnType); - } - + public MinAggregator(List arguments, ExprCoreType returnType) { + super(BuiltinFunctionName.MIN.getName(), arguments, returnType); + } - @Override - public MinState create() { - return new MinState(); - } + @Override + public MinState create() { + return new MinState(); + } - @Override - protected MinState iterate(ExprValue value, MinState state) { - state.min(value); - return state; - } + @Override + protected MinState iterate(ExprValue value, MinState state) { + state.min(value); + return state; + } - @Override - public String toString() { - return String.format("min(%s)", format(getArguments())); - } + @Override + public String toString() { + return String.format("min(%s)", format(getArguments())); + } - protected static class MinState implements AggregationState { - private ExprValue minResult; + protected static class MinState implements AggregationState { + private ExprValue minResult; - MinState() { - minResult = LITERAL_NULL; - } + MinState() { + minResult = LITERAL_NULL; + } - public void min(ExprValue value) { - minResult = minResult.isNull() ? value - : (minResult.compareTo(value) < 0) - ? minResult : value; - } + public void min(ExprValue value) { + minResult = minResult.isNull() ? value : (minResult.compareTo(value) < 0) ? minResult : value; + } - @Override - public ExprValue result() { - return minResult; + @Override + public ExprValue result() { + return minResult; + } } - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/NamedAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/NamedAggregator.java index 510c5d1e45..14446f8b81 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/NamedAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/NamedAggregator.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.aggregation; import lombok.EqualsAndHashCode; @@ -19,61 +18,59 @@ @EqualsAndHashCode(callSuper = false) public class NamedAggregator extends Aggregator { - /** - * Aggregator name. - */ - private final String name; + /** + * Aggregator name. + */ + private final String name; - /** - * Aggregator that being named. - */ - @Getter - private final Aggregator delegated; + /** + * Aggregator that being named. + */ + @Getter + private final Aggregator delegated; - /** - * NamedAggregator. - * The aggregator properties {@link #condition} and {@link #distinct} - * are inherited by named aggregator to avoid errors introduced by the property inconsistency. - * - * @param name name - * @param delegated delegated - */ - public NamedAggregator( - String name, - Aggregator delegated) { - super(delegated.getFunctionName(), delegated.getArguments(), delegated.returnType); - this.name = name; - this.delegated = delegated; - this.condition = delegated.condition; - this.distinct = delegated.distinct; - } + /** + * NamedAggregator. + * The aggregator properties {@link #condition} and {@link #distinct} + * are inherited by named aggregator to avoid errors introduced by the property inconsistency. + * + * @param name name + * @param delegated delegated + */ + public NamedAggregator(String name, Aggregator delegated) { + super(delegated.getFunctionName(), delegated.getArguments(), delegated.returnType); + this.name = name; + this.delegated = delegated; + this.condition = delegated.condition; + this.distinct = delegated.distinct; + } - @Override - public AggregationState create() { - return delegated.create(); - } + @Override + public AggregationState create() { + return delegated.create(); + } - @Override - protected AggregationState iterate(ExprValue value, AggregationState state) { - return delegated.iterate(value, state); - } + @Override + protected AggregationState iterate(ExprValue value, AggregationState state) { + return delegated.iterate(value, state); + } - /** - * Get expression name using name or its alias (if it's present). - * @return expression name - */ - public String getName() { - return name; - } + /** + * Get expression name using name or its alias (if it's present). + * @return expression name + */ + public String getName() { + return name; + } - @Override - public T accept(ExpressionNodeVisitor visitor, C context) { - return visitor.visitNamedAggregator(this, context); - } + @Override + public T accept(ExpressionNodeVisitor visitor, C context) { + return visitor.visitNamedAggregator(this, context); + } - @Override - public String toString() { - return getName(); - } + @Override + public String toString() { + return getName(); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/StdDevAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/StdDevAggregator.java index 0cd8494449..5d74864bbe 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/StdDevAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/StdDevAggregator.java @@ -31,80 +31,70 @@ */ public class StdDevAggregator extends Aggregator { - private final boolean isSampleStdDev; - - /** - * Build Population Variance {@link VarianceAggregator}. - */ - public static Aggregator stddevPopulation(List arguments, - ExprCoreType returnType) { - return new StdDevAggregator(false, arguments, returnType); - } - - /** - * Build Sample Variance {@link VarianceAggregator}. - */ - public static Aggregator stddevSample(List arguments, - ExprCoreType returnType) { - return new StdDevAggregator(true, arguments, returnType); - } - - /** - * VarianceAggregator constructor. - * - * @param isSampleStdDev true for sample standard deviation aggregator, false for population - * standard deviation aggregator. - * @param arguments aggregator arguments. - * @param returnType aggregator return types. - */ - public StdDevAggregator( - Boolean isSampleStdDev, List arguments, ExprCoreType returnType) { - super( - isSampleStdDev - ? BuiltinFunctionName.STDDEV_SAMP.getName() - : BuiltinFunctionName.STDDEV_POP.getName(), - arguments, - returnType); - this.isSampleStdDev = isSampleStdDev; - } - - @Override - public StdDevAggregator.StdDevState create() { - return new StdDevAggregator.StdDevState(isSampleStdDev); - } - - @Override - protected StdDevAggregator.StdDevState iterate(ExprValue value, - StdDevAggregator.StdDevState state) { - state.evaluate(value); - return state; - } - - @Override - public String toString() { - return StringUtils.format( - "%s(%s)", isSampleStdDev ? "stddev_samp" : "stddev_pop", format(getArguments())); - } - - protected static class StdDevState implements AggregationState { - - private final StandardDeviation standardDeviation; - - private final List values = new ArrayList<>(); - - public StdDevState(boolean isSampleStdDev) { - this.standardDeviation = new StandardDeviation(isSampleStdDev); + private final boolean isSampleStdDev; + + /** + * Build Population Variance {@link VarianceAggregator}. + */ + public static Aggregator stddevPopulation(List arguments, ExprCoreType returnType) { + return new StdDevAggregator(false, arguments, returnType); + } + + /** + * Build Sample Variance {@link VarianceAggregator}. + */ + public static Aggregator stddevSample(List arguments, ExprCoreType returnType) { + return new StdDevAggregator(true, arguments, returnType); + } + + /** + * VarianceAggregator constructor. + * + * @param isSampleStdDev true for sample standard deviation aggregator, false for population + * standard deviation aggregator. + * @param arguments aggregator arguments. + * @param returnType aggregator return types. + */ + public StdDevAggregator(Boolean isSampleStdDev, List arguments, ExprCoreType returnType) { + super(isSampleStdDev ? BuiltinFunctionName.STDDEV_SAMP.getName() : BuiltinFunctionName.STDDEV_POP.getName(), arguments, returnType); + this.isSampleStdDev = isSampleStdDev; + } + + @Override + public StdDevAggregator.StdDevState create() { + return new StdDevAggregator.StdDevState(isSampleStdDev); } - public void evaluate(ExprValue value) { - values.add(value.doubleValue()); + @Override + protected StdDevAggregator.StdDevState iterate(ExprValue value, StdDevAggregator.StdDevState state) { + state.evaluate(value); + return state; } @Override - public ExprValue result() { - return values.size() == 0 - ? ExprNullValue.of() - : doubleValue(standardDeviation.evaluate(values.stream().mapToDouble(d -> d).toArray())); + public String toString() { + return StringUtils.format("%s(%s)", isSampleStdDev ? "stddev_samp" : "stddev_pop", format(getArguments())); + } + + protected static class StdDevState implements AggregationState { + + private final StandardDeviation standardDeviation; + + private final List values = new ArrayList<>(); + + public StdDevState(boolean isSampleStdDev) { + this.standardDeviation = new StandardDeviation(isSampleStdDev); + } + + public void evaluate(ExprValue value) { + values.add(value.doubleValue()); + } + + @Override + public ExprValue result() { + return values.size() == 0 + ? ExprNullValue.of() + : doubleValue(standardDeviation.evaluate(values.stream().mapToDouble(d -> d).toArray())); + } } - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java index f5b042034a..7bb572fd56 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.aggregation; import static org.opensearch.sql.data.model.ExprValueUtils.doubleValue; @@ -33,68 +32,67 @@ */ public class SumAggregator extends Aggregator { - public SumAggregator(List arguments, ExprCoreType returnType) { - super(BuiltinFunctionName.SUM.getName(), arguments, returnType); - } - - @Override - public SumState create() { - return new SumState(returnType); - } - - @Override - protected SumState iterate(ExprValue value, SumState state) { - state.isEmptyCollection = false; - state.add(value); - return state; - } - - @Override - public String toString() { - return String.format(Locale.ROOT, "sum(%s)", format(getArguments())); - } + public SumAggregator(List arguments, ExprCoreType returnType) { + super(BuiltinFunctionName.SUM.getName(), arguments, returnType); + } - /** - * Sum State. - */ - protected static class SumState implements AggregationState { + @Override + public SumState create() { + return new SumState(returnType); + } - private final ExprCoreType type; - private ExprValue sumResult; - private boolean isEmptyCollection; + @Override + protected SumState iterate(ExprValue value, SumState state) { + state.isEmptyCollection = false; + state.add(value); + return state; + } - SumState(ExprCoreType type) { - this.type = type; - sumResult = ExprValueUtils.integerValue(0); - isEmptyCollection = true; + @Override + public String toString() { + return String.format(Locale.ROOT, "sum(%s)", format(getArguments())); } /** - * Add value to current sumResult. + * Sum State. */ - public void add(ExprValue value) { - switch (type) { - case INTEGER: - sumResult = integerValue(getIntegerValue(sumResult) + getIntegerValue(value)); - break; - case LONG: - sumResult = longValue(getLongValue(sumResult) + getLongValue(value)); - break; - case FLOAT: - sumResult = floatValue(getFloatValue(sumResult) + getFloatValue(value)); - break; - case DOUBLE: - sumResult = doubleValue(getDoubleValue(sumResult) + getDoubleValue(value)); - break; - default: - throw new ExpressionEvaluationException( - String.format("unexpected type [%s] in sum aggregation", type)); - } - } + protected static class SumState implements AggregationState { - @Override - public ExprValue result() { - return isEmptyCollection ? ExprNullValue.of() : sumResult; + private final ExprCoreType type; + private ExprValue sumResult; + private boolean isEmptyCollection; + + SumState(ExprCoreType type) { + this.type = type; + sumResult = ExprValueUtils.integerValue(0); + isEmptyCollection = true; + } + + /** + * Add value to current sumResult. + */ + public void add(ExprValue value) { + switch (type) { + case INTEGER: + sumResult = integerValue(getIntegerValue(sumResult) + getIntegerValue(value)); + break; + case LONG: + sumResult = longValue(getLongValue(sumResult) + getLongValue(value)); + break; + case FLOAT: + sumResult = floatValue(getFloatValue(sumResult) + getFloatValue(value)); + break; + case DOUBLE: + sumResult = doubleValue(getDoubleValue(sumResult) + getDoubleValue(value)); + break; + default: + throw new ExpressionEvaluationException(String.format("unexpected type [%s] in sum aggregation", type)); + } + } + + @Override + public ExprValue result() { + return isEmptyCollection ? ExprNullValue.of() : sumResult; + } } - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/TakeAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/TakeAggregator.java index cff08bb098..c64658caab 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/TakeAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/TakeAggregator.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.aggregation; import static org.opensearch.sql.utils.ExpressionUtils.format; @@ -23,53 +22,53 @@ */ public class TakeAggregator extends Aggregator { - public TakeAggregator(List arguments, ExprCoreType returnType) { - super(BuiltinFunctionName.TAKE.getName(), arguments, returnType); - } + public TakeAggregator(List arguments, ExprCoreType returnType) { + super(BuiltinFunctionName.TAKE.getName(), arguments, returnType); + } - @Override - public TakeState create() { - return new TakeState(getArguments().get(1).valueOf().integerValue()); - } + @Override + public TakeState create() { + return new TakeState(getArguments().get(1).valueOf().integerValue()); + } - @Override - protected TakeState iterate(ExprValue value, TakeState state) { - state.take(value); - return state; - } + @Override + protected TakeState iterate(ExprValue value, TakeState state) { + state.take(value); + return state; + } - @Override - public String toString() { - return String.format(Locale.ROOT, "take(%s)", format(getArguments())); - } + @Override + public String toString() { + return String.format(Locale.ROOT, "take(%s)", format(getArguments())); + } - /** - * Take State. - */ - protected static class TakeState implements AggregationState { - protected int index; - protected int size; - protected List hits; + /** + * Take State. + */ + protected static class TakeState implements AggregationState { + protected int index; + protected int size; + protected List hits; - TakeState(int size) { - if (size <= 0) { - throw new IllegalArgumentException("size must be greater than 0"); - } - this.index = 0; - this.size = size; - this.hits = new ArrayList<>(); - } + TakeState(int size) { + if (size <= 0) { + throw new IllegalArgumentException("size must be greater than 0"); + } + this.index = 0; + this.size = size; + this.hits = new ArrayList<>(); + } - public void take(ExprValue value) { - if (index < size) { - hits.add(value); - } - index++; - } + public void take(ExprValue value) { + if (index < size) { + hits.add(value); + } + index++; + } - @Override - public ExprValue result() { - return new ExprCollectionValue(hits); + @Override + public ExprValue result() { + return new ExprCollectionValue(hits); + } } - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/VarianceAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/VarianceAggregator.java index bd9f0948f6..23e7e6dd38 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/VarianceAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/VarianceAggregator.java @@ -31,79 +31,68 @@ */ public class VarianceAggregator extends Aggregator { - private final boolean isSampleVariance; - - /** - * Build Population Variance {@link VarianceAggregator}. - */ - public static Aggregator variancePopulation(List arguments, - ExprCoreType returnType) { - return new VarianceAggregator(false, arguments, returnType); - } - - /** - * Build Sample Variance {@link VarianceAggregator}. - */ - public static Aggregator varianceSample(List arguments, - ExprCoreType returnType) { - return new VarianceAggregator(true, arguments, returnType); - } - - /** - * VarianceAggregator constructor. - * - * @param isSampleVariance true for sample variance aggregator, false for population variance - * aggregator. - * @param arguments aggregator arguments. - * @param returnType aggregator return types. - */ - public VarianceAggregator( - Boolean isSampleVariance, List arguments, ExprCoreType returnType) { - super( - isSampleVariance - ? BuiltinFunctionName.VARSAMP.getName() - : BuiltinFunctionName.VARPOP.getName(), - arguments, - returnType); - this.isSampleVariance = isSampleVariance; - } - - @Override - public VarianceState create() { - return new VarianceState(isSampleVariance); - } - - @Override - protected VarianceState iterate(ExprValue value, VarianceState state) { - state.evaluate(value); - return state; - } - - @Override - public String toString() { - return StringUtils.format( - "%s(%s)", isSampleVariance ? "var_samp" : "var_pop", format(getArguments())); - } - - protected static class VarianceState implements AggregationState { - - private final Variance variance; - - private final List values = new ArrayList<>(); - - public VarianceState(boolean isSampleVariance) { - this.variance = new Variance(isSampleVariance); + private final boolean isSampleVariance; + + /** + * Build Population Variance {@link VarianceAggregator}. + */ + public static Aggregator variancePopulation(List arguments, ExprCoreType returnType) { + return new VarianceAggregator(false, arguments, returnType); + } + + /** + * Build Sample Variance {@link VarianceAggregator}. + */ + public static Aggregator varianceSample(List arguments, ExprCoreType returnType) { + return new VarianceAggregator(true, arguments, returnType); + } + + /** + * VarianceAggregator constructor. + * + * @param isSampleVariance true for sample variance aggregator, false for population variance + * aggregator. + * @param arguments aggregator arguments. + * @param returnType aggregator return types. + */ + public VarianceAggregator(Boolean isSampleVariance, List arguments, ExprCoreType returnType) { + super(isSampleVariance ? BuiltinFunctionName.VARSAMP.getName() : BuiltinFunctionName.VARPOP.getName(), arguments, returnType); + this.isSampleVariance = isSampleVariance; + } + + @Override + public VarianceState create() { + return new VarianceState(isSampleVariance); } - public void evaluate(ExprValue value) { - values.add(value.doubleValue()); + @Override + protected VarianceState iterate(ExprValue value, VarianceState state) { + state.evaluate(value); + return state; } @Override - public ExprValue result() { - return values.size() == 0 - ? ExprNullValue.of() - : doubleValue(variance.evaluate(values.stream().mapToDouble(d -> d).toArray())); + public String toString() { + return StringUtils.format("%s(%s)", isSampleVariance ? "var_samp" : "var_pop", format(getArguments())); + } + + protected static class VarianceState implements AggregationState { + + private final Variance variance; + + private final List values = new ArrayList<>(); + + public VarianceState(boolean isSampleVariance) { + this.variance = new Variance(isSampleVariance); + } + + public void evaluate(ExprValue value) { + values.add(value.doubleValue()); + } + + @Override + public ExprValue result() { + return values.size() == 0 ? ExprNullValue.of() : doubleValue(variance.evaluate(values.stream().mapToDouble(d -> d).toArray())); + } } - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/conditional/cases/CaseClause.java b/core/src/main/java/org/opensearch/sql/expression/conditional/cases/CaseClause.java index ad7860a6dc..f0c77af4f8 100644 --- a/core/src/main/java/org/opensearch/sql/expression/conditional/cases/CaseClause.java +++ b/core/src/main/java/org/opensearch/sql/expression/conditional/cases/CaseClause.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.conditional.cases; import static org.opensearch.sql.data.type.ExprCoreType.UNDEFINED; @@ -32,75 +31,72 @@ @ToString public class CaseClause extends FunctionExpression { - /** - * List of WHEN clauses. - */ - private final List whenClauses; - - /** - * Default result if none of WHEN conditions match. - */ - private final Expression defaultResult; - - /** - * Initialize case clause. - */ - public CaseClause(List whenClauses, Expression defaultResult) { - super(FunctionName.of("case"), concatArgs(whenClauses, defaultResult)); - this.whenClauses = whenClauses; - this.defaultResult = defaultResult; - } - - @Override - public ExprValue valueOf(Environment valueEnv) { - for (WhenClause when : whenClauses) { - if (when.isTrue(valueEnv)) { - return when.valueOf(valueEnv); - } + /** + * List of WHEN clauses. + */ + private final List whenClauses; + + /** + * Default result if none of WHEN conditions match. + */ + private final Expression defaultResult; + + /** + * Initialize case clause. + */ + public CaseClause(List whenClauses, Expression defaultResult) { + super(FunctionName.of("case"), concatArgs(whenClauses, defaultResult)); + this.whenClauses = whenClauses; + this.defaultResult = defaultResult; + } + + @Override + public ExprValue valueOf(Environment valueEnv) { + for (WhenClause when : whenClauses) { + if (when.isTrue(valueEnv)) { + return when.valueOf(valueEnv); + } + } + return (defaultResult == null) ? ExprNullValue.of() : defaultResult.valueOf(valueEnv); + } + + @Override + public ExprType type() { + List types = allResultTypes(); + + // Return undefined if all WHEN/ELSE return NULL + return types.isEmpty() ? UNDEFINED : types.get(0); } - return (defaultResult == null) ? ExprNullValue.of() : defaultResult.valueOf(valueEnv); - } - - @Override - public ExprType type() { - List types = allResultTypes(); - - // Return undefined if all WHEN/ELSE return NULL - return types.isEmpty() ? UNDEFINED : types.get(0); - } - - @Override - public T accept(ExpressionNodeVisitor visitor, C context) { - return visitor.visitCase(this, context); - } - - /** - * Get types of each result in WHEN clause and ELSE clause. - * Exclude UNKNOWN type from NULL literal which means NULL in THEN or ELSE clause - * is not included in result. - * @return all result types. Use list so caller can generate friendly error message. - */ - public List allResultTypes() { - List types = whenClauses.stream() - .map(WhenClause::type) - .collect(Collectors.toList()); - if (defaultResult != null) { - types.add(defaultResult.type()); + + @Override + public T accept(ExpressionNodeVisitor visitor, C context) { + return visitor.visitCase(this, context); } - types.removeIf(type -> (type == UNDEFINED)); - return types; - } + /** + * Get types of each result in WHEN clause and ELSE clause. + * Exclude UNKNOWN type from NULL literal which means NULL in THEN or ELSE clause + * is not included in result. + * @return all result types. Use list so caller can generate friendly error message. + */ + public List allResultTypes() { + List types = whenClauses.stream().map(WhenClause::type).collect(Collectors.toList()); + if (defaultResult != null) { + types.add(defaultResult.type()); + } + + types.removeIf(type -> (type == UNDEFINED)); + return types; + } - private static List concatArgs(List whenClauses, - Expression defaultResult) { - ImmutableList.Builder args = ImmutableList.builder(); - whenClauses.forEach(args::add); + private static List concatArgs(List whenClauses, Expression defaultResult) { + ImmutableList.Builder args = ImmutableList.builder(); + whenClauses.forEach(args::add); - if (defaultResult != null) { - args.add(defaultResult); + if (defaultResult != null) { + args.add(defaultResult); + } + return args.build(); } - return args.build(); - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/conditional/cases/WhenClause.java b/core/src/main/java/org/opensearch/sql/expression/conditional/cases/WhenClause.java index fd2eeab983..411feae2b3 100644 --- a/core/src/main/java/org/opensearch/sql/expression/conditional/cases/WhenClause.java +++ b/core/src/main/java/org/opensearch/sql/expression/conditional/cases/WhenClause.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.conditional.cases; import com.google.common.collect.ImmutableList; @@ -26,51 +25,51 @@ @ToString public class WhenClause extends FunctionExpression { - /** - * Condition that must be a predicate. - */ - private final Expression condition; + /** + * Condition that must be a predicate. + */ + private final Expression condition; - /** - * Result to return if condition is evaluated to true. - */ - private final Expression result; + /** + * Result to return if condition is evaluated to true. + */ + private final Expression result; - /** - * Initialize when clause. - */ - public WhenClause(Expression condition, Expression result) { - super(FunctionName.of("when"), ImmutableList.of(condition, result)); - this.condition = condition; - this.result = result; - } + /** + * Initialize when clause. + */ + public WhenClause(Expression condition, Expression result) { + super(FunctionName.of("when"), ImmutableList.of(condition, result)); + this.condition = condition; + this.result = result; + } - /** - * Evaluate when condition. - * @param valueEnv value env - * @return is condition satisfied - */ - public boolean isTrue(Environment valueEnv) { - ExprValue result = condition.valueOf(valueEnv); - if (result.isMissing() || result.isNull()) { - return false; + /** + * Evaluate when condition. + * @param valueEnv value env + * @return is condition satisfied + */ + public boolean isTrue(Environment valueEnv) { + ExprValue result = condition.valueOf(valueEnv); + if (result.isMissing() || result.isNull()) { + return false; + } + return result.booleanValue(); } - return result.booleanValue(); - } - @Override - public ExprValue valueOf(Environment valueEnv) { - return result.valueOf(valueEnv); - } + @Override + public ExprValue valueOf(Environment valueEnv) { + return result.valueOf(valueEnv); + } - @Override - public ExprType type() { - return result.type(); - } + @Override + public ExprType type() { + return result.type(); + } - @Override - public T accept(ExpressionNodeVisitor visitor, C context) { - return visitor.visitWhen(this, context); - } + @Override + public T accept(ExpressionNodeVisitor visitor, C context) { + return visitor.visitWhen(this, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/datetime/CalendarLookup.java b/core/src/main/java/org/opensearch/sql/expression/datetime/CalendarLookup.java index c5b6343991..07b1388a4a 100644 --- a/core/src/main/java/org/opensearch/sql/expression/datetime/CalendarLookup.java +++ b/core/src/main/java/org/opensearch/sql/expression/datetime/CalendarLookup.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.datetime; import com.google.common.collect.ImmutableList; @@ -16,68 +15,65 @@ @AllArgsConstructor class CalendarLookup { - /** - * Get a calendar for the specific mode. - * @param mode Mode to get calendar for. - * @param date Date to get calendar for. - */ - private static Calendar getCalendar(int mode, LocalDate date) { - if ((mode < 0) || (mode > 7)) { - throw new SemanticCheckException( - String.format("mode:%s is invalid, please use mode value between 0-7", mode)); - } - int day = (mode % 2 == 0) ? Calendar.SUNDAY : Calendar.MONDAY; - if (ImmutableList.of(1, 3).contains(mode)) { - return getCalendar(day, 5, date); - } else if (ImmutableList.of(4, 6).contains(mode)) { - return getCalendar(day, 4, date); - } else { - return getCalendar(day, 7, date); + /** + * Get a calendar for the specific mode. + * @param mode Mode to get calendar for. + * @param date Date to get calendar for. + */ + private static Calendar getCalendar(int mode, LocalDate date) { + if ((mode < 0) || (mode > 7)) { + throw new SemanticCheckException(String.format("mode:%s is invalid, please use mode value between 0-7", mode)); + } + int day = (mode % 2 == 0) ? Calendar.SUNDAY : Calendar.MONDAY; + if (ImmutableList.of(1, 3).contains(mode)) { + return getCalendar(day, 5, date); + } else if (ImmutableList.of(4, 6).contains(mode)) { + return getCalendar(day, 4, date); + } else { + return getCalendar(day, 7, date); + } } - } - /** - * Set first day of week, minimal days in first week and date in calendar. - * @param firstDayOfWeek the given first day of the week. - * @param minimalDaysInWeek the given minimal days required in the first week of the year. - * @param date the given date. - */ - private static Calendar getCalendar(int firstDayOfWeek, int minimalDaysInWeek, LocalDate date) { - Calendar calendar = Calendar.getInstance(); - calendar.setFirstDayOfWeek(firstDayOfWeek); - calendar.setMinimalDaysInFirstWeek(minimalDaysInWeek); - calendar.set(date.getYear(), date.getMonthValue() - 1, date.getDayOfMonth()); - return calendar; - } + /** + * Set first day of week, minimal days in first week and date in calendar. + * @param firstDayOfWeek the given first day of the week. + * @param minimalDaysInWeek the given minimal days required in the first week of the year. + * @param date the given date. + */ + private static Calendar getCalendar(int firstDayOfWeek, int minimalDaysInWeek, LocalDate date) { + Calendar calendar = Calendar.getInstance(); + calendar.setFirstDayOfWeek(firstDayOfWeek); + calendar.setMinimalDaysInFirstWeek(minimalDaysInWeek); + calendar.set(date.getYear(), date.getMonthValue() - 1, date.getDayOfMonth()); + return calendar; + } - /** - * Returns week number for date according to mode. - * @param mode Integer for mode. Valid mode values are 0 to 7. - * @param date LocalDate for date. - */ - static int getWeekNumber(int mode, LocalDate date) { - Calendar calendar = getCalendar(mode, date); - int weekNumber = calendar.get(Calendar.WEEK_OF_YEAR); - if ((weekNumber > 51) - && (calendar.get(Calendar.DAY_OF_MONTH) < 7) - && Arrays.asList(0, 1, 4, 5).contains(mode)) { - weekNumber = 0; + /** + * Returns week number for date according to mode. + * @param mode Integer for mode. Valid mode values are 0 to 7. + * @param date LocalDate for date. + */ + static int getWeekNumber(int mode, LocalDate date) { + Calendar calendar = getCalendar(mode, date); + int weekNumber = calendar.get(Calendar.WEEK_OF_YEAR); + if ((weekNumber > 51) && (calendar.get(Calendar.DAY_OF_MONTH) < 7) && Arrays.asList(0, 1, 4, 5).contains(mode)) { + weekNumber = 0; + } + return weekNumber; } - return weekNumber; - } - /** - * Returns year for date according to mode. - * @param mode Integer for mode. Valid mode values are 0 to 7. - * @param date LocalDate for date. - */ - static int getYearNumber(int mode, LocalDate date) { - Calendar calendar = getCalendar(mode, date); - int weekNumber = getWeekNumber(mode, date); - int yearNumber = calendar.get(Calendar.YEAR); - if ((weekNumber > 51) && (calendar.get(Calendar.DAY_OF_MONTH) < 7)) { - yearNumber--; + /** + * Returns year for date according to mode. + * @param mode Integer for mode. Valid mode values are 0 to 7. + * @param date LocalDate for date. + */ + static int getYearNumber(int mode, LocalDate date) { + Calendar calendar = getCalendar(mode, date); + int weekNumber = getWeekNumber(mode, date); + int yearNumber = calendar.get(Calendar.YEAR); + if ((weekNumber > 51) && (calendar.get(Calendar.DAY_OF_MONTH) < 7)) { + yearNumber--; + } + return yearNumber; } - return yearNumber; - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/datetime/DateTimeFormatterUtil.java b/core/src/main/java/org/opensearch/sql/expression/datetime/DateTimeFormatterUtil.java index 55bfa67f3f..65dcd4e5c3 100644 --- a/core/src/main/java/org/opensearch/sql/expression/datetime/DateTimeFormatterUtil.java +++ b/core/src/main/java/org/opensearch/sql/expression/datetime/DateTimeFormatterUtil.java @@ -32,316 +32,300 @@ * Java SimpleDateTime format. */ class DateTimeFormatterUtil { - private static final int SUFFIX_SPECIAL_START_TH = 11; - private static final int SUFFIX_SPECIAL_END_TH = 13; - private static final String SUFFIX_SPECIAL_TH = "th"; - - private static final String NANO_SEC_FORMAT = "'%06d'"; - - private static final Map SUFFIX_CONVERTER = - ImmutableMap.builder() - .put(1, "st").put(2, "nd").put(3, "rd").build(); - - // The following have special cases that need handling outside of the format options provided - // by the DateTimeFormatter class. - interface DateTimeFormatHandler { - String getFormat(LocalDateTime date); - } - - private static final Map DATE_HANDLERS = - ImmutableMap.builder() - .put("%a", (date) -> "EEE") // %a => EEE - Abbreviated weekday name (Sun..Sat) - .put("%b", (date) -> "LLL") // %b => LLL - Abbreviated month name (Jan..Dec) - .put("%c", (date) -> "MM") // %c => MM - Month, numeric (0..12) - .put("%d", (date) -> "dd") // %d => dd - Day of the month, numeric (00..31) - .put("%e", (date) -> "d") // %e => d - Day of the month, numeric (0..31) - .put("%H", (date) -> "HH") // %H => HH - (00..23) - .put("%h", (date) -> "hh") // %h => hh - (01..12) - .put("%I", (date) -> "hh") // %I => hh - (01..12) - .put("%i", (date) -> "mm") // %i => mm - Minutes, numeric (00..59) - .put("%j", (date) -> "DDD") // %j => DDD - (001..366) - .put("%k", (date) -> "H") // %k => H - (0..23) - .put("%l", (date) -> "h") // %l => h - (1..12) - .put("%p", (date) -> "a") // %p => a - AM or PM - .put("%M", (date) -> "LLLL") // %M => LLLL - Month name (January..December) - .put("%m", (date) -> "MM") // %m => MM - Month, numeric (00..12) - .put("%r", (date) -> "hh:mm:ss a") // %r => hh:mm:ss a - hh:mm:ss followed by AM or PM - .put("%S", (date) -> "ss") // %S => ss - Seconds (00..59) - .put("%s", (date) -> "ss") // %s => ss - Seconds (00..59) - .put("%T", (date) -> "HH:mm:ss") // %T => HH:mm:ss - .put("%W", (date) -> "EEEE") // %W => EEEE - Weekday name (Sunday..Saturday) - .put("%Y", (date) -> "yyyy") // %Y => yyyy - Year, numeric, 4 digits - .put("%y", (date) -> "yy") // %y => yy - Year, numeric, 2 digits - // The following are not directly supported by DateTimeFormatter. - .put("%D", (date) -> // %w - Day of month with English suffix - String.format("'%d%s'", date.getDayOfMonth(), getSuffix(date.getDayOfMonth()))) - .put("%f", (date) -> // %f - Microseconds - String.format(NANO_SEC_FORMAT, (date.getNano() / 1000))) - .put("%w", (date) -> // %w - Day of week (0 indexed) - String.format("'%d'", date.getDayOfWeek().getValue())) - .put("%U", (date) -> // %U Week where Sunday is the first day - WEEK() mode 0 - String.format("'%d'", CalendarLookup.getWeekNumber(0, date.toLocalDate()))) - .put("%u", (date) -> // %u Week where Monday is the first day - WEEK() mode 1 - String.format("'%d'", CalendarLookup.getWeekNumber(1, date.toLocalDate()))) - .put("%V", (date) -> // %V Week where Sunday is the first day - WEEK() mode 2 used with %X - String.format("'%d'", CalendarLookup.getWeekNumber(2, date.toLocalDate()))) - .put("%v", (date) -> // %v Week where Monday is the first day - WEEK() mode 3 used with %x - String.format("'%d'", CalendarLookup.getWeekNumber(3, date.toLocalDate()))) - .put("%X", (date) -> // %X Year for week where Sunday is the first day, 4 digits used with %V - String.format("'%d'", CalendarLookup.getYearNumber(2, date.toLocalDate()))) - .put("%x", (date) -> // %x Year for week where Monday is the first day, 4 digits used with %v - String.format("'%d'", CalendarLookup.getYearNumber(3, date.toLocalDate()))) - .build(); - - //Handlers for the time_format function. - //Some format specifiers return 0 or null to align with MySQL. - //https://dev.mysql.com/doc/refman/8.0/en/date-and-time-functions.html#function_time-format - private static final Map TIME_HANDLERS = - ImmutableMap.builder() - .put("%a", (date) -> null) - .put("%b", (date) -> null) - .put("%c", (date) -> "0") - .put("%d", (date) -> "00") - .put("%e", (date) -> "0") - .put("%H", (date) -> "HH") // %H => HH - (00..23) - .put("%h", (date) -> "hh") // %h => hh - (01..12) - .put("%I", (date) -> "hh") // %I => hh - (01..12) - .put("%i", (date) -> "mm") // %i => mm - Minutes, numeric (00..59) - .put("%j", (date) -> null) - .put("%k", (date) -> "H") // %k => H - (0..23) - .put("%l", (date) -> "h") // %l => h - (1..12) - .put("%p", (date) -> "a") // %p => a - AM or PM - .put("%M", (date) -> null) - .put("%m", (date) -> "00") - .put("%r", (date) -> "hh:mm:ss a") // %r => hh:mm:ss a - hh:mm:ss followed by AM or PM - .put("%S", (date) -> "ss") // %S => ss - Seconds (00..59) - .put("%s", (date) -> "ss") // %s => ss - Seconds (00..59) - .put("%T", (date) -> "HH:mm:ss") // %T => HH:mm:ss - .put("%W", (date) -> null) - .put("%Y", (date) -> "0000") - .put("%y", (date) -> "00") - .put("%D", (date) -> null) - .put("%f", (date) -> // %f - Microseconds - String.format(NANO_SEC_FORMAT, (date.getNano() / 1000))) - .put("%w", (date) -> null) - .put("%U", (date) -> null) - .put("%u", (date) -> null) - .put("%V", (date) -> null) - .put("%v", (date) -> null) - .put("%X", (date) -> null) - .put("%x", (date) -> null) - .build(); - - private static final Map STR_TO_DATE_FORMATS = - ImmutableMap.builder() - .put("%a", "EEE") // %a => EEE - Abbreviated weekday name (Sun..Sat) - .put("%b", "LLL") // %b => LLL - Abbreviated month name (Jan..Dec) - .put("%c", "M") // %c => MM - Month, numeric (0..12) - .put("%d", "d") // %d => dd - Day of the month, numeric (00..31) - .put("%e", "d") // %e => d - Day of the month, numeric (0..31) - .put("%H", "H") // %H => HH - (00..23) - .put("%h", "H") // %h => hh - (01..12) - .put("%I", "h") // %I => hh - (01..12) - .put("%i", "m") // %i => mm - Minutes, numeric (00..59) - .put("%j", "DDD") // %j => DDD - (001..366) - .put("%k", "H") // %k => H - (0..23) - .put("%l", "h") // %l => h - (1..12) - .put("%p", "a") // %p => a - AM or PM - .put("%M", "LLLL") // %M => LLLL - Month name (January..December) - .put("%m", "M") // %m => MM - Month, numeric (00..12) - .put("%r", "hh:mm:ss a") // %r => hh:mm:ss a - hh:mm:ss followed by AM or PM - .put("%S", "s") // %S => ss - Seconds (00..59) - .put("%s", "s") // %s => ss - Seconds (00..59) - .put("%T", "HH:mm:ss") // %T => HH:mm:ss - .put("%W", "EEEE") // %W => EEEE - Weekday name (Sunday..Saturday) - .put("%Y", "u") // %Y => yyyy - Year, numeric, 4 digits - .put("%y", "u") // %y => yy - Year, numeric, 2 digits - .put("%f", "n") // %f => n - Nanoseconds - //The following have been implemented but cannot be aligned with - // MySQL due to the limitations of the DatetimeFormatter - .put("%D", "d") // %w - Day of month with English suffix - .put("%w", "e") // %w - Day of week (0 indexed) - .put("%U", "w") // %U Week where Sunday is the first day - WEEK() mode 0 - .put("%u", "w") // %u Week where Monday is the first day - WEEK() mode 1 - .put("%V", "w") // %V Week where Sunday is the first day - WEEK() mode 2 - .put("%v", "w") // %v Week where Monday is the first day - WEEK() mode 3 - .put("%X", "u") // %X Year for week where Sunday is the first day - .put("%x", "u") // %x Year for week where Monday is the first day - .build(); - - private static final Pattern pattern = Pattern.compile("%."); - private static final Pattern CHARACTERS_WITH_NO_MOD_LITERAL_BEHIND_PATTERN - = Pattern.compile("(? SUFFIX_CONVERTER = ImmutableMap.builder() + .put(1, "st") + .put(2, "nd") + .put(3, "rd") + .build(); + + // The following have special cases that need handling outside of the format options provided + // by the DateTimeFormatter class. + interface DateTimeFormatHandler { + String getFormat(LocalDateTime date); } - m.appendTail(cleanFormat); - - return cleanFormat; - } - - /** - * Helper function to format a DATETIME according to a provided handler and matcher. - * @param formatExpr ExprValue containing the format expression - * @param handler Map of character patterns to their associated datetime format - * @param datetime The datetime argument being formatted - * @return A formatted string expression - */ - static ExprValue getFormattedString(ExprValue formatExpr, - Map handler, - LocalDateTime datetime) { - StringBuffer cleanFormat = getCleanFormat(formatExpr); - - final Matcher matcher = pattern.matcher(cleanFormat.toString()); - final StringBuffer format = new StringBuffer(); - try { - while (matcher.find()) { - matcher.appendReplacement(format, - handler.getOrDefault(matcher.group(), (d) -> - String.format("'%s'", matcher.group().replaceFirst(MOD_LITERAL, ""))) - .getFormat(datetime)); - } - } catch (Exception e) { - return ExprNullValue.of(); - } - matcher.appendTail(format); - - // English Locale matches SQL requirements. - // 'AM'/'PM' instead of 'a.m.'/'p.m.' - // 'Sat' instead of 'Sat.' etc - return new ExprStringValue(datetime.format( - DateTimeFormatter.ofPattern(format.toString(), Locale.ENGLISH))); - } - - /** - * Format the date using the date format String. - * @param dateExpr the date ExprValue of Date/Datetime/Timestamp/String type. - * @param formatExpr the format ExprValue of String type. - * @return Date formatted using format and returned as a String. - */ - static ExprValue getFormattedDate(ExprValue dateExpr, ExprValue formatExpr) { - final LocalDateTime date = dateExpr.datetimeValue(); - return getFormattedString(formatExpr, DATE_HANDLERS, date); - } - - static ExprValue getFormattedDateOfToday(ExprValue formatExpr, ExprValue time, Clock current) { - final LocalDateTime date = LocalDateTime.of(LocalDate.now(current), time.timeValue()); - - return getFormattedString(formatExpr, DATE_HANDLERS, date); - } - - /** - * Format the date using the date format String. - * @param timeExpr the date ExprValue of Date/Datetime/Timestamp/String type. - * @param formatExpr the format ExprValue of String type. - * @return Date formatted using format and returned as a String. - */ - static ExprValue getFormattedTime(ExprValue timeExpr, ExprValue formatExpr) { - //Initializes DateTime with LocalDate.now(). This is safe because the date is ignored. - //The time_format function will only return 0 or null for invalid string format specifiers. - final LocalDateTime time = LocalDateTime.of(LocalDate.now(), timeExpr.timeValue()); - - return getFormattedString(formatExpr, TIME_HANDLERS, time); - } - - private static boolean canGetDate(TemporalAccessor ta) { - return (ta.isSupported(ChronoField.YEAR) - && ta.isSupported(ChronoField.MONTH_OF_YEAR) - && ta.isSupported(ChronoField.DAY_OF_MONTH)); - } - - private static boolean canGetTime(TemporalAccessor ta) { - return (ta.isSupported(ChronoField.HOUR_OF_DAY) - && ta.isSupported(ChronoField.MINUTE_OF_HOUR) - && ta.isSupported(ChronoField.SECOND_OF_MINUTE)); - } - - static ExprValue parseStringWithDateOrTime(FunctionProperties fp, - ExprValue datetimeStringExpr, - ExprValue formatExpr) { - - //Replace patterns with % for Java DateTimeFormatter - StringBuffer cleanFormat = getCleanFormat(formatExpr); - final Matcher matcher = pattern.matcher(cleanFormat.toString()); - final StringBuffer format = new StringBuffer(); - - while (matcher.find()) { - matcher.appendReplacement(format, - STR_TO_DATE_FORMATS.getOrDefault(matcher.group(), - String.format("'%s'", matcher.group().replaceFirst(MOD_LITERAL, "")))); + + private static final Map DATE_HANDLERS = ImmutableMap.builder() + .put("%a", (date) -> "EEE") // %a => EEE - Abbreviated weekday name (Sun..Sat) + .put("%b", (date) -> "LLL") // %b => LLL - Abbreviated month name (Jan..Dec) + .put("%c", (date) -> "MM") // %c => MM - Month, numeric (0..12) + .put("%d", (date) -> "dd") // %d => dd - Day of the month, numeric (00..31) + .put("%e", (date) -> "d") // %e => d - Day of the month, numeric (0..31) + .put("%H", (date) -> "HH") // %H => HH - (00..23) + .put("%h", (date) -> "hh") // %h => hh - (01..12) + .put("%I", (date) -> "hh") // %I => hh - (01..12) + .put("%i", (date) -> "mm") // %i => mm - Minutes, numeric (00..59) + .put("%j", (date) -> "DDD") // %j => DDD - (001..366) + .put("%k", (date) -> "H") // %k => H - (0..23) + .put("%l", (date) -> "h") // %l => h - (1..12) + .put("%p", (date) -> "a") // %p => a - AM or PM + .put("%M", (date) -> "LLLL") // %M => LLLL - Month name (January..December) + .put("%m", (date) -> "MM") // %m => MM - Month, numeric (00..12) + .put("%r", (date) -> "hh:mm:ss a") // %r => hh:mm:ss a - hh:mm:ss followed by AM or PM + .put("%S", (date) -> "ss") // %S => ss - Seconds (00..59) + .put("%s", (date) -> "ss") // %s => ss - Seconds (00..59) + .put("%T", (date) -> "HH:mm:ss") // %T => HH:mm:ss + .put("%W", (date) -> "EEEE") // %W => EEEE - Weekday name (Sunday..Saturday) + .put("%Y", (date) -> "yyyy") // %Y => yyyy - Year, numeric, 4 digits + .put("%y", (date) -> "yy") // %y => yy - Year, numeric, 2 digits + // The following are not directly supported by DateTimeFormatter. + .put("%D", (date) -> // %w - Day of month with English suffix + String.format("'%d%s'", date.getDayOfMonth(), getSuffix(date.getDayOfMonth()))) + .put("%f", (date) -> // %f - Microseconds + String.format(NANO_SEC_FORMAT, (date.getNano() / 1000))) + .put("%w", (date) -> // %w - Day of week (0 indexed) + String.format("'%d'", date.getDayOfWeek().getValue())) + .put("%U", (date) -> // %U Week where Sunday is the first day - WEEK() mode 0 + String.format("'%d'", CalendarLookup.getWeekNumber(0, date.toLocalDate()))) + .put("%u", (date) -> // %u Week where Monday is the first day - WEEK() mode 1 + String.format("'%d'", CalendarLookup.getWeekNumber(1, date.toLocalDate()))) + .put("%V", (date) -> // %V Week where Sunday is the first day - WEEK() mode 2 used with %X + String.format("'%d'", CalendarLookup.getWeekNumber(2, date.toLocalDate()))) + .put("%v", (date) -> // %v Week where Monday is the first day - WEEK() mode 3 used with %x + String.format("'%d'", CalendarLookup.getWeekNumber(3, date.toLocalDate()))) + .put("%X", (date) -> // %X Year for week where Sunday is the first day, 4 digits used with %V + String.format("'%d'", CalendarLookup.getYearNumber(2, date.toLocalDate()))) + .put("%x", (date) -> // %x Year for week where Monday is the first day, 4 digits used with %v + String.format("'%d'", CalendarLookup.getYearNumber(3, date.toLocalDate()))) + .build(); + + // Handlers for the time_format function. + // Some format specifiers return 0 or null to align with MySQL. + // https://dev.mysql.com/doc/refman/8.0/en/date-and-time-functions.html#function_time-format + private static final Map TIME_HANDLERS = ImmutableMap.builder() + .put("%a", (date) -> null) + .put("%b", (date) -> null) + .put("%c", (date) -> "0") + .put("%d", (date) -> "00") + .put("%e", (date) -> "0") + .put("%H", (date) -> "HH") // %H => HH - (00..23) + .put("%h", (date) -> "hh") // %h => hh - (01..12) + .put("%I", (date) -> "hh") // %I => hh - (01..12) + .put("%i", (date) -> "mm") // %i => mm - Minutes, numeric (00..59) + .put("%j", (date) -> null) + .put("%k", (date) -> "H") // %k => H - (0..23) + .put("%l", (date) -> "h") // %l => h - (1..12) + .put("%p", (date) -> "a") // %p => a - AM or PM + .put("%M", (date) -> null) + .put("%m", (date) -> "00") + .put("%r", (date) -> "hh:mm:ss a") // %r => hh:mm:ss a - hh:mm:ss followed by AM or PM + .put("%S", (date) -> "ss") // %S => ss - Seconds (00..59) + .put("%s", (date) -> "ss") // %s => ss - Seconds (00..59) + .put("%T", (date) -> "HH:mm:ss") // %T => HH:mm:ss + .put("%W", (date) -> null) + .put("%Y", (date) -> "0000") + .put("%y", (date) -> "00") + .put("%D", (date) -> null) + .put("%f", (date) -> // %f - Microseconds + String.format(NANO_SEC_FORMAT, (date.getNano() / 1000))) + .put("%w", (date) -> null) + .put("%U", (date) -> null) + .put("%u", (date) -> null) + .put("%V", (date) -> null) + .put("%v", (date) -> null) + .put("%X", (date) -> null) + .put("%x", (date) -> null) + .build(); + + private static final Map STR_TO_DATE_FORMATS = ImmutableMap.builder() + .put("%a", "EEE") // %a => EEE - Abbreviated weekday name (Sun..Sat) + .put("%b", "LLL") // %b => LLL - Abbreviated month name (Jan..Dec) + .put("%c", "M") // %c => MM - Month, numeric (0..12) + .put("%d", "d") // %d => dd - Day of the month, numeric (00..31) + .put("%e", "d") // %e => d - Day of the month, numeric (0..31) + .put("%H", "H") // %H => HH - (00..23) + .put("%h", "H") // %h => hh - (01..12) + .put("%I", "h") // %I => hh - (01..12) + .put("%i", "m") // %i => mm - Minutes, numeric (00..59) + .put("%j", "DDD") // %j => DDD - (001..366) + .put("%k", "H") // %k => H - (0..23) + .put("%l", "h") // %l => h - (1..12) + .put("%p", "a") // %p => a - AM or PM + .put("%M", "LLLL") // %M => LLLL - Month name (January..December) + .put("%m", "M") // %m => MM - Month, numeric (00..12) + .put("%r", "hh:mm:ss a") // %r => hh:mm:ss a - hh:mm:ss followed by AM or PM + .put("%S", "s") // %S => ss - Seconds (00..59) + .put("%s", "s") // %s => ss - Seconds (00..59) + .put("%T", "HH:mm:ss") // %T => HH:mm:ss + .put("%W", "EEEE") // %W => EEEE - Weekday name (Sunday..Saturday) + .put("%Y", "u") // %Y => yyyy - Year, numeric, 4 digits + .put("%y", "u") // %y => yy - Year, numeric, 2 digits + .put("%f", "n") // %f => n - Nanoseconds + // The following have been implemented but cannot be aligned with + // MySQL due to the limitations of the DatetimeFormatter + .put("%D", "d") // %w - Day of month with English suffix + .put("%w", "e") // %w - Day of week (0 indexed) + .put("%U", "w") // %U Week where Sunday is the first day - WEEK() mode 0 + .put("%u", "w") // %u Week where Monday is the first day - WEEK() mode 1 + .put("%V", "w") // %V Week where Sunday is the first day - WEEK() mode 2 + .put("%v", "w") // %v Week where Monday is the first day - WEEK() mode 3 + .put("%X", "u") // %X Year for week where Sunday is the first day + .put("%x", "u") // %x Year for week where Monday is the first day + .build(); + + private static final Pattern pattern = Pattern.compile("%."); + private static final Pattern CHARACTERS_WITH_NO_MOD_LITERAL_BEHIND_PATTERN = Pattern.compile("(? handler, LocalDateTime datetime) { + StringBuffer cleanFormat = getCleanFormat(formatExpr); + + final Matcher matcher = pattern.matcher(cleanFormat.toString()); + final StringBuffer format = new StringBuffer(); + try { + while (matcher.find()) { + matcher.appendReplacement( + format, + handler.getOrDefault(matcher.group(), (d) -> String.format("'%s'", matcher.group().replaceFirst(MOD_LITERAL, ""))) + .getFormat(datetime) + ); + } + } catch (Exception e) { + return ExprNullValue.of(); + } + matcher.appendTail(format); + + // English Locale matches SQL requirements. + // 'AM'/'PM' instead of 'a.m.'/'p.m.' + // 'Sat' instead of 'Sat.' etc + return new ExprStringValue(datetime.format(DateTimeFormatter.ofPattern(format.toString(), Locale.ENGLISH))); } - int year = taWithMissingFields.isSupported(ChronoField.YEAR) - ? taWithMissingFields.get(ChronoField.YEAR) : 2000; + /** + * Format the date using the date format String. + * @param dateExpr the date ExprValue of Date/Datetime/Timestamp/String type. + * @param formatExpr the format ExprValue of String type. + * @return Date formatted using format and returned as a String. + */ + static ExprValue getFormattedDate(ExprValue dateExpr, ExprValue formatExpr) { + final LocalDateTime date = dateExpr.datetimeValue(); + return getFormattedString(formatExpr, DATE_HANDLERS, date); + } - int month = taWithMissingFields.isSupported(ChronoField.MONTH_OF_YEAR) - ? taWithMissingFields.get(ChronoField.MONTH_OF_YEAR) : 1; + static ExprValue getFormattedDateOfToday(ExprValue formatExpr, ExprValue time, Clock current) { + final LocalDateTime date = LocalDateTime.of(LocalDate.now(current), time.timeValue()); - int day = taWithMissingFields.isSupported(ChronoField.DAY_OF_MONTH) - ? taWithMissingFields.get(ChronoField.DAY_OF_MONTH) : 1; + return getFormattedString(formatExpr, DATE_HANDLERS, date); + } - int hour = taWithMissingFields.isSupported(ChronoField.HOUR_OF_DAY) - ? taWithMissingFields.get(ChronoField.HOUR_OF_DAY) : 0; + /** + * Format the date using the date format String. + * @param timeExpr the date ExprValue of Date/Datetime/Timestamp/String type. + * @param formatExpr the format ExprValue of String type. + * @return Date formatted using format and returned as a String. + */ + static ExprValue getFormattedTime(ExprValue timeExpr, ExprValue formatExpr) { + // Initializes DateTime with LocalDate.now(). This is safe because the date is ignored. + // The time_format function will only return 0 or null for invalid string format specifiers. + final LocalDateTime time = LocalDateTime.of(LocalDate.now(), timeExpr.timeValue()); + + return getFormattedString(formatExpr, TIME_HANDLERS, time); + } - int minute = taWithMissingFields.isSupported(ChronoField.MINUTE_OF_HOUR) - ? taWithMissingFields.get(ChronoField.MINUTE_OF_HOUR) : 0; + private static boolean canGetDate(TemporalAccessor ta) { + return (ta.isSupported(ChronoField.YEAR) && ta.isSupported(ChronoField.MONTH_OF_YEAR) && ta.isSupported(ChronoField.DAY_OF_MONTH)); + } - int second = taWithMissingFields.isSupported(ChronoField.SECOND_OF_MINUTE) - ? taWithMissingFields.get(ChronoField.SECOND_OF_MINUTE) : 0; + private static boolean canGetTime(TemporalAccessor ta) { + return (ta.isSupported(ChronoField.HOUR_OF_DAY) + && ta.isSupported(ChronoField.MINUTE_OF_HOUR) + && ta.isSupported(ChronoField.SECOND_OF_MINUTE)); + } - //Fill returned datetime with current date if only Time information was parsed - LocalDateTime output; - if (!canGetDate(taWithMissingFields)) { - output = LocalDateTime.of( - LocalDate.now(fp.getQueryStartClock()), - LocalTime.of(hour, minute, second) - ); - } else { - output = LocalDateTime.of(year, month, day, hour, minute, second); + static ExprValue parseStringWithDateOrTime(FunctionProperties fp, ExprValue datetimeStringExpr, ExprValue formatExpr) { + + // Replace patterns with % for Java DateTimeFormatter + StringBuffer cleanFormat = getCleanFormat(formatExpr); + final Matcher matcher = pattern.matcher(cleanFormat.toString()); + final StringBuffer format = new StringBuffer(); + + while (matcher.find()) { + matcher.appendReplacement( + format, + STR_TO_DATE_FORMATS.getOrDefault(matcher.group(), String.format("'%s'", matcher.group().replaceFirst(MOD_LITERAL, ""))) + ); + } + matcher.appendTail(format); + + TemporalAccessor taWithMissingFields; + // Return NULL for invalid parse in string to align with MySQL + try { + // Get Temporal Accessor to initially parse string without default values + taWithMissingFields = new DateTimeFormatterBuilder().appendPattern(format.toString()) + .toFormatter() + .withResolverStyle(ResolverStyle.STRICT) + .parseUnresolved(datetimeStringExpr.stringValue(), new ParsePosition(0)); + if (taWithMissingFields == null) { + throw new DateTimeException("Input string could not be parsed properly."); + } + if (!canGetDate(taWithMissingFields) && !canGetTime(taWithMissingFields)) { + throw new DateTimeException("Not enough data to build a valid Date, Time, or Datetime."); + } + } catch (DateTimeException e) { + return ExprNullValue.of(); + } + + int year = taWithMissingFields.isSupported(ChronoField.YEAR) ? taWithMissingFields.get(ChronoField.YEAR) : 2000; + + int month = taWithMissingFields.isSupported(ChronoField.MONTH_OF_YEAR) ? taWithMissingFields.get(ChronoField.MONTH_OF_YEAR) : 1; + + int day = taWithMissingFields.isSupported(ChronoField.DAY_OF_MONTH) ? taWithMissingFields.get(ChronoField.DAY_OF_MONTH) : 1; + + int hour = taWithMissingFields.isSupported(ChronoField.HOUR_OF_DAY) ? taWithMissingFields.get(ChronoField.HOUR_OF_DAY) : 0; + + int minute = taWithMissingFields.isSupported(ChronoField.MINUTE_OF_HOUR) ? taWithMissingFields.get(ChronoField.MINUTE_OF_HOUR) : 0; + + int second = taWithMissingFields.isSupported(ChronoField.SECOND_OF_MINUTE) + ? taWithMissingFields.get(ChronoField.SECOND_OF_MINUTE) + : 0; + + // Fill returned datetime with current date if only Time information was parsed + LocalDateTime output; + if (!canGetDate(taWithMissingFields)) { + output = LocalDateTime.of(LocalDate.now(fp.getQueryStartClock()), LocalTime.of(hour, minute, second)); + } else { + output = LocalDateTime.of(year, month, day, hour, minute, second); + } + + return new ExprDatetimeValue(output); } - return new ExprDatetimeValue(output); - } - - /** - * Returns English suffix of incoming value. - * @param val Incoming value. - * @return English suffix as String (st, nd, rd, th) - */ - private static String getSuffix(int val) { - // The numbers 11, 12, and 13 do not follow general suffix rules. - if ((SUFFIX_SPECIAL_START_TH <= val) && (val <= SUFFIX_SPECIAL_END_TH)) { - return SUFFIX_SPECIAL_TH; + /** + * Returns English suffix of incoming value. + * @param val Incoming value. + * @return English suffix as String (st, nd, rd, th) + */ + private static String getSuffix(int val) { + // The numbers 11, 12, and 13 do not follow general suffix rules. + if ((SUFFIX_SPECIAL_START_TH <= val) && (val <= SUFFIX_SPECIAL_END_TH)) { + return SUFFIX_SPECIAL_TH; + } + return SUFFIX_CONVERTER.getOrDefault(val % 10, SUFFIX_SPECIAL_TH); } - return SUFFIX_CONVERTER.getOrDefault(val % 10, SUFFIX_SPECIAL_TH); - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/datetime/DateTimeFunction.java b/core/src/main/java/org/opensearch/sql/expression/datetime/DateTimeFunction.java index cd5ef23d1c..56004b591a 100644 --- a/core/src/main/java/org/opensearch/sql/expression/datetime/DateTimeFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/datetime/DateTimeFunction.java @@ -3,10 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.datetime; - import static java.time.temporal.ChronoUnit.DAYS; import static java.time.temporal.ChronoUnit.HOURS; import static java.time.temporal.ChronoUnit.MICROS; @@ -107,2160 +105,2155 @@ @UtilityClass @SuppressWarnings("unchecked") public class DateTimeFunction { - //The number of seconds per day - public static final long SECONDS_PER_DAY = 86400; - - // The number of days from year zero to year 1970. - private static final Long DAYS_0000_TO_1970 = (146097 * 5L) - (30L * 365L + 7L); - - // MySQL doesn't process any datetime/timestamp values which are greater than - // 32536771199.999999, or equivalent '3001-01-18 23:59:59.999999' UTC - private static final Double MYSQL_MAX_TIMESTAMP = 32536771200d; - - // Mode used for week/week_of_year function by default when no argument is provided - private static final ExprIntegerValue DEFAULT_WEEK_OF_YEAR_MODE = new ExprIntegerValue(0); - - // Map used to determine format output for the extract function - private static final Map extract_formats = - ImmutableMap.builder() - .put("MICROSECOND", "SSSSSS") - .put("SECOND", "ss") - .put("MINUTE", "mm") - .put("HOUR", "HH") - .put("DAY", "dd") - .put("WEEK", "w") - .put("MONTH", "MM") - .put("YEAR", "yyyy") - .put("SECOND_MICROSECOND", "ssSSSSSS") - .put("MINUTE_MICROSECOND", "mmssSSSSSS") - .put("MINUTE_SECOND", "mmss") - .put("HOUR_MICROSECOND", "HHmmssSSSSSS") - .put("HOUR_SECOND", "HHmmss") - .put("HOUR_MINUTE", "HHmm") - .put("DAY_MICROSECOND", "ddHHmmssSSSSSS") - .put("DAY_SECOND", "ddHHmmss") - .put("DAY_MINUTE", "ddHHmm") - .put("DAY_HOUR", "ddHH") - .put("YEAR_MONTH", "yyyyMM") - .put("QUARTER", "Q") - .build(); - - // Map used to determine format output for the get_format function - private static final Table formats = - ImmutableTable.builder() - .put("date", "usa", "%m.%d.%Y") - .put("date", "jis", "%Y-%m-%d") - .put("date", "iso", "%Y-%m-%d") - .put("date", "eur", "%d.%m.%Y") - .put("date", "internal", "%Y%m%d") - .put("datetime", "usa", "%Y-%m-%d %H.%i.%s") - .put("datetime", "jis", "%Y-%m-%d %H:%i:%s") - .put("datetime", "iso", "%Y-%m-%d %H:%i:%s") - .put("datetime", "eur", "%Y-%m-%d %H.%i.%s") - .put("datetime", "internal", "%Y%m%d%H%i%s") - .put("time", "usa", "%h:%i:%s %p") - .put("time", "jis", "%H:%i:%s") - .put("time", "iso", "%H:%i:%s") - .put("time", "eur", "%H.%i.%s") - .put("time", "internal", "%H%i%s") - .put("timestamp", "usa", "%Y-%m-%d %H.%i.%s") - .put("timestamp", "jis", "%Y-%m-%d %H:%i:%s") - .put("timestamp", "iso", "%Y-%m-%d %H:%i:%s") - .put("timestamp", "eur", "%Y-%m-%d %H.%i.%s") - .put("timestamp", "internal", "%Y%m%d%H%i%s") - .build(); - - /** - * Register Date and Time Functions. - * - * @param repository {@link BuiltinFunctionRepository}. - */ - public void register(BuiltinFunctionRepository repository) { - repository.register(adddate()); - repository.register(addtime()); - repository.register(convert_tz()); - repository.register(curtime()); - repository.register(curdate()); - repository.register(current_date()); - repository.register(current_time()); - repository.register(current_timestamp()); - repository.register(date()); - repository.register(datediff()); - repository.register(datetime()); - repository.register(date_add()); - repository.register(date_format()); - repository.register(date_sub()); - repository.register(day()); - repository.register(dayName()); - repository.register(dayOfMonth(BuiltinFunctionName.DAYOFMONTH)); - repository.register(dayOfMonth(BuiltinFunctionName.DAY_OF_MONTH)); - repository.register(dayOfWeek(BuiltinFunctionName.DAYOFWEEK.getName())); - repository.register(dayOfWeek(BuiltinFunctionName.DAY_OF_WEEK.getName())); - repository.register(dayOfYear(BuiltinFunctionName.DAYOFYEAR)); - repository.register(dayOfYear(BuiltinFunctionName.DAY_OF_YEAR)); - repository.register(extract()); - repository.register(from_days()); - repository.register(from_unixtime()); - repository.register(get_format()); - repository.register(hour(BuiltinFunctionName.HOUR)); - repository.register(hour(BuiltinFunctionName.HOUR_OF_DAY)); - repository.register(last_day()); - repository.register(localtime()); - repository.register(localtimestamp()); - repository.register(makedate()); - repository.register(maketime()); - repository.register(microsecond()); - repository.register(minute(BuiltinFunctionName.MINUTE)); - repository.register(minute_of_day()); - repository.register(minute(BuiltinFunctionName.MINUTE_OF_HOUR)); - repository.register(month(BuiltinFunctionName.MONTH)); - repository.register(month(BuiltinFunctionName.MONTH_OF_YEAR)); - repository.register(monthName()); - repository.register(now()); - repository.register(period_add()); - repository.register(period_diff()); - repository.register(quarter()); - repository.register(sec_to_time()); - repository.register(second(BuiltinFunctionName.SECOND)); - repository.register(second(BuiltinFunctionName.SECOND_OF_MINUTE)); - repository.register(subdate()); - repository.register(subtime()); - repository.register(str_to_date()); - repository.register(sysdate()); - repository.register(time()); - repository.register(time_format()); - repository.register(time_to_sec()); - repository.register(timediff()); - repository.register(timestamp()); - repository.register(timestampadd()); - repository.register(timestampdiff()); - repository.register(to_days()); - repository.register(to_seconds()); - repository.register(unix_timestamp()); - repository.register(utc_date()); - repository.register(utc_time()); - repository.register(utc_timestamp()); - repository.register(week(BuiltinFunctionName.WEEK)); - repository.register(week(BuiltinFunctionName.WEEKOFYEAR)); - repository.register(week(BuiltinFunctionName.WEEK_OF_YEAR)); - repository.register(weekday()); - repository.register(year()); - repository.register(yearweek()); - } - - /** - * NOW() returns a constant time that indicates the time at which the statement began to execute. - * `fsp` argument support is removed until refactoring to avoid bug where `now()`, `now(x)` and - * `now(y) return different values. - */ - private FunctionResolver now(FunctionName functionName) { - return define(functionName, - implWithProperties( - functionProperties -> new ExprDatetimeValue( - formatNow(functionProperties.getQueryStartClock())), DATETIME) - ); - } - - private FunctionResolver now() { - return now(BuiltinFunctionName.NOW.getName()); - } - - private FunctionResolver current_timestamp() { - return now(BuiltinFunctionName.CURRENT_TIMESTAMP.getName()); - } - - private FunctionResolver localtimestamp() { - return now(BuiltinFunctionName.LOCALTIMESTAMP.getName()); - } - - private FunctionResolver localtime() { - return now(BuiltinFunctionName.LOCALTIME.getName()); - } - - /** - * SYSDATE() returns the time at which it executes. - */ - private FunctionResolver sysdate() { - return define(BuiltinFunctionName.SYSDATE.getName(), - implWithProperties(functionProperties - -> new ExprDatetimeValue(formatNow(Clock.systemDefaultZone())), DATETIME), - FunctionDSL.implWithProperties((functionProperties, v) -> new ExprDatetimeValue( - formatNow(Clock.systemDefaultZone(), v.integerValue())), DATETIME, INTEGER) - ); - } - - /** - * Synonym for @see `now`. - */ - private FunctionResolver curtime(FunctionName functionName) { - return define(functionName, - implWithProperties(functionProperties -> new ExprTimeValue( - formatNow(functionProperties.getQueryStartClock()).toLocalTime()), TIME)); - } - - private FunctionResolver curtime() { - return curtime(BuiltinFunctionName.CURTIME.getName()); - } - - private FunctionResolver current_time() { - return curtime(BuiltinFunctionName.CURRENT_TIME.getName()); - } - - private FunctionResolver curdate(FunctionName functionName) { - return define(functionName, - implWithProperties(functionProperties -> new ExprDateValue( - formatNow(functionProperties.getQueryStartClock()).toLocalDate()), DATE)); - } - - private FunctionResolver curdate() { - return curdate(BuiltinFunctionName.CURDATE.getName()); - } - - private FunctionResolver current_date() { - return curdate(BuiltinFunctionName.CURRENT_DATE.getName()); - } - - /** - * A common signature for `date_add` and `date_sub`. - * Specify a start date and add/subtract a temporal amount to/from the date. - * The return type depends on the date type and the interval unit. Detailed supported signatures: - * (DATE/DATETIME/TIMESTAMP/TIME, INTERVAL) -> DATETIME - * MySQL has these signatures too - * (DATE, INTERVAL) -> DATE // when interval has no time part - * (TIME, INTERVAL) -> TIME // when interval has no date part - * (STRING, INTERVAL) -> STRING // when argument has date or datetime string, - * // result has date or datetime depending on interval type - */ - private Stream> get_date_add_date_sub_signatures( - SerializableTriFunction function) { - return Stream.of( - implWithProperties(nullMissingHandlingWithProperties(function), DATETIME, DATE, INTERVAL), - implWithProperties(nullMissingHandlingWithProperties(function), - DATETIME, DATETIME, INTERVAL), - implWithProperties(nullMissingHandlingWithProperties(function), - DATETIME, TIMESTAMP, INTERVAL), - implWithProperties(nullMissingHandlingWithProperties(function), DATETIME, TIME, INTERVAL) - ); - } - - /** - * A common signature for `adddate` and `subdate`. - * Adds/subtracts an integer number of days to/from the first argument. - * (DATE, LONG) -> DATE - * (TIME/DATETIME/TIMESTAMP, LONG) -> DATETIME - */ - private Stream> get_adddate_subdate_signatures( - SerializableTriFunction function) { - return Stream.of( - implWithProperties(nullMissingHandlingWithProperties(function), DATE, DATE, LONG), - implWithProperties(nullMissingHandlingWithProperties(function), DATETIME, DATETIME, LONG), - implWithProperties(nullMissingHandlingWithProperties(function), DATETIME, TIMESTAMP, LONG), - implWithProperties(nullMissingHandlingWithProperties(function), DATETIME, TIME, LONG) - ); - } - - private DefaultFunctionResolver adddate() { - return define(BuiltinFunctionName.ADDDATE.getName(), - (SerializableFunction>[]) - (Stream.concat( - get_date_add_date_sub_signatures(DateTimeFunction::exprAddDateInterval), - get_adddate_subdate_signatures(DateTimeFunction::exprAddDateDays)) - .toArray(SerializableFunction[]::new))); - } - - /** - * Adds expr2 to expr1 and returns the result. - * (TIME, TIME/DATE/DATETIME/TIMESTAMP) -> TIME - * (DATE/DATETIME/TIMESTAMP, TIME/DATE/DATETIME/TIMESTAMP) -> DATETIME - * TODO: MySQL has these signatures too - * (STRING, STRING/TIME) -> STRING // second arg - string with time only - * (x, STRING) -> NULL // second arg - string with timestamp - * (x, STRING/DATE) -> x // second arg - string with date only - */ - private DefaultFunctionResolver addtime() { - return define(BuiltinFunctionName.ADDTIME.getName(), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), - TIME, TIME, TIME), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), - TIME, TIME, DATE), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), - TIME, TIME, DATETIME), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), - TIME, TIME, TIMESTAMP), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), - DATETIME, DATETIME, TIME), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), - DATETIME, DATETIME, DATE), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), - DATETIME, DATETIME, DATETIME), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), - DATETIME, DATETIME, TIMESTAMP), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), - DATETIME, DATE, TIME), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), - DATETIME, DATE, DATE), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), - DATETIME, DATE, DATETIME), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), - DATETIME, DATE, TIMESTAMP), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), - DATETIME, TIMESTAMP, TIME), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), - DATETIME, TIMESTAMP, DATE), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), - DATETIME, TIMESTAMP, DATETIME), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), - DATETIME, TIMESTAMP, TIMESTAMP) - ); - } - - /** - * Converts date/time from a specified timezone to another specified timezone. - * The supported signatures: - * (DATETIME, STRING, STRING) -> DATETIME - * (STRING, STRING, STRING) -> DATETIME - */ - private DefaultFunctionResolver convert_tz() { - return define(BuiltinFunctionName.CONVERT_TZ.getName(), - impl(nullMissingHandling(DateTimeFunction::exprConvertTZ), - DATETIME, DATETIME, STRING, STRING), - impl(nullMissingHandling(DateTimeFunction::exprConvertTZ), - DATETIME, STRING, STRING, STRING) - ); - } - - /** - * Extracts the date part of a date and time value. - * Also to construct a date type. The supported signatures: - * STRING/DATE/DATETIME/TIMESTAMP -> DATE - */ - private DefaultFunctionResolver date() { - return define(BuiltinFunctionName.DATE.getName(), - impl(nullMissingHandling(DateTimeFunction::exprDate), DATE, STRING), - impl(nullMissingHandling(DateTimeFunction::exprDate), DATE, DATE), - impl(nullMissingHandling(DateTimeFunction::exprDate), DATE, DATETIME), - impl(nullMissingHandling(DateTimeFunction::exprDate), DATE, TIMESTAMP)); - } - - /* - * Calculates the difference of date part of given values. - * (DATE/DATETIME/TIMESTAMP/TIME, DATE/DATETIME/TIMESTAMP/TIME) -> LONG - */ - private DefaultFunctionResolver datediff() { - return define(BuiltinFunctionName.DATEDIFF.getName(), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), - LONG, DATE, DATE), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), - LONG, DATETIME, DATE), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), - LONG, DATE, DATETIME), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), - LONG, DATETIME, DATETIME), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), - LONG, DATE, TIME), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), - LONG, TIME, DATE), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), - LONG, TIME, TIME), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), - LONG, TIMESTAMP, DATE), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), - LONG, DATE, TIMESTAMP), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), - LONG, TIMESTAMP, TIMESTAMP), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), - LONG, TIMESTAMP, TIME), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), - LONG, TIME, TIMESTAMP), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), - LONG, TIMESTAMP, DATETIME), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), - LONG, DATETIME, TIMESTAMP), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), - LONG, TIME, DATETIME), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), - LONG, DATETIME, TIME)); - } - - /** - * Specify a datetime with time zone field and a time zone to convert to. - * Returns a local date time. - * (STRING, STRING) -> DATETIME - * (STRING) -> DATETIME - */ - private FunctionResolver datetime() { - return define(BuiltinFunctionName.DATETIME.getName(), - impl(nullMissingHandling(DateTimeFunction::exprDateTime), - DATETIME, STRING, STRING), - impl(nullMissingHandling(DateTimeFunction::exprDateTimeNoTimezone), - DATETIME, STRING) - ); - } - - private DefaultFunctionResolver date_add() { - return define(BuiltinFunctionName.DATE_ADD.getName(), - (SerializableFunction>[]) - get_date_add_date_sub_signatures(DateTimeFunction::exprAddDateInterval) - .toArray(SerializableFunction[]::new)); - } - - private DefaultFunctionResolver date_sub() { - return define(BuiltinFunctionName.DATE_SUB.getName(), - (SerializableFunction>[]) - get_date_add_date_sub_signatures(DateTimeFunction::exprSubDateInterval) - .toArray(SerializableFunction[]::new)); - } - - /** - * DAY(STRING/DATE/DATETIME/TIMESTAMP). return the day of the month (1-31). - */ - private DefaultFunctionResolver day() { - return define(BuiltinFunctionName.DAY.getName(), - impl(nullMissingHandling(DateTimeFunction::exprDayOfMonth), INTEGER, DATE), - impl(nullMissingHandling(DateTimeFunction::exprDayOfMonth), INTEGER, DATETIME), - impl(nullMissingHandling(DateTimeFunction::exprDayOfMonth), INTEGER, TIMESTAMP), - impl(nullMissingHandling(DateTimeFunction::exprDayOfMonth), INTEGER, STRING) - ); - } - - /** - * DAYNAME(STRING/DATE/DATETIME/TIMESTAMP). - * return the name of the weekday for date, including Monday, Tuesday, Wednesday, - * Thursday, Friday, Saturday and Sunday. - */ - private DefaultFunctionResolver dayName() { - return define(BuiltinFunctionName.DAYNAME.getName(), - impl(nullMissingHandling(DateTimeFunction::exprDayName), STRING, DATE), - impl(nullMissingHandling(DateTimeFunction::exprDayName), STRING, DATETIME), - impl(nullMissingHandling(DateTimeFunction::exprDayName), STRING, TIMESTAMP), - impl(nullMissingHandling(DateTimeFunction::exprDayName), STRING, STRING) - ); - } - - /** - * DAYOFMONTH(STRING/DATE/DATETIME/TIMESTAMP). return the day of the month (1-31). - */ - private DefaultFunctionResolver dayOfMonth(BuiltinFunctionName name) { - return define(name.getName(), - implWithProperties(nullMissingHandlingWithProperties( - (functionProperties, arg) -> DateTimeFunction.dayOfMonthToday( - functionProperties.getQueryStartClock())), INTEGER, TIME), - impl(nullMissingHandling(DateTimeFunction::exprDayOfMonth), INTEGER, DATE), - impl(nullMissingHandling(DateTimeFunction::exprDayOfMonth), INTEGER, DATETIME), - impl(nullMissingHandling(DateTimeFunction::exprDayOfMonth), INTEGER, STRING), - impl(nullMissingHandling(DateTimeFunction::exprDayOfMonth), INTEGER, TIMESTAMP) - ); - } - - /** - * DAYOFWEEK(STRING/DATE/DATETIME/TIME/TIMESTAMP). - * return the weekday index for date (1 = Sunday, 2 = Monday, ..., 7 = Saturday). - */ - private DefaultFunctionResolver dayOfWeek(FunctionName name) { - return define(name, - implWithProperties(nullMissingHandlingWithProperties( - (functionProperties, arg) -> DateTimeFunction.dayOfWeekToday( - functionProperties.getQueryStartClock())), INTEGER, TIME), - impl(nullMissingHandling(DateTimeFunction::exprDayOfWeek), INTEGER, DATE), - impl(nullMissingHandling(DateTimeFunction::exprDayOfWeek), INTEGER, DATETIME), - impl(nullMissingHandling(DateTimeFunction::exprDayOfWeek), INTEGER, TIMESTAMP), - impl(nullMissingHandling(DateTimeFunction::exprDayOfWeek), INTEGER, STRING) - ); - } - - /** - * DAYOFYEAR(STRING/DATE/DATETIME/TIMESTAMP). - * return the day of the year for date (1-366). - */ - private DefaultFunctionResolver dayOfYear(BuiltinFunctionName dayOfYear) { - return define(dayOfYear.getName(), - implWithProperties(nullMissingHandlingWithProperties((functionProperties, arg) - -> DateTimeFunction.dayOfYearToday( - functionProperties.getQueryStartClock())), INTEGER, TIME), - impl(nullMissingHandling(DateTimeFunction::exprDayOfYear), INTEGER, DATE), - impl(nullMissingHandling(DateTimeFunction::exprDayOfYear), INTEGER, DATETIME), - impl(nullMissingHandling(DateTimeFunction::exprDayOfYear), INTEGER, TIMESTAMP), - impl(nullMissingHandling(DateTimeFunction::exprDayOfYear), INTEGER, STRING) - ); - } - - private DefaultFunctionResolver extract() { - return define(BuiltinFunctionName.EXTRACT.getName(), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprExtractForTime), - LONG, STRING, TIME), - impl(nullMissingHandling(DateTimeFunction::exprExtract), LONG, STRING, DATE), - impl(nullMissingHandling(DateTimeFunction::exprExtract), LONG, STRING, DATETIME), - impl(nullMissingHandling(DateTimeFunction::exprExtract), LONG, STRING, TIMESTAMP), - impl(nullMissingHandling(DateTimeFunction::exprExtract), LONG, STRING, STRING) + // The number of seconds per day + public static final long SECONDS_PER_DAY = 86400; + + // The number of days from year zero to year 1970. + private static final Long DAYS_0000_TO_1970 = (146097 * 5L) - (30L * 365L + 7L); + + // MySQL doesn't process any datetime/timestamp values which are greater than + // 32536771199.999999, or equivalent '3001-01-18 23:59:59.999999' UTC + private static final Double MYSQL_MAX_TIMESTAMP = 32536771200d; + + // Mode used for week/week_of_year function by default when no argument is provided + private static final ExprIntegerValue DEFAULT_WEEK_OF_YEAR_MODE = new ExprIntegerValue(0); + + // Map used to determine format output for the extract function + private static final Map extract_formats = ImmutableMap.builder() + .put("MICROSECOND", "SSSSSS") + .put("SECOND", "ss") + .put("MINUTE", "mm") + .put("HOUR", "HH") + .put("DAY", "dd") + .put("WEEK", "w") + .put("MONTH", "MM") + .put("YEAR", "yyyy") + .put("SECOND_MICROSECOND", "ssSSSSSS") + .put("MINUTE_MICROSECOND", "mmssSSSSSS") + .put("MINUTE_SECOND", "mmss") + .put("HOUR_MICROSECOND", "HHmmssSSSSSS") + .put("HOUR_SECOND", "HHmmss") + .put("HOUR_MINUTE", "HHmm") + .put("DAY_MICROSECOND", "ddHHmmssSSSSSS") + .put("DAY_SECOND", "ddHHmmss") + .put("DAY_MINUTE", "ddHHmm") + .put("DAY_HOUR", "ddHH") + .put("YEAR_MONTH", "yyyyMM") + .put("QUARTER", "Q") + .build(); + + // Map used to determine format output for the get_format function + private static final Table formats = ImmutableTable.builder() + .put("date", "usa", "%m.%d.%Y") + .put("date", "jis", "%Y-%m-%d") + .put("date", "iso", "%Y-%m-%d") + .put("date", "eur", "%d.%m.%Y") + .put("date", "internal", "%Y%m%d") + .put("datetime", "usa", "%Y-%m-%d %H.%i.%s") + .put("datetime", "jis", "%Y-%m-%d %H:%i:%s") + .put("datetime", "iso", "%Y-%m-%d %H:%i:%s") + .put("datetime", "eur", "%Y-%m-%d %H.%i.%s") + .put("datetime", "internal", "%Y%m%d%H%i%s") + .put("time", "usa", "%h:%i:%s %p") + .put("time", "jis", "%H:%i:%s") + .put("time", "iso", "%H:%i:%s") + .put("time", "eur", "%H.%i.%s") + .put("time", "internal", "%H%i%s") + .put("timestamp", "usa", "%Y-%m-%d %H.%i.%s") + .put("timestamp", "jis", "%Y-%m-%d %H:%i:%s") + .put("timestamp", "iso", "%Y-%m-%d %H:%i:%s") + .put("timestamp", "eur", "%Y-%m-%d %H.%i.%s") + .put("timestamp", "internal", "%Y%m%d%H%i%s") + .build(); + + /** + * Register Date and Time Functions. + * + * @param repository {@link BuiltinFunctionRepository}. + */ + public void register(BuiltinFunctionRepository repository) { + repository.register(adddate()); + repository.register(addtime()); + repository.register(convert_tz()); + repository.register(curtime()); + repository.register(curdate()); + repository.register(current_date()); + repository.register(current_time()); + repository.register(current_timestamp()); + repository.register(date()); + repository.register(datediff()); + repository.register(datetime()); + repository.register(date_add()); + repository.register(date_format()); + repository.register(date_sub()); + repository.register(day()); + repository.register(dayName()); + repository.register(dayOfMonth(BuiltinFunctionName.DAYOFMONTH)); + repository.register(dayOfMonth(BuiltinFunctionName.DAY_OF_MONTH)); + repository.register(dayOfWeek(BuiltinFunctionName.DAYOFWEEK.getName())); + repository.register(dayOfWeek(BuiltinFunctionName.DAY_OF_WEEK.getName())); + repository.register(dayOfYear(BuiltinFunctionName.DAYOFYEAR)); + repository.register(dayOfYear(BuiltinFunctionName.DAY_OF_YEAR)); + repository.register(extract()); + repository.register(from_days()); + repository.register(from_unixtime()); + repository.register(get_format()); + repository.register(hour(BuiltinFunctionName.HOUR)); + repository.register(hour(BuiltinFunctionName.HOUR_OF_DAY)); + repository.register(last_day()); + repository.register(localtime()); + repository.register(localtimestamp()); + repository.register(makedate()); + repository.register(maketime()); + repository.register(microsecond()); + repository.register(minute(BuiltinFunctionName.MINUTE)); + repository.register(minute_of_day()); + repository.register(minute(BuiltinFunctionName.MINUTE_OF_HOUR)); + repository.register(month(BuiltinFunctionName.MONTH)); + repository.register(month(BuiltinFunctionName.MONTH_OF_YEAR)); + repository.register(monthName()); + repository.register(now()); + repository.register(period_add()); + repository.register(period_diff()); + repository.register(quarter()); + repository.register(sec_to_time()); + repository.register(second(BuiltinFunctionName.SECOND)); + repository.register(second(BuiltinFunctionName.SECOND_OF_MINUTE)); + repository.register(subdate()); + repository.register(subtime()); + repository.register(str_to_date()); + repository.register(sysdate()); + repository.register(time()); + repository.register(time_format()); + repository.register(time_to_sec()); + repository.register(timediff()); + repository.register(timestamp()); + repository.register(timestampadd()); + repository.register(timestampdiff()); + repository.register(to_days()); + repository.register(to_seconds()); + repository.register(unix_timestamp()); + repository.register(utc_date()); + repository.register(utc_time()); + repository.register(utc_timestamp()); + repository.register(week(BuiltinFunctionName.WEEK)); + repository.register(week(BuiltinFunctionName.WEEKOFYEAR)); + repository.register(week(BuiltinFunctionName.WEEK_OF_YEAR)); + repository.register(weekday()); + repository.register(year()); + repository.register(yearweek()); + } + + /** + * NOW() returns a constant time that indicates the time at which the statement began to execute. + * `fsp` argument support is removed until refactoring to avoid bug where `now()`, `now(x)` and + * `now(y) return different values. + */ + private FunctionResolver now(FunctionName functionName) { + return define( + functionName, + implWithProperties(functionProperties -> new ExprDatetimeValue(formatNow(functionProperties.getQueryStartClock())), DATETIME) ); - } - - /** - * FROM_DAYS(LONG). return the date value given the day number N. - */ - private DefaultFunctionResolver from_days() { - return define(BuiltinFunctionName.FROM_DAYS.getName(), - impl(nullMissingHandling(DateTimeFunction::exprFromDays), DATE, LONG)); - } - - private FunctionResolver from_unixtime() { - return define(BuiltinFunctionName.FROM_UNIXTIME.getName(), - impl(nullMissingHandling(DateTimeFunction::exprFromUnixTime), DATETIME, DOUBLE), - impl(nullMissingHandling(DateTimeFunction::exprFromUnixTimeFormat), - STRING, DOUBLE, STRING)); - } - - private DefaultFunctionResolver get_format() { - return define(BuiltinFunctionName.GET_FORMAT.getName(), - impl(nullMissingHandling(DateTimeFunction::exprGetFormat), STRING, STRING, STRING) - ); - } - - /** - * HOUR(STRING/TIME/DATETIME/DATE/TIMESTAMP). return the hour value for time. - */ - private DefaultFunctionResolver hour(BuiltinFunctionName name) { - return define(name.getName(), - impl(nullMissingHandling(DateTimeFunction::exprHour), INTEGER, STRING), - impl(nullMissingHandling(DateTimeFunction::exprHour), INTEGER, TIME), - impl(nullMissingHandling(DateTimeFunction::exprHour), INTEGER, DATE), - impl(nullMissingHandling(DateTimeFunction::exprHour), INTEGER, DATETIME), - impl(nullMissingHandling(DateTimeFunction::exprHour), INTEGER, TIMESTAMP) - ); - } - - private DefaultFunctionResolver last_day() { - return define(BuiltinFunctionName.LAST_DAY.getName(), - impl(nullMissingHandling(DateTimeFunction::exprLastDay), DATE, STRING), - implWithProperties(nullMissingHandlingWithProperties((functionProperties, arg) - -> DateTimeFunction.exprLastDayToday( - functionProperties.getQueryStartClock())), DATE, TIME), - impl(nullMissingHandling(DateTimeFunction::exprLastDay), DATE, DATE), - impl(nullMissingHandling(DateTimeFunction::exprLastDay), DATE, DATETIME), - impl(nullMissingHandling(DateTimeFunction::exprLastDay), DATE, TIMESTAMP) - ); - } - - private FunctionResolver makedate() { - return define(BuiltinFunctionName.MAKEDATE.getName(), - impl(nullMissingHandling(DateTimeFunction::exprMakeDate), DATE, DOUBLE, DOUBLE)); - } - - private FunctionResolver maketime() { - return define(BuiltinFunctionName.MAKETIME.getName(), - impl(nullMissingHandling(DateTimeFunction::exprMakeTime), TIME, DOUBLE, DOUBLE, DOUBLE)); - } - - /** - * MICROSECOND(STRING/TIME/DATETIME/TIMESTAMP). return the microsecond value for time. - */ - private DefaultFunctionResolver microsecond() { - return define(BuiltinFunctionName.MICROSECOND.getName(), - impl(nullMissingHandling(DateTimeFunction::exprMicrosecond), INTEGER, STRING), - impl(nullMissingHandling(DateTimeFunction::exprMicrosecond), INTEGER, TIME), - impl(nullMissingHandling(DateTimeFunction::exprMicrosecond), INTEGER, DATETIME), - impl(nullMissingHandling(DateTimeFunction::exprMicrosecond), INTEGER, TIMESTAMP) - ); - } - - /** - * MINUTE(STRING/TIME/DATETIME/TIMESTAMP). return the minute value for time. - */ - private DefaultFunctionResolver minute(BuiltinFunctionName name) { - return define(name.getName(), - impl(nullMissingHandling(DateTimeFunction::exprMinute), INTEGER, STRING), - impl(nullMissingHandling(DateTimeFunction::exprMinute), INTEGER, TIME), - impl(nullMissingHandling(DateTimeFunction::exprMinute), INTEGER, DATETIME), - impl(nullMissingHandling(DateTimeFunction::exprMinute), INTEGER, DATE), - impl(nullMissingHandling(DateTimeFunction::exprMinute), INTEGER, TIMESTAMP) - ); - } - - /** - * MINUTE(STRING/TIME/DATETIME/TIMESTAMP). return the minute value for time. - */ - private DefaultFunctionResolver minute_of_day() { - return define(BuiltinFunctionName.MINUTE_OF_DAY.getName(), - impl(nullMissingHandling(DateTimeFunction::exprMinuteOfDay), INTEGER, STRING), - impl(nullMissingHandling(DateTimeFunction::exprMinuteOfDay), INTEGER, TIME), - impl(nullMissingHandling(DateTimeFunction::exprMinuteOfDay), INTEGER, DATE), - impl(nullMissingHandling(DateTimeFunction::exprMinuteOfDay), INTEGER, DATETIME), - impl(nullMissingHandling(DateTimeFunction::exprMinuteOfDay), INTEGER, TIMESTAMP) - ); - } - - /** - * MONTH(STRING/DATE/DATETIME/TIMESTAMP). return the month for date (1-12). - */ - private DefaultFunctionResolver month(BuiltinFunctionName month) { - return define(month.getName(), - implWithProperties(nullMissingHandlingWithProperties((functionProperties, arg) - -> DateTimeFunction.monthOfYearToday( - functionProperties.getQueryStartClock())), INTEGER, TIME), - impl(nullMissingHandling(DateTimeFunction::exprMonth), INTEGER, DATE), - impl(nullMissingHandling(DateTimeFunction::exprMonth), INTEGER, DATETIME), - impl(nullMissingHandling(DateTimeFunction::exprMonth), INTEGER, TIMESTAMP), - impl(nullMissingHandling(DateTimeFunction::exprMonth), INTEGER, STRING) - ); - } - - /** - * MONTHNAME(STRING/DATE/DATETIME/TIMESTAMP). return the full name of the month for date. - */ - private DefaultFunctionResolver monthName() { - return define(BuiltinFunctionName.MONTHNAME.getName(), - impl(nullMissingHandling(DateTimeFunction::exprMonthName), STRING, DATE), - impl(nullMissingHandling(DateTimeFunction::exprMonthName), STRING, DATETIME), - impl(nullMissingHandling(DateTimeFunction::exprMonthName), STRING, TIMESTAMP), - impl(nullMissingHandling(DateTimeFunction::exprMonthName), STRING, STRING) - ); - } - - /** - * Add N months to period P (in the format YYMM or YYYYMM). Returns a value in the format YYYYMM. - * (INTEGER, INTEGER) -> INTEGER - */ - private DefaultFunctionResolver period_add() { - return define(BuiltinFunctionName.PERIOD_ADD.getName(), - impl(nullMissingHandling(DateTimeFunction::exprPeriodAdd), INTEGER, INTEGER, INTEGER) - ); - } - - /** - * Returns the number of months between periods P1 and P2. - * P1 and P2 should be in the format YYMM or YYYYMM. - * (INTEGER, INTEGER) -> INTEGER - */ - private DefaultFunctionResolver period_diff() { - return define(BuiltinFunctionName.PERIOD_DIFF.getName(), - impl(nullMissingHandling(DateTimeFunction::exprPeriodDiff), INTEGER, INTEGER, INTEGER) - ); - } - - /** - * QUARTER(STRING/DATE/DATETIME/TIMESTAMP). return the month for date (1-4). - */ - private DefaultFunctionResolver quarter() { - return define(BuiltinFunctionName.QUARTER.getName(), - impl(nullMissingHandling(DateTimeFunction::exprQuarter), INTEGER, DATE), - impl(nullMissingHandling(DateTimeFunction::exprQuarter), INTEGER, DATETIME), - impl(nullMissingHandling(DateTimeFunction::exprQuarter), INTEGER, TIMESTAMP), - impl(nullMissingHandling(DateTimeFunction::exprQuarter), INTEGER, STRING) - ); - } - - private DefaultFunctionResolver sec_to_time() { - return define(BuiltinFunctionName.SEC_TO_TIME.getName(), - impl((nullMissingHandling(DateTimeFunction::exprSecToTime)), TIME, INTEGER), - impl((nullMissingHandling(DateTimeFunction::exprSecToTime)), TIME, LONG), - impl((nullMissingHandling(DateTimeFunction::exprSecToTimeWithNanos)), TIME, DOUBLE), - impl((nullMissingHandling(DateTimeFunction::exprSecToTimeWithNanos)), TIME, FLOAT) - ); - } - - /** - * SECOND(STRING/TIME/DATETIME/TIMESTAMP). return the second value for time. - */ - private DefaultFunctionResolver second(BuiltinFunctionName name) { - return define(name.getName(), - impl(nullMissingHandling(DateTimeFunction::exprSecond), INTEGER, STRING), - impl(nullMissingHandling(DateTimeFunction::exprSecond), INTEGER, TIME), - impl(nullMissingHandling(DateTimeFunction::exprSecond), INTEGER, DATE), - impl(nullMissingHandling(DateTimeFunction::exprSecond), INTEGER, DATETIME), - impl(nullMissingHandling(DateTimeFunction::exprSecond), INTEGER, TIMESTAMP) - ); - } - - private DefaultFunctionResolver subdate() { - return define(BuiltinFunctionName.SUBDATE.getName(), - (SerializableFunction>[]) - (Stream.concat( - get_date_add_date_sub_signatures(DateTimeFunction::exprSubDateInterval), - get_adddate_subdate_signatures(DateTimeFunction::exprSubDateDays)) - .toArray(SerializableFunction[]::new))); - } - - /** - * Subtracts expr2 from expr1 and returns the result. - * (TIME, TIME/DATE/DATETIME/TIMESTAMP) -> TIME - * (DATE/DATETIME/TIMESTAMP, TIME/DATE/DATETIME/TIMESTAMP) -> DATETIME - * TODO: MySQL has these signatures too - * (STRING, STRING/TIME) -> STRING // second arg - string with time only - * (x, STRING) -> NULL // second arg - string with timestamp - * (x, STRING/DATE) -> x // second arg - string with date only - */ - private DefaultFunctionResolver subtime() { - return define(BuiltinFunctionName.SUBTIME.getName(), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), - TIME, TIME, TIME), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), - TIME, TIME, DATE), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), - TIME, TIME, DATETIME), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), - TIME, TIME, TIMESTAMP), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), - DATETIME, DATETIME, TIME), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), - DATETIME, DATETIME, DATE), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), - DATETIME, DATETIME, DATETIME), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), - DATETIME, DATETIME, TIMESTAMP), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), - DATETIME, DATE, TIME), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), - DATETIME, DATE, DATE), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), - DATETIME, DATE, DATETIME), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), - DATETIME, DATE, TIMESTAMP), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), - DATETIME, TIMESTAMP, TIME), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), - DATETIME, TIMESTAMP, DATE), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), - DATETIME, TIMESTAMP, DATETIME), - implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), - DATETIME, TIMESTAMP, TIMESTAMP) - ); - } - - /** - * Extracts a date, time, or datetime from the given string. - * It accomplishes this using another string which specifies the input format. - */ - private DefaultFunctionResolver str_to_date() { - return define(BuiltinFunctionName.STR_TO_DATE.getName(), - implWithProperties( - nullMissingHandlingWithProperties((functionProperties, arg, format) - -> DateTimeFunction.exprStrToDate(functionProperties, arg, format)), - DATETIME, STRING, STRING)); - } - - /** - * Extracts the time part of a date and time value. - * Also to construct a time type. The supported signatures: - * STRING/DATE/DATETIME/TIME/TIMESTAMP -> TIME - */ - private DefaultFunctionResolver time() { - return define(BuiltinFunctionName.TIME.getName(), - impl(nullMissingHandling(DateTimeFunction::exprTime), TIME, STRING), - impl(nullMissingHandling(DateTimeFunction::exprTime), TIME, DATE), - impl(nullMissingHandling(DateTimeFunction::exprTime), TIME, DATETIME), - impl(nullMissingHandling(DateTimeFunction::exprTime), TIME, TIME), - impl(nullMissingHandling(DateTimeFunction::exprTime), TIME, TIMESTAMP)); - } - - /** - * Returns different between two times as a time. - * (TIME, TIME) -> TIME - * MySQL has these signatures too - * (DATE, DATE) -> TIME // result is > 24 hours - * (DATETIME, DATETIME) -> TIME // result is > 24 hours - * (TIMESTAMP, TIMESTAMP) -> TIME // result is > 24 hours - * (x, x) -> NULL // when args have different types - * (STRING, STRING) -> TIME // argument strings contain same types only - * (STRING, STRING) -> NULL // argument strings are different types - */ - private DefaultFunctionResolver timediff() { - return define(BuiltinFunctionName.TIMEDIFF.getName(), - impl(nullMissingHandling(DateTimeFunction::exprTimeDiff), TIME, TIME, TIME)); - } - - /** - * TIME_TO_SEC(STRING/TIME/DATETIME/TIMESTAMP). return the time argument, converted to seconds. - */ - private DefaultFunctionResolver time_to_sec() { - return define(BuiltinFunctionName.TIME_TO_SEC.getName(), - impl(nullMissingHandling(DateTimeFunction::exprTimeToSec), LONG, STRING), - impl(nullMissingHandling(DateTimeFunction::exprTimeToSec), LONG, TIME), - impl(nullMissingHandling(DateTimeFunction::exprTimeToSec), LONG, TIMESTAMP), - impl(nullMissingHandling(DateTimeFunction::exprTimeToSec), LONG, DATETIME) - ); - } - - /** - * Extracts the timestamp of a date and time value. - * Input strings may contain a timestamp only in format 'yyyy-MM-dd HH:mm:ss[.SSSSSSSSS]' - * STRING/DATE/TIME/DATETIME/TIMESTAMP -> TIMESTAMP - * STRING/DATE/TIME/DATETIME/TIMESTAMP, STRING/DATE/TIME/DATETIME/TIMESTAMP -> TIMESTAMP - * All types are converted to TIMESTAMP actually before the function call - it is responsibility - * of the automatic cast mechanism defined in `ExprCoreType` and performed by `TypeCastOperator`. - */ - private DefaultFunctionResolver timestamp() { - return define(BuiltinFunctionName.TIMESTAMP.getName(), - impl(nullMissingHandling(v -> v), TIMESTAMP, TIMESTAMP), - // We can use FunctionProperties.None, because it is not used. It is required to convert - // TIME to other datetime types, but arguments there are already converted. - impl(nullMissingHandling((v1, v2) -> exprAddTime(FunctionProperties.None, v1, v2)), - TIMESTAMP, TIMESTAMP, TIMESTAMP)); - } - - /** - * Adds an interval of time to the provided DATE/DATETIME/TIME/TIMESTAMP/STRING argument. - * The interval of time added is determined by the given first and second arguments. - * The first argument is an interval type, and must be one of the tokens below... - * [MICROSECOND, SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, YEAR] - * The second argument is the amount of the interval type to be added. - * The third argument is the DATE/DATETIME/TIME/TIMESTAMP/STRING to add to. - * @return The DATETIME representing the summed DATE/DATETIME/TIME/TIMESTAMP and interval. - */ - private DefaultFunctionResolver timestampadd() { - return define(BuiltinFunctionName.TIMESTAMPADD.getName(), - impl(nullMissingHandling(DateTimeFunction::exprTimestampAdd), - DATETIME, STRING, INTEGER, DATETIME), - impl(nullMissingHandling(DateTimeFunction::exprTimestampAdd), - DATETIME, STRING, INTEGER, TIMESTAMP), - implWithProperties( - nullMissingHandlingWithProperties( - (functionProperties, part, amount, time) -> exprTimestampAddForTimeType( - functionProperties.getQueryStartClock(), - part, - amount, - time)), - DATETIME, STRING, INTEGER, TIME)); - } - - /** - * Finds the difference between provided DATE/DATETIME/TIME/TIMESTAMP/STRING arguments. - * The first argument is an interval type, and must be one of the tokens below... - * [MICROSECOND, SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, YEAR] - * The second argument the DATE/DATETIME/TIME/TIMESTAMP/STRING representing the start time. - * The third argument is the DATE/DATETIME/TIME/TIMESTAMP/STRING representing the end time. - * @return A LONG representing the difference between arguments, using the given interval type. - */ - private DefaultFunctionResolver timestampdiff() { - return define(BuiltinFunctionName.TIMESTAMPDIFF.getName(), - impl(nullMissingHandling(DateTimeFunction::exprTimestampDiff), - DATETIME, STRING, DATETIME, DATETIME), - impl(nullMissingHandling(DateTimeFunction::exprTimestampDiff), - DATETIME, STRING, DATETIME, TIMESTAMP), - impl(nullMissingHandling(DateTimeFunction::exprTimestampDiff), - DATETIME, STRING, TIMESTAMP, DATETIME), - impl(nullMissingHandling(DateTimeFunction::exprTimestampDiff), - DATETIME, STRING, TIMESTAMP, TIMESTAMP), - implWithProperties( - nullMissingHandlingWithProperties( - (functionProperties, part, startTime, endTime) -> exprTimestampDiffForTimeType( - functionProperties, - part, - startTime, - endTime)), - DATETIME, STRING, TIME, TIME) - ); - } - - /** - * TO_DAYS(STRING/DATE/DATETIME/TIMESTAMP). return the day number of the given date. - */ - private DefaultFunctionResolver to_days() { - return define(BuiltinFunctionName.TO_DAYS.getName(), - impl(nullMissingHandling(DateTimeFunction::exprToDays), LONG, STRING), - impl(nullMissingHandling(DateTimeFunction::exprToDays), LONG, TIMESTAMP), - impl(nullMissingHandling(DateTimeFunction::exprToDays), LONG, DATE), - impl(nullMissingHandling(DateTimeFunction::exprToDays), LONG, DATETIME)); - } - - /** - * TO_SECONDS(TIMESTAMP/LONG). return the seconds number of the given date. - * Arguments of type STRING/TIMESTAMP/LONG are also accepted. - * STRING/TIMESTAMP/LONG arguments are automatically cast to TIMESTAMP. - */ - private DefaultFunctionResolver to_seconds() { - return define(BuiltinFunctionName.TO_SECONDS.getName(), - impl(nullMissingHandling(DateTimeFunction::exprToSeconds), LONG, TIMESTAMP), - impl(nullMissingHandling(DateTimeFunction::exprToSecondsForIntType), LONG, LONG)); - } - - private FunctionResolver unix_timestamp() { - return define(BuiltinFunctionName.UNIX_TIMESTAMP.getName(), - implWithProperties(functionProperties - -> DateTimeFunction.unixTimeStamp(functionProperties.getQueryStartClock()), LONG), - impl(nullMissingHandling(DateTimeFunction::unixTimeStampOf), DOUBLE, DATE), - impl(nullMissingHandling(DateTimeFunction::unixTimeStampOf), DOUBLE, DATETIME), - impl(nullMissingHandling(DateTimeFunction::unixTimeStampOf), DOUBLE, TIMESTAMP), - impl(nullMissingHandling(DateTimeFunction::unixTimeStampOf), DOUBLE, DOUBLE) - ); - } - - /** - * UTC_DATE(). return the current UTC Date in format yyyy-MM-dd - */ - private DefaultFunctionResolver utc_date() { - return define(BuiltinFunctionName.UTC_DATE.getName(), - implWithProperties(functionProperties - -> exprUtcDate(functionProperties), DATE)); - } - - /** - * UTC_TIME(). return the current UTC Time in format HH:mm:ss - */ - private DefaultFunctionResolver utc_time() { - return define(BuiltinFunctionName.UTC_TIME.getName(), - implWithProperties(functionProperties - -> exprUtcTime(functionProperties), TIME)); - } - - /** - * UTC_TIMESTAMP(). return the current UTC TimeStamp in format yyyy-MM-dd HH:mm:ss - */ - private DefaultFunctionResolver utc_timestamp() { - return define(BuiltinFunctionName.UTC_TIMESTAMP.getName(), - implWithProperties(functionProperties - -> exprUtcTimeStamp(functionProperties), DATETIME)); - } - - /** - * WEEK(DATE[,mode]). return the week number for date. - */ - private DefaultFunctionResolver week(BuiltinFunctionName week) { - return define(week.getName(), - implWithProperties(nullMissingHandlingWithProperties((functionProperties, arg) - -> DateTimeFunction.weekOfYearToday( - DEFAULT_WEEK_OF_YEAR_MODE, - functionProperties.getQueryStartClock())), INTEGER, TIME), - impl(nullMissingHandling(DateTimeFunction::exprWeekWithoutMode), INTEGER, DATE), - impl(nullMissingHandling(DateTimeFunction::exprWeekWithoutMode), INTEGER, DATETIME), - impl(nullMissingHandling(DateTimeFunction::exprWeekWithoutMode), INTEGER, TIMESTAMP), - impl(nullMissingHandling(DateTimeFunction::exprWeekWithoutMode), INTEGER, STRING), - implWithProperties(nullMissingHandlingWithProperties((functionProperties, time, modeArg) - -> DateTimeFunction.weekOfYearToday( - modeArg, - functionProperties.getQueryStartClock())), INTEGER, TIME, INTEGER), - impl(nullMissingHandling(DateTimeFunction::exprWeek), INTEGER, DATE, INTEGER), - impl(nullMissingHandling(DateTimeFunction::exprWeek), INTEGER, DATETIME, INTEGER), - impl(nullMissingHandling(DateTimeFunction::exprWeek), INTEGER, TIMESTAMP, INTEGER), - impl(nullMissingHandling(DateTimeFunction::exprWeek), INTEGER, STRING, INTEGER) - ); - } - - private DefaultFunctionResolver weekday() { - return define(BuiltinFunctionName.WEEKDAY.getName(), - implWithProperties(nullMissingHandlingWithProperties( - (functionProperties, arg) -> new ExprIntegerValue( - formatNow(functionProperties.getQueryStartClock()).getDayOfWeek().getValue() - 1)), - INTEGER, TIME), - impl(nullMissingHandling(DateTimeFunction::exprWeekday), INTEGER, DATE), - impl(nullMissingHandling(DateTimeFunction::exprWeekday), INTEGER, DATETIME), - impl(nullMissingHandling(DateTimeFunction::exprWeekday), INTEGER, TIMESTAMP), - impl(nullMissingHandling(DateTimeFunction::exprWeekday), INTEGER, STRING) - ); - } - - /** - * YEAR(STRING/DATE/DATETIME/TIMESTAMP). return the year for date (1000-9999). - */ - private DefaultFunctionResolver year() { - return define(BuiltinFunctionName.YEAR.getName(), - impl(nullMissingHandling(DateTimeFunction::exprYear), INTEGER, DATE), - impl(nullMissingHandling(DateTimeFunction::exprYear), INTEGER, DATETIME), - impl(nullMissingHandling(DateTimeFunction::exprYear), INTEGER, TIMESTAMP), - impl(nullMissingHandling(DateTimeFunction::exprYear), INTEGER, STRING) - ); - } - - /** - * YEARWEEK(DATE[,mode]). return the week number for date. - */ - private DefaultFunctionResolver yearweek() { - return define(BuiltinFunctionName.YEARWEEK.getName(), - implWithProperties(nullMissingHandlingWithProperties((functionProperties, arg) - -> yearweekToday( - DEFAULT_WEEK_OF_YEAR_MODE, - functionProperties.getQueryStartClock())), INTEGER, TIME), - impl(nullMissingHandling(DateTimeFunction::exprYearweekWithoutMode), INTEGER, DATE), - impl(nullMissingHandling(DateTimeFunction::exprYearweekWithoutMode), INTEGER, DATETIME), - impl(nullMissingHandling(DateTimeFunction::exprYearweekWithoutMode), INTEGER, TIMESTAMP), - impl(nullMissingHandling(DateTimeFunction::exprYearweekWithoutMode), INTEGER, STRING), - implWithProperties(nullMissingHandlingWithProperties((functionProperties, time, modeArg) - -> yearweekToday( - modeArg, - functionProperties.getQueryStartClock())), INTEGER, TIME, INTEGER), - impl(nullMissingHandling(DateTimeFunction::exprYearweek), INTEGER, DATE, INTEGER), - impl(nullMissingHandling(DateTimeFunction::exprYearweek), INTEGER, DATETIME, INTEGER), - impl(nullMissingHandling(DateTimeFunction::exprYearweek), INTEGER, TIMESTAMP, INTEGER), - impl(nullMissingHandling(DateTimeFunction::exprYearweek), INTEGER, STRING, INTEGER) - ); - } - - /** - * Formats date according to format specifier. First argument is date, second is format. - * Detailed supported signatures: - * (STRING, STRING) -> STRING - * (DATE, STRING) -> STRING - * (DATETIME, STRING) -> STRING - * (TIME, STRING) -> STRING - * (TIMESTAMP, STRING) -> STRING - */ - private DefaultFunctionResolver date_format() { - return define(BuiltinFunctionName.DATE_FORMAT.getName(), - impl(nullMissingHandling(DateTimeFormatterUtil::getFormattedDate), - STRING, STRING, STRING), - impl(nullMissingHandling(DateTimeFormatterUtil::getFormattedDate), - STRING, DATE, STRING), - impl(nullMissingHandling(DateTimeFormatterUtil::getFormattedDate), - STRING, DATETIME, STRING), - implWithProperties( - nullMissingHandlingWithProperties( - (functionProperties, time, formatString) - -> DateTimeFormatterUtil.getFormattedDateOfToday( - formatString, time, functionProperties.getQueryStartClock())), - STRING, TIME, STRING), - impl(nullMissingHandling(DateTimeFormatterUtil::getFormattedDate), - STRING, TIMESTAMP, STRING) - ); - } - - - private ExprValue dayOfMonthToday(Clock clock) { - return new ExprIntegerValue(LocalDateTime.now(clock).getDayOfMonth()); - } - - private ExprValue dayOfYearToday(Clock clock) { - return new ExprIntegerValue(LocalDateTime.now(clock).getDayOfYear()); - } - - private ExprValue weekOfYearToday(ExprValue mode, Clock clock) { - return new ExprIntegerValue( - CalendarLookup.getWeekNumber(mode.integerValue(), LocalDateTime.now(clock).toLocalDate())); - } - - /** - * Day of Week implementation for ExprValue when passing in an arguemt of type TIME. - * - * @param clock Current clock taken from function properties - * @return ExprValue. - */ - private ExprValue dayOfWeekToday(Clock clock) { - return new ExprIntegerValue((formatNow(clock).getDayOfWeek().getValue() % 7) + 1); - } - - /** - * DATE_ADD function implementation for ExprValue. - * - * @param functionProperties An FunctionProperties object. - * @param datetime ExprValue of Date/Time/Datetime/Timestamp type. - * @param interval ExprValue of Interval type, the temporal amount to add. - * @return Datetime resulted from `interval` added to `datetime`. - */ - private ExprValue exprAddDateInterval(FunctionProperties functionProperties, - ExprValue datetime, ExprValue interval) { - return exprDateApplyInterval(functionProperties, datetime, interval.intervalValue(), true); - } - - /** - * Adds or subtracts `interval` to/from `datetime`. - * - * @param functionProperties An FunctionProperties object. - * @param datetime A Date/Time/Datetime/Timestamp value to change. - * @param interval An Interval to isAdd or subtract. - * @param isAdd A flag: true to isAdd, false to subtract. - * @return Datetime calculated. - */ - private ExprValue exprDateApplyInterval(FunctionProperties functionProperties, - ExprValue datetime, - TemporalAmount interval, - Boolean isAdd) { - var dt = extractDateTime(datetime, functionProperties); - return new ExprDatetimeValue(isAdd ? dt.plus(interval) : dt.minus(interval)); - } - - /** - * Formats date according to format specifier. First argument is time, second is format. - * Detailed supported signatures: - * (STRING, STRING) -> STRING - * (DATE, STRING) -> STRING - * (DATETIME, STRING) -> STRING - * (TIME, STRING) -> STRING - * (TIMESTAMP, STRING) -> STRING - */ - private DefaultFunctionResolver time_format() { - return define(BuiltinFunctionName.TIME_FORMAT.getName(), - impl(nullMissingHandling(DateTimeFormatterUtil::getFormattedTime), - STRING, STRING, STRING), - impl(nullMissingHandling(DateTimeFormatterUtil::getFormattedTime), - STRING, DATE, STRING), - impl(nullMissingHandling(DateTimeFormatterUtil::getFormattedTime), - STRING, DATETIME, STRING), - impl(nullMissingHandling(DateTimeFormatterUtil::getFormattedTime), - STRING, TIME, STRING), - impl(nullMissingHandling(DateTimeFormatterUtil::getFormattedTime), - STRING, TIMESTAMP, STRING) - ); - } - - /** - * ADDDATE function implementation for ExprValue. - * - * @param functionProperties An FunctionProperties object. - * @param datetime ExprValue of Time/Date/Datetime/Timestamp type. - * @param days ExprValue of Long type, representing the number of days to add. - * @return Date/Datetime resulted from days added to `datetime`. - */ - private ExprValue exprAddDateDays(FunctionProperties functionProperties, - ExprValue datetime, ExprValue days) { - return exprDateApplyDays(functionProperties, datetime, days.longValue(), true); - } - - /** - * Adds or subtracts `days` to/from `datetime`. - * - * @param functionProperties An FunctionProperties object. - * @param datetime A Date/Time/Datetime/Timestamp value to change. - * @param days A days amount to add or subtract. - * @param isAdd A flag: true to add, false to subtract. - * @return Datetime calculated. - */ - private ExprValue exprDateApplyDays(FunctionProperties functionProperties, - ExprValue datetime, Long days, Boolean isAdd) { - if (datetime.type() == DATE) { - return new ExprDateValue(isAdd ? datetime.dateValue().plusDays(days) - : datetime.dateValue().minusDays(days)); - } - var dt = extractDateTime(datetime, functionProperties); - return new ExprDatetimeValue(isAdd ? dt.plusDays(days) : dt.minusDays(days)); - } - - /** - * Adds or subtracts time to/from date and returns the result. - * - * @param functionProperties A FunctionProperties object. - * @param temporal A Date/Time/Datetime/Timestamp value to change. - * @param temporalDelta A Date/Time/Datetime/Timestamp object to add/subtract time from. - * @param isAdd A flag: true to add, false to subtract. - * @return A value calculated. - */ - private ExprValue exprApplyTime(FunctionProperties functionProperties, - ExprValue temporal, ExprValue temporalDelta, Boolean isAdd) { - var interval = Duration.between(LocalTime.MIN, temporalDelta.timeValue()); - var result = isAdd - ? extractDateTime(temporal, functionProperties).plus(interval) - : extractDateTime(temporal, functionProperties).minus(interval); - return temporal.type() == TIME - ? new ExprTimeValue(result.toLocalTime()) - : new ExprDatetimeValue(result); - } - - /** - * Adds time to date and returns the result. - * - * @param functionProperties A FunctionProperties object. - * @param temporal A Date/Time/Datetime/Timestamp value to change. - * @param temporalDelta A Date/Time/Datetime/Timestamp object to add time from. - * @return A value calculated. - */ - private ExprValue exprAddTime(FunctionProperties functionProperties, - ExprValue temporal, ExprValue temporalDelta) { - return exprApplyTime(functionProperties, temporal, temporalDelta, true); - } - - /** - * CONVERT_TZ function implementation for ExprValue. - * Returns null for time zones outside of +13:00 and -12:00. - * - * @param startingDateTime ExprValue of DateTime that is being converted from - * @param fromTz ExprValue of time zone, representing the time to convert from. - * @param toTz ExprValue of time zone, representing the time to convert to. - * @return DateTime that has been converted to the to_tz timezone. - */ - private ExprValue exprConvertTZ(ExprValue startingDateTime, ExprValue fromTz, ExprValue toTz) { - if (startingDateTime.type() == ExprCoreType.STRING) { - startingDateTime = exprDateTimeNoTimezone(startingDateTime); - } - try { - ZoneId convertedFromTz = ZoneId.of(fromTz.stringValue()); - ZoneId convertedToTz = ZoneId.of(toTz.stringValue()); - - // isValidMySqlTimeZoneId checks if the timezone is within the range accepted by - // MySQL standard. - if (!DateTimeUtils.isValidMySqlTimeZoneId(convertedFromTz) - || !DateTimeUtils.isValidMySqlTimeZoneId(convertedToTz)) { - return ExprNullValue.of(); - } - ZonedDateTime zonedDateTime = - startingDateTime.datetimeValue().atZone(convertedFromTz); - return new ExprDatetimeValue( - zonedDateTime.withZoneSameInstant(convertedToTz).toLocalDateTime()); - - // Catches exception for invalid timezones. - // ex. "+0:00" is an invalid timezone and would result in this exception being thrown. - } catch (ExpressionEvaluationException | DateTimeException e) { - return ExprNullValue.of(); - } - } - - /** - * Date implementation for ExprValue. - * - * @param exprValue ExprValue of Date type or String type. - * @return ExprValue. - */ - private ExprValue exprDate(ExprValue exprValue) { - if (exprValue instanceof ExprStringValue) { - return new ExprDateValue(exprValue.stringValue()); - } else { - return new ExprDateValue(exprValue.dateValue()); - } - } - - /** - * Calculate the value in days from one date to the other. - * Only the date parts of the values are used in the calculation. - * - * @param first The first value. - * @param second The second value. - * @return The diff. - */ - private ExprValue exprDateDiff(FunctionProperties functionProperties, - ExprValue first, ExprValue second) { - // java inverses the value, so we have to swap 1 and 2 - return new ExprLongValue(DAYS.between( - extractDate(second, functionProperties), - extractDate(first, functionProperties))); - } - - /** - * DateTime implementation for ExprValue. - * - * @param dateTime ExprValue of String type. - * @param timeZone ExprValue of String type (or null). - * @return ExprValue of date type. - */ - private ExprValue exprDateTime(ExprValue dateTime, ExprValue timeZone) { - String defaultTimeZone = TimeZone.getDefault().getID(); - - try { - LocalDateTime ldtFormatted = - LocalDateTime.parse(dateTime.stringValue(), DATE_TIME_FORMATTER_STRICT_WITH_TZ); - if (timeZone.isNull()) { - return new ExprDatetimeValue(ldtFormatted); - } - - // Used if datetime field is invalid format. - } catch (DateTimeParseException e) { - return ExprNullValue.of(); - } - - ExprValue convertTZResult; - ExprDatetimeValue ldt; - String toTz; - - try { - ZonedDateTime zdtWithZoneOffset = - ZonedDateTime.parse(dateTime.stringValue(), DATE_TIME_FORMATTER_STRICT_WITH_TZ); - ZoneId fromTZ = zdtWithZoneOffset.getZone(); - - ldt = new ExprDatetimeValue(zdtWithZoneOffset.toLocalDateTime()); - toTz = String.valueOf(fromTZ); - } catch (DateTimeParseException e) { - ldt = new ExprDatetimeValue(dateTime.stringValue()); - toTz = defaultTimeZone; - } - convertTZResult = exprConvertTZ( - ldt, - new ExprStringValue(toTz), - timeZone); - - return convertTZResult; - } - - /** - * DateTime implementation for ExprValue without a timezone to convert to. - * - * @param dateTime ExprValue of String type. - * @return ExprValue of date type. - */ - private ExprValue exprDateTimeNoTimezone(ExprValue dateTime) { - return exprDateTime(dateTime, ExprNullValue.of()); - } - - /** - * Name of the Weekday implementation for ExprValue. - * - * @param date ExprValue of Date/String type. - * @return ExprValue. - */ - private ExprValue exprDayName(ExprValue date) { - return new ExprStringValue( - date.dateValue().getDayOfWeek().getDisplayName(TextStyle.FULL, Locale.getDefault())); - } - - /** - * Day of Month implementation for ExprValue. - * - * @param date ExprValue of Date/Datetime/String/Time/Timestamp type. - * @return ExprValue. - */ - private ExprValue exprDayOfMonth(ExprValue date) { - return new ExprIntegerValue(date.dateValue().getDayOfMonth()); - } - - /** - * Day of Week implementation for ExprValue. - * - * @param date ExprValue of Date/Datetime/String/Timstamp type. - * @return ExprValue. - */ - private ExprValue exprDayOfWeek(ExprValue date) { - return new ExprIntegerValue((date.dateValue().getDayOfWeek().getValue() % 7) + 1); - } - - /** - * Day of Year implementation for ExprValue. - * - * @param date ExprValue of Date/String type. - * @return ExprValue. - */ - private ExprValue exprDayOfYear(ExprValue date) { - return new ExprIntegerValue(date.dateValue().getDayOfYear()); - } - - /** - * Obtains a formatted long value for a specified part and datetime for the 'extract' function. - * - * @param part is an ExprValue which comes from a defined list of accepted values. - * @param datetime the date to be formatted as an ExprValue. - * @return is a LONG formatted according to the input arguments. - */ - public ExprLongValue formatExtractFunction(ExprValue part, ExprValue datetime) { - String partName = part.stringValue().toUpperCase(); - LocalDateTime arg = datetime.datetimeValue(); - String text = arg.format(DateTimeFormatter.ofPattern( - extract_formats.get(partName), Locale.ENGLISH)); - - return new ExprLongValue(Long.parseLong(text)); - } - - /** - * Implements extract function. Returns a LONG formatted according to the 'part' argument. - * - * @param part Literal that determines the format of the outputted LONG. - * @param datetime The date/datetime to be formatted. - * @return A LONG - */ - private ExprValue exprExtract(ExprValue part, ExprValue datetime) { - return formatExtractFunction(part, datetime); - } - - /** - * Implements extract function. Returns a LONG formatted according to the 'part' argument. - * - * @param part Literal that determines the format of the outputted LONG. - * @param time The time to be formatted. - * @return A LONG - */ - private ExprValue exprExtractForTime(FunctionProperties functionProperties, - ExprValue part, - ExprValue time) { - return formatExtractFunction( - part, - new ExprDatetimeValue(extractDateTime(time, functionProperties))); - } - - /** - * From_days implementation for ExprValue. - * - * @param exprValue Day number N. - * @return ExprValue. - */ - private ExprValue exprFromDays(ExprValue exprValue) { - return new ExprDateValue(LocalDate.ofEpochDay(exprValue.longValue() - DAYS_0000_TO_1970)); - } - - private ExprValue exprFromUnixTime(ExprValue time) { - if (0 > time.doubleValue()) { - return ExprNullValue.of(); - } - // According to MySQL documentation: - // effective maximum is 32536771199.999999, which returns '3001-01-18 23:59:59.999999' UTC. - // Regardless of platform or version, a greater value for first argument than the effective - // maximum returns 0. - if (MYSQL_MAX_TIMESTAMP <= time.doubleValue()) { - return ExprNullValue.of(); - } - return new ExprDatetimeValue(exprFromUnixTimeImpl(time)); - } - - private LocalDateTime exprFromUnixTimeImpl(ExprValue time) { - return LocalDateTime.ofInstant( - Instant.ofEpochSecond((long)Math.floor(time.doubleValue())), - UTC_ZONE_ID) - .withNano((int)((time.doubleValue() % 1) * 1E9)); - } - - private ExprValue exprFromUnixTimeFormat(ExprValue time, ExprValue format) { - var value = exprFromUnixTime(time); - if (value.equals(ExprNullValue.of())) { - return ExprNullValue.of(); - } - return DateTimeFormatterUtil.getFormattedDate(value, format); - } - - /** - * get_format implementation for ExprValue. - * - * @param type ExprValue of the type. - * @param format ExprValue of Time/String type - * @return ExprValue.. - */ - private ExprValue exprGetFormat(ExprValue type, ExprValue format) { - if (formats.contains(type.stringValue().toLowerCase(), format.stringValue().toLowerCase())) { - return new ExprStringValue(formats.get( - type.stringValue().toLowerCase(), - format.stringValue().toLowerCase())); - } - - return ExprNullValue.of(); - } - - /** - * Hour implementation for ExprValue. - * - * @param time ExprValue of Time/String type. - * @return ExprValue. - */ - private ExprValue exprHour(ExprValue time) { - return new ExprIntegerValue( - HOURS.between(LocalTime.MIN, time.timeValue())); - } - - /** - * Helper function to retrieve the last day of a month based on a LocalDate argument. - * - * @param today a LocalDate. - * @return a LocalDate associated with the last day of the month for the given input. - */ - private LocalDate getLastDay(LocalDate today) { - return LocalDate.of( - today.getYear(), - today.getMonth(), - today.getMonth().length(today.isLeapYear())); - } - - /** - * Returns a DATE for the last day of the month of a given argument. - * - * @param datetime A DATE/DATETIME/TIMESTAMP/STRING ExprValue. - * @return An DATE value corresponding to the last day of the month of the given argument. - */ - private ExprValue exprLastDay(ExprValue datetime) { - return new ExprDateValue(getLastDay(datetime.dateValue())); - } - - /** - * Returns a DATE for the last day of the current month. - * - * @param clock The clock for the query start time from functionProperties. - * @return An DATE value corresponding to the last day of the month of the given argument. - */ - private ExprValue exprLastDayToday(Clock clock) { - return new ExprDateValue(getLastDay(formatNow(clock).toLocalDate())); - } - - /** - * Following MySQL, function receives arguments of type double and rounds them before use. - * Furthermore: - * - zero year interpreted as 2000 - * - negative year is not accepted - * - @dayOfYear should be greater than 1 - * - if @dayOfYear is greater than 365/366, calculation goes to the next year(s) - * - * @param yearExpr year - * @param dayOfYearExp day of the @year, starting from 1 - * @return Date - ExprDateValue object with LocalDate - */ - private ExprValue exprMakeDate(ExprValue yearExpr, ExprValue dayOfYearExp) { - var year = Math.round(yearExpr.doubleValue()); - var dayOfYear = Math.round(dayOfYearExp.doubleValue()); - // We need to do this to comply with MySQL - if (0 >= dayOfYear || 0 > year) { - return ExprNullValue.of(); - } - if (0 == year) { - year = 2000; - } - return new ExprDateValue(LocalDate.ofYearDay((int)year, 1).plusDays(dayOfYear - 1)); - } - - /** - * Following MySQL, function receives arguments of type double. @hour and @minute are rounded, - * while @second used as is, including fraction part. - * @param hourExpr hour - * @param minuteExpr minute - * @param secondExpr second - * @return Time - ExprTimeValue object with LocalTime - */ - private ExprValue exprMakeTime(ExprValue hourExpr, ExprValue minuteExpr, ExprValue secondExpr) { - var hour = Math.round(hourExpr.doubleValue()); - var minute = Math.round(minuteExpr.doubleValue()); - var second = secondExpr.doubleValue(); - if (0 > hour || 0 > minute || 0 > second) { - return ExprNullValue.of(); - } - return new ExprTimeValue(LocalTime.parse(String.format("%02d:%02d:%012.9f", - hour, minute, second), DateTimeFormatter.ISO_TIME)); - } - - /** - * Microsecond implementation for ExprValue. - * - * @param time ExprValue of Time/String type. - * @return ExprValue. - */ - private ExprValue exprMicrosecond(ExprValue time) { - return new ExprIntegerValue( - TimeUnit.MICROSECONDS.convert(time.timeValue().getNano(), TimeUnit.NANOSECONDS)); - } - - /** - * Minute implementation for ExprValue. - * - * @param time ExprValue of Time/String type. - * @return ExprValue. - */ - private ExprValue exprMinute(ExprValue time) { - return new ExprIntegerValue( - (MINUTES.between(LocalTime.MIN, time.timeValue()) % 60)); - } - - /** - * Minute_of_day implementation for ExprValue. - * - * @param time ExprValue of Time/String type. - * @return ExprValue. - */ - private ExprValue exprMinuteOfDay(ExprValue time) { - return new ExprIntegerValue( - MINUTES.between(LocalTime.MIN, time.timeValue())); - } - - /** - * Month for date implementation for ExprValue. - * - * @param date ExprValue of Date/String type. - * @return ExprValue. - */ - private ExprValue exprMonth(ExprValue date) { - return new ExprIntegerValue(date.dateValue().getMonthValue()); - } - - /** - * Name of the Month implementation for ExprValue. - * - * @param date ExprValue of Date/String type. - * @return ExprValue. - */ - private ExprValue exprMonthName(ExprValue date) { - return new ExprStringValue( - date.dateValue().getMonth().getDisplayName(TextStyle.FULL, Locale.getDefault())); - } - - private LocalDate parseDatePeriod(Integer period) { - var input = period.toString(); - // MySQL undocumented: if year is not specified or has 1 digit - 2000/200x is assumed - if (input.length() <= 5) { - input = String.format("200%05d", period); - } - try { - return LocalDate.parse(input, DATE_FORMATTER_SHORT_YEAR); - } catch (DateTimeParseException ignored) { - // nothing to do, try another format - } - try { - return LocalDate.parse(input, DATE_FORMATTER_LONG_YEAR); - } catch (DateTimeParseException ignored) { - return null; - } - } - - /** - * Adds N months to period P (in the format YYMM or YYYYMM). - * Returns a value in the format YYYYMM. - * - * @param period Period in the format YYMM or YYYYMM. - * @param months Amount of months to add. - * @return ExprIntegerValue. - */ - private ExprValue exprPeriodAdd(ExprValue period, ExprValue months) { - // We should add a day to make string parsable and remove it afterwards - var input = period.integerValue() * 100 + 1; // adds 01 to end of the string - var parsedDate = parseDatePeriod(input); - if (parsedDate == null) { - return ExprNullValue.of(); - } - var res = DATE_FORMATTER_LONG_YEAR.format(parsedDate.plusMonths(months.integerValue())); - return new ExprIntegerValue(Integer.parseInt( - res.substring(0, res.length() - 2))); // Remove the day part, .eg. 20070101 -> 200701 - } - - /** - * Returns the number of months between periods P1 and P2. - * P1 and P2 should be in the format YYMM or YYYYMM. - * - * @param period1 Period in the format YYMM or YYYYMM. - * @param period2 Period in the format YYMM or YYYYMM. - * @return ExprIntegerValue. - */ - private ExprValue exprPeriodDiff(ExprValue period1, ExprValue period2) { - var parsedDate1 = parseDatePeriod(period1.integerValue() * 100 + 1); - var parsedDate2 = parseDatePeriod(period2.integerValue() * 100 + 1); - if (parsedDate1 == null || parsedDate2 == null) { - return ExprNullValue.of(); - } - return new ExprIntegerValue(MONTHS.between(parsedDate2, parsedDate1)); - } - - /** - * Quarter for date implementation for ExprValue. - * - * @param date ExprValue of Date/String type. - * @return ExprValue. - */ - private ExprValue exprQuarter(ExprValue date) { - int month = date.dateValue().getMonthValue(); - return new ExprIntegerValue((month / 3) + ((month % 3) == 0 ? 0 : 1)); - } - - /** - * Returns TIME value of sec_to_time function for an INTEGER or LONG arguments. - * @param totalSeconds The total number of seconds - * @return A TIME value - */ - private ExprValue exprSecToTime(ExprValue totalSeconds) { - return new ExprTimeValue(LocalTime.MIN.plus(Duration.ofSeconds(totalSeconds.longValue()))); - } - - /** - * Helper function which obtains the decimal portion of the seconds value passed in. - * Uses BigDecimal to prevent issues with math on floating point numbers. - * Return is formatted to be used with Duration.ofSeconds(); - * - * @param seconds and ExprDoubleValue or ExprFloatValue for the seconds - * @return A LONG representing the nanoseconds portion - */ - private long formatNanos(ExprValue seconds) { - //Convert ExprValue to BigDecimal - BigDecimal formattedNanos = BigDecimal.valueOf(seconds.doubleValue()); - //Extract only the nanosecond part - formattedNanos = formattedNanos.subtract(BigDecimal.valueOf(formattedNanos.intValue())); - - return formattedNanos.scaleByPowerOfTen(9).longValue(); - } - - /** - * Returns TIME value of sec_to_time function for FLOAT or DOUBLE arguments. - * @param totalSeconds The total number of seconds - * @return A TIME value - */ - private ExprValue exprSecToTimeWithNanos(ExprValue totalSeconds) { - long nanos = formatNanos(totalSeconds); - - return new ExprTimeValue( - LocalTime.MIN.plus(Duration.ofSeconds(totalSeconds.longValue(), nanos))); - } - - /** - * Second implementation for ExprValue. - * - * @param time ExprValue of Time/String type. - * @return ExprValue. - */ - private ExprValue exprSecond(ExprValue time) { - return new ExprIntegerValue( - (SECONDS.between(LocalTime.MIN, time.timeValue()) % 60)); - } - - /** - * SUBDATE function implementation for ExprValue. - * - * @param functionProperties An FunctionProperties object. - * @param date ExprValue of Time/Date/Datetime/Timestamp type. - * @param days ExprValue of Long type, representing the number of days to subtract. - * @return Date/Datetime resulted from days subtracted to date. - */ - private ExprValue exprSubDateDays(FunctionProperties functionProperties, - ExprValue date, ExprValue days) { - return exprDateApplyDays(functionProperties, date, days.longValue(), false); - } - - /** - * DATE_SUB function implementation for ExprValue. - * - * @param functionProperties An FunctionProperties object. - * @param datetime ExprValue of Time/Date/Datetime/Timestamp type. - * @param expr ExprValue of Interval type, the temporal amount to subtract. - * @return Datetime resulted from expr subtracted to `datetime`. - */ - private ExprValue exprSubDateInterval(FunctionProperties functionProperties, - ExprValue datetime, ExprValue expr) { - return exprDateApplyInterval(functionProperties, datetime, expr.intervalValue(), false); - } - - /** - * Subtracts expr2 from expr1 and returns the result. - * - * @param temporal A Date/Time/Datetime/Timestamp value to change. - * @param temporalDelta A Date/Time/Datetime/Timestamp to subtract time from. - * @return A value calculated. - */ - private ExprValue exprSubTime(FunctionProperties functionProperties, - ExprValue temporal, ExprValue temporalDelta) { - return exprApplyTime(functionProperties, temporal, temporalDelta, false); - } - - private ExprValue exprStrToDate(FunctionProperties fp, - ExprValue dateTimeExpr, - ExprValue formatStringExp) { - return DateTimeFormatterUtil.parseStringWithDateOrTime(fp, dateTimeExpr, formatStringExp); - } - - /** - * Time implementation for ExprValue. - * - * @param exprValue ExprValue of Time type or String. - * @return ExprValue. - */ - private ExprValue exprTime(ExprValue exprValue) { - if (exprValue instanceof ExprStringValue) { - return new ExprTimeValue(exprValue.stringValue()); - } else { - return new ExprTimeValue(exprValue.timeValue()); - } - } - - /** - * Calculate the time difference between two times. - * - * @param first The first value. - * @param second The second value. - * @return The diff. - */ - private ExprValue exprTimeDiff(ExprValue first, ExprValue second) { - // java inverses the value, so we have to swap 1 and 2 - return new ExprTimeValue(LocalTime.MIN.plus( - Duration.between(second.timeValue(), first.timeValue()))); - } - - /** - * Time To Sec implementation for ExprValue. - * - * @param time ExprValue of Time/String type. - * @return ExprValue. - */ - private ExprValue exprTimeToSec(ExprValue time) { - return new ExprLongValue(time.timeValue().toSecondOfDay()); - } - - private ExprValue exprTimestampAdd(ExprValue partExpr, - ExprValue amountExpr, - ExprValue datetimeExpr) { - String part = partExpr.stringValue(); - int amount = amountExpr.integerValue(); - LocalDateTime datetime = datetimeExpr.datetimeValue(); - ChronoUnit temporalUnit; - - switch (part) { - case "MICROSECOND": - temporalUnit = MICROS; - break; - case "SECOND": - temporalUnit = SECONDS; - break; - case "MINUTE": - temporalUnit = MINUTES; - break; - case "HOUR": - temporalUnit = HOURS; - break; - case "DAY": - temporalUnit = DAYS; - break; - case "WEEK": - temporalUnit = WEEKS; - break; - case "MONTH": - temporalUnit = MONTHS; - break; - case "QUARTER": - temporalUnit = MONTHS; - amount *= 3; - break; - case "YEAR": - temporalUnit = YEARS; - break; - default: - return ExprNullValue.of(); } - return new ExprDatetimeValue(datetime.plus(amount, temporalUnit)); - } - - private ExprValue exprTimestampAddForTimeType(Clock clock, - ExprValue partExpr, - ExprValue amountExpr, - ExprValue timeExpr) { - LocalDateTime datetime = LocalDateTime.of( - formatNow(clock).toLocalDate(), - timeExpr.timeValue()); - return exprTimestampAdd(partExpr, amountExpr, new ExprDatetimeValue(datetime)); - } - - private ExprValue getTimeDifference(String part, LocalDateTime startTime, LocalDateTime endTime) { - long returnVal; - switch (part) { - case "MICROSECOND": - returnVal = MICROS.between(startTime, endTime); - break; - case "SECOND": - returnVal = SECONDS.between(startTime, endTime); - break; - case "MINUTE": - returnVal = MINUTES.between(startTime, endTime); - break; - case "HOUR": - returnVal = HOURS.between(startTime, endTime); - break; - case "DAY": - returnVal = DAYS.between(startTime, endTime); - break; - case "WEEK": - returnVal = WEEKS.between(startTime, endTime); - break; - case "MONTH": - returnVal = MONTHS.between(startTime, endTime); - break; - case "QUARTER": - returnVal = MONTHS.between(startTime, endTime) / 3; - break; - case "YEAR": - returnVal = YEARS.between(startTime, endTime); - break; - default: - return ExprNullValue.of(); + + private FunctionResolver now() { + return now(BuiltinFunctionName.NOW.getName()); + } + + private FunctionResolver current_timestamp() { + return now(BuiltinFunctionName.CURRENT_TIMESTAMP.getName()); + } + + private FunctionResolver localtimestamp() { + return now(BuiltinFunctionName.LOCALTIMESTAMP.getName()); + } + + private FunctionResolver localtime() { + return now(BuiltinFunctionName.LOCALTIME.getName()); + } + + /** + * SYSDATE() returns the time at which it executes. + */ + private FunctionResolver sysdate() { + return define( + BuiltinFunctionName.SYSDATE.getName(), + implWithProperties(functionProperties -> new ExprDatetimeValue(formatNow(Clock.systemDefaultZone())), DATETIME), + FunctionDSL.implWithProperties( + (functionProperties, v) -> new ExprDatetimeValue(formatNow(Clock.systemDefaultZone(), v.integerValue())), + DATETIME, + INTEGER + ) + ); + } + + /** + * Synonym for @see `now`. + */ + private FunctionResolver curtime(FunctionName functionName) { + return define( + functionName, + implWithProperties( + functionProperties -> new ExprTimeValue(formatNow(functionProperties.getQueryStartClock()).toLocalTime()), + TIME + ) + ); + } + + private FunctionResolver curtime() { + return curtime(BuiltinFunctionName.CURTIME.getName()); + } + + private FunctionResolver current_time() { + return curtime(BuiltinFunctionName.CURRENT_TIME.getName()); + } + + private FunctionResolver curdate(FunctionName functionName) { + return define( + functionName, + implWithProperties( + functionProperties -> new ExprDateValue(formatNow(functionProperties.getQueryStartClock()).toLocalDate()), + DATE + ) + ); + } + + private FunctionResolver curdate() { + return curdate(BuiltinFunctionName.CURDATE.getName()); + } + + private FunctionResolver current_date() { + return curdate(BuiltinFunctionName.CURRENT_DATE.getName()); + } + + /** + * A common signature for `date_add` and `date_sub`. + * Specify a start date and add/subtract a temporal amount to/from the date. + * The return type depends on the date type and the interval unit. Detailed supported signatures: + * (DATE/DATETIME/TIMESTAMP/TIME, INTERVAL) -> DATETIME + * MySQL has these signatures too + * (DATE, INTERVAL) -> DATE // when interval has no time part + * (TIME, INTERVAL) -> TIME // when interval has no date part + * (STRING, INTERVAL) -> STRING // when argument has date or datetime string, + * // result has date or datetime depending on interval type + */ + private Stream> get_date_add_date_sub_signatures( + SerializableTriFunction function + ) { + return Stream.of( + implWithProperties(nullMissingHandlingWithProperties(function), DATETIME, DATE, INTERVAL), + implWithProperties(nullMissingHandlingWithProperties(function), DATETIME, DATETIME, INTERVAL), + implWithProperties(nullMissingHandlingWithProperties(function), DATETIME, TIMESTAMP, INTERVAL), + implWithProperties(nullMissingHandlingWithProperties(function), DATETIME, TIME, INTERVAL) + ); + } + + /** + * A common signature for `adddate` and `subdate`. + * Adds/subtracts an integer number of days to/from the first argument. + * (DATE, LONG) -> DATE + * (TIME/DATETIME/TIMESTAMP, LONG) -> DATETIME + */ + private Stream> get_adddate_subdate_signatures( + SerializableTriFunction function + ) { + return Stream.of( + implWithProperties(nullMissingHandlingWithProperties(function), DATE, DATE, LONG), + implWithProperties(nullMissingHandlingWithProperties(function), DATETIME, DATETIME, LONG), + implWithProperties(nullMissingHandlingWithProperties(function), DATETIME, TIMESTAMP, LONG), + implWithProperties(nullMissingHandlingWithProperties(function), DATETIME, TIME, LONG) + ); + } + + private DefaultFunctionResolver adddate() { + return define( + BuiltinFunctionName.ADDDATE.getName(), + (SerializableFunction>[]) (Stream.concat( + get_date_add_date_sub_signatures(DateTimeFunction::exprAddDateInterval), + get_adddate_subdate_signatures(DateTimeFunction::exprAddDateDays) + ).toArray(SerializableFunction[]::new)) + ); + } + + /** + * Adds expr2 to expr1 and returns the result. + * (TIME, TIME/DATE/DATETIME/TIMESTAMP) -> TIME + * (DATE/DATETIME/TIMESTAMP, TIME/DATE/DATETIME/TIMESTAMP) -> DATETIME + * TODO: MySQL has these signatures too + * (STRING, STRING/TIME) -> STRING // second arg - string with time only + * (x, STRING) -> NULL // second arg - string with timestamp + * (x, STRING/DATE) -> x // second arg - string with date only + */ + private DefaultFunctionResolver addtime() { + return define( + BuiltinFunctionName.ADDTIME.getName(), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), TIME, TIME, TIME), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), TIME, TIME, DATE), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), TIME, TIME, DATETIME), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), TIME, TIME, TIMESTAMP), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), DATETIME, DATETIME, TIME), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), DATETIME, DATETIME, DATE), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), DATETIME, DATETIME, DATETIME), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), DATETIME, DATETIME, TIMESTAMP), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), DATETIME, DATE, TIME), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), DATETIME, DATE, DATE), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), DATETIME, DATE, DATETIME), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), DATETIME, DATE, TIMESTAMP), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), DATETIME, TIMESTAMP, TIME), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), DATETIME, TIMESTAMP, DATE), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), DATETIME, TIMESTAMP, DATETIME), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprAddTime), DATETIME, TIMESTAMP, TIMESTAMP) + ); + } + + /** + * Converts date/time from a specified timezone to another specified timezone. + * The supported signatures: + * (DATETIME, STRING, STRING) -> DATETIME + * (STRING, STRING, STRING) -> DATETIME + */ + private DefaultFunctionResolver convert_tz() { + return define( + BuiltinFunctionName.CONVERT_TZ.getName(), + impl(nullMissingHandling(DateTimeFunction::exprConvertTZ), DATETIME, DATETIME, STRING, STRING), + impl(nullMissingHandling(DateTimeFunction::exprConvertTZ), DATETIME, STRING, STRING, STRING) + ); + } + + /** + * Extracts the date part of a date and time value. + * Also to construct a date type. The supported signatures: + * STRING/DATE/DATETIME/TIMESTAMP -> DATE + */ + private DefaultFunctionResolver date() { + return define( + BuiltinFunctionName.DATE.getName(), + impl(nullMissingHandling(DateTimeFunction::exprDate), DATE, STRING), + impl(nullMissingHandling(DateTimeFunction::exprDate), DATE, DATE), + impl(nullMissingHandling(DateTimeFunction::exprDate), DATE, DATETIME), + impl(nullMissingHandling(DateTimeFunction::exprDate), DATE, TIMESTAMP) + ); + } + + /* + * Calculates the difference of date part of given values. + * (DATE/DATETIME/TIMESTAMP/TIME, DATE/DATETIME/TIMESTAMP/TIME) -> LONG + */ + private DefaultFunctionResolver datediff() { + return define( + BuiltinFunctionName.DATEDIFF.getName(), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), LONG, DATE, DATE), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), LONG, DATETIME, DATE), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), LONG, DATE, DATETIME), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), LONG, DATETIME, DATETIME), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), LONG, DATE, TIME), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), LONG, TIME, DATE), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), LONG, TIME, TIME), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), LONG, TIMESTAMP, DATE), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), LONG, DATE, TIMESTAMP), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), LONG, TIMESTAMP, TIMESTAMP), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), LONG, TIMESTAMP, TIME), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), LONG, TIME, TIMESTAMP), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), LONG, TIMESTAMP, DATETIME), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), LONG, DATETIME, TIMESTAMP), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), LONG, TIME, DATETIME), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprDateDiff), LONG, DATETIME, TIME) + ); + } + + /** + * Specify a datetime with time zone field and a time zone to convert to. + * Returns a local date time. + * (STRING, STRING) -> DATETIME + * (STRING) -> DATETIME + */ + private FunctionResolver datetime() { + return define( + BuiltinFunctionName.DATETIME.getName(), + impl(nullMissingHandling(DateTimeFunction::exprDateTime), DATETIME, STRING, STRING), + impl(nullMissingHandling(DateTimeFunction::exprDateTimeNoTimezone), DATETIME, STRING) + ); + } + + private DefaultFunctionResolver date_add() { + return define( + BuiltinFunctionName.DATE_ADD.getName(), + (SerializableFunction>[]) get_date_add_date_sub_signatures( + DateTimeFunction::exprAddDateInterval + ).toArray(SerializableFunction[]::new) + ); + } + + private DefaultFunctionResolver date_sub() { + return define( + BuiltinFunctionName.DATE_SUB.getName(), + (SerializableFunction>[]) get_date_add_date_sub_signatures( + DateTimeFunction::exprSubDateInterval + ).toArray(SerializableFunction[]::new) + ); + } + + /** + * DAY(STRING/DATE/DATETIME/TIMESTAMP). return the day of the month (1-31). + */ + private DefaultFunctionResolver day() { + return define( + BuiltinFunctionName.DAY.getName(), + impl(nullMissingHandling(DateTimeFunction::exprDayOfMonth), INTEGER, DATE), + impl(nullMissingHandling(DateTimeFunction::exprDayOfMonth), INTEGER, DATETIME), + impl(nullMissingHandling(DateTimeFunction::exprDayOfMonth), INTEGER, TIMESTAMP), + impl(nullMissingHandling(DateTimeFunction::exprDayOfMonth), INTEGER, STRING) + ); + } + + /** + * DAYNAME(STRING/DATE/DATETIME/TIMESTAMP). + * return the name of the weekday for date, including Monday, Tuesday, Wednesday, + * Thursday, Friday, Saturday and Sunday. + */ + private DefaultFunctionResolver dayName() { + return define( + BuiltinFunctionName.DAYNAME.getName(), + impl(nullMissingHandling(DateTimeFunction::exprDayName), STRING, DATE), + impl(nullMissingHandling(DateTimeFunction::exprDayName), STRING, DATETIME), + impl(nullMissingHandling(DateTimeFunction::exprDayName), STRING, TIMESTAMP), + impl(nullMissingHandling(DateTimeFunction::exprDayName), STRING, STRING) + ); + } + + /** + * DAYOFMONTH(STRING/DATE/DATETIME/TIMESTAMP). return the day of the month (1-31). + */ + private DefaultFunctionResolver dayOfMonth(BuiltinFunctionName name) { + return define( + name.getName(), + implWithProperties( + nullMissingHandlingWithProperties( + (functionProperties, arg) -> DateTimeFunction.dayOfMonthToday(functionProperties.getQueryStartClock()) + ), + INTEGER, + TIME + ), + impl(nullMissingHandling(DateTimeFunction::exprDayOfMonth), INTEGER, DATE), + impl(nullMissingHandling(DateTimeFunction::exprDayOfMonth), INTEGER, DATETIME), + impl(nullMissingHandling(DateTimeFunction::exprDayOfMonth), INTEGER, STRING), + impl(nullMissingHandling(DateTimeFunction::exprDayOfMonth), INTEGER, TIMESTAMP) + ); + } + + /** + * DAYOFWEEK(STRING/DATE/DATETIME/TIME/TIMESTAMP). + * return the weekday index for date (1 = Sunday, 2 = Monday, ..., 7 = Saturday). + */ + private DefaultFunctionResolver dayOfWeek(FunctionName name) { + return define( + name, + implWithProperties( + nullMissingHandlingWithProperties( + (functionProperties, arg) -> DateTimeFunction.dayOfWeekToday(functionProperties.getQueryStartClock()) + ), + INTEGER, + TIME + ), + impl(nullMissingHandling(DateTimeFunction::exprDayOfWeek), INTEGER, DATE), + impl(nullMissingHandling(DateTimeFunction::exprDayOfWeek), INTEGER, DATETIME), + impl(nullMissingHandling(DateTimeFunction::exprDayOfWeek), INTEGER, TIMESTAMP), + impl(nullMissingHandling(DateTimeFunction::exprDayOfWeek), INTEGER, STRING) + ); + } + + /** + * DAYOFYEAR(STRING/DATE/DATETIME/TIMESTAMP). + * return the day of the year for date (1-366). + */ + private DefaultFunctionResolver dayOfYear(BuiltinFunctionName dayOfYear) { + return define( + dayOfYear.getName(), + implWithProperties( + nullMissingHandlingWithProperties( + (functionProperties, arg) -> DateTimeFunction.dayOfYearToday(functionProperties.getQueryStartClock()) + ), + INTEGER, + TIME + ), + impl(nullMissingHandling(DateTimeFunction::exprDayOfYear), INTEGER, DATE), + impl(nullMissingHandling(DateTimeFunction::exprDayOfYear), INTEGER, DATETIME), + impl(nullMissingHandling(DateTimeFunction::exprDayOfYear), INTEGER, TIMESTAMP), + impl(nullMissingHandling(DateTimeFunction::exprDayOfYear), INTEGER, STRING) + ); + } + + private DefaultFunctionResolver extract() { + return define( + BuiltinFunctionName.EXTRACT.getName(), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprExtractForTime), LONG, STRING, TIME), + impl(nullMissingHandling(DateTimeFunction::exprExtract), LONG, STRING, DATE), + impl(nullMissingHandling(DateTimeFunction::exprExtract), LONG, STRING, DATETIME), + impl(nullMissingHandling(DateTimeFunction::exprExtract), LONG, STRING, TIMESTAMP), + impl(nullMissingHandling(DateTimeFunction::exprExtract), LONG, STRING, STRING) + ); + } + + /** + * FROM_DAYS(LONG). return the date value given the day number N. + */ + private DefaultFunctionResolver from_days() { + return define(BuiltinFunctionName.FROM_DAYS.getName(), impl(nullMissingHandling(DateTimeFunction::exprFromDays), DATE, LONG)); + } + + private FunctionResolver from_unixtime() { + return define( + BuiltinFunctionName.FROM_UNIXTIME.getName(), + impl(nullMissingHandling(DateTimeFunction::exprFromUnixTime), DATETIME, DOUBLE), + impl(nullMissingHandling(DateTimeFunction::exprFromUnixTimeFormat), STRING, DOUBLE, STRING) + ); + } + + private DefaultFunctionResolver get_format() { + return define( + BuiltinFunctionName.GET_FORMAT.getName(), + impl(nullMissingHandling(DateTimeFunction::exprGetFormat), STRING, STRING, STRING) + ); + } + + /** + * HOUR(STRING/TIME/DATETIME/DATE/TIMESTAMP). return the hour value for time. + */ + private DefaultFunctionResolver hour(BuiltinFunctionName name) { + return define( + name.getName(), + impl(nullMissingHandling(DateTimeFunction::exprHour), INTEGER, STRING), + impl(nullMissingHandling(DateTimeFunction::exprHour), INTEGER, TIME), + impl(nullMissingHandling(DateTimeFunction::exprHour), INTEGER, DATE), + impl(nullMissingHandling(DateTimeFunction::exprHour), INTEGER, DATETIME), + impl(nullMissingHandling(DateTimeFunction::exprHour), INTEGER, TIMESTAMP) + ); + } + + private DefaultFunctionResolver last_day() { + return define( + BuiltinFunctionName.LAST_DAY.getName(), + impl(nullMissingHandling(DateTimeFunction::exprLastDay), DATE, STRING), + implWithProperties( + nullMissingHandlingWithProperties( + (functionProperties, arg) -> DateTimeFunction.exprLastDayToday(functionProperties.getQueryStartClock()) + ), + DATE, + TIME + ), + impl(nullMissingHandling(DateTimeFunction::exprLastDay), DATE, DATE), + impl(nullMissingHandling(DateTimeFunction::exprLastDay), DATE, DATETIME), + impl(nullMissingHandling(DateTimeFunction::exprLastDay), DATE, TIMESTAMP) + ); + } + + private FunctionResolver makedate() { + return define( + BuiltinFunctionName.MAKEDATE.getName(), + impl(nullMissingHandling(DateTimeFunction::exprMakeDate), DATE, DOUBLE, DOUBLE) + ); + } + + private FunctionResolver maketime() { + return define( + BuiltinFunctionName.MAKETIME.getName(), + impl(nullMissingHandling(DateTimeFunction::exprMakeTime), TIME, DOUBLE, DOUBLE, DOUBLE) + ); + } + + /** + * MICROSECOND(STRING/TIME/DATETIME/TIMESTAMP). return the microsecond value for time. + */ + private DefaultFunctionResolver microsecond() { + return define( + BuiltinFunctionName.MICROSECOND.getName(), + impl(nullMissingHandling(DateTimeFunction::exprMicrosecond), INTEGER, STRING), + impl(nullMissingHandling(DateTimeFunction::exprMicrosecond), INTEGER, TIME), + impl(nullMissingHandling(DateTimeFunction::exprMicrosecond), INTEGER, DATETIME), + impl(nullMissingHandling(DateTimeFunction::exprMicrosecond), INTEGER, TIMESTAMP) + ); + } + + /** + * MINUTE(STRING/TIME/DATETIME/TIMESTAMP). return the minute value for time. + */ + private DefaultFunctionResolver minute(BuiltinFunctionName name) { + return define( + name.getName(), + impl(nullMissingHandling(DateTimeFunction::exprMinute), INTEGER, STRING), + impl(nullMissingHandling(DateTimeFunction::exprMinute), INTEGER, TIME), + impl(nullMissingHandling(DateTimeFunction::exprMinute), INTEGER, DATETIME), + impl(nullMissingHandling(DateTimeFunction::exprMinute), INTEGER, DATE), + impl(nullMissingHandling(DateTimeFunction::exprMinute), INTEGER, TIMESTAMP) + ); + } + + /** + * MINUTE(STRING/TIME/DATETIME/TIMESTAMP). return the minute value for time. + */ + private DefaultFunctionResolver minute_of_day() { + return define( + BuiltinFunctionName.MINUTE_OF_DAY.getName(), + impl(nullMissingHandling(DateTimeFunction::exprMinuteOfDay), INTEGER, STRING), + impl(nullMissingHandling(DateTimeFunction::exprMinuteOfDay), INTEGER, TIME), + impl(nullMissingHandling(DateTimeFunction::exprMinuteOfDay), INTEGER, DATE), + impl(nullMissingHandling(DateTimeFunction::exprMinuteOfDay), INTEGER, DATETIME), + impl(nullMissingHandling(DateTimeFunction::exprMinuteOfDay), INTEGER, TIMESTAMP) + ); + } + + /** + * MONTH(STRING/DATE/DATETIME/TIMESTAMP). return the month for date (1-12). + */ + private DefaultFunctionResolver month(BuiltinFunctionName month) { + return define( + month.getName(), + implWithProperties( + nullMissingHandlingWithProperties( + (functionProperties, arg) -> DateTimeFunction.monthOfYearToday(functionProperties.getQueryStartClock()) + ), + INTEGER, + TIME + ), + impl(nullMissingHandling(DateTimeFunction::exprMonth), INTEGER, DATE), + impl(nullMissingHandling(DateTimeFunction::exprMonth), INTEGER, DATETIME), + impl(nullMissingHandling(DateTimeFunction::exprMonth), INTEGER, TIMESTAMP), + impl(nullMissingHandling(DateTimeFunction::exprMonth), INTEGER, STRING) + ); + } + + /** + * MONTHNAME(STRING/DATE/DATETIME/TIMESTAMP). return the full name of the month for date. + */ + private DefaultFunctionResolver monthName() { + return define( + BuiltinFunctionName.MONTHNAME.getName(), + impl(nullMissingHandling(DateTimeFunction::exprMonthName), STRING, DATE), + impl(nullMissingHandling(DateTimeFunction::exprMonthName), STRING, DATETIME), + impl(nullMissingHandling(DateTimeFunction::exprMonthName), STRING, TIMESTAMP), + impl(nullMissingHandling(DateTimeFunction::exprMonthName), STRING, STRING) + ); + } + + /** + * Add N months to period P (in the format YYMM or YYYYMM). Returns a value in the format YYYYMM. + * (INTEGER, INTEGER) -> INTEGER + */ + private DefaultFunctionResolver period_add() { + return define( + BuiltinFunctionName.PERIOD_ADD.getName(), + impl(nullMissingHandling(DateTimeFunction::exprPeriodAdd), INTEGER, INTEGER, INTEGER) + ); + } + + /** + * Returns the number of months between periods P1 and P2. + * P1 and P2 should be in the format YYMM or YYYYMM. + * (INTEGER, INTEGER) -> INTEGER + */ + private DefaultFunctionResolver period_diff() { + return define( + BuiltinFunctionName.PERIOD_DIFF.getName(), + impl(nullMissingHandling(DateTimeFunction::exprPeriodDiff), INTEGER, INTEGER, INTEGER) + ); + } + + /** + * QUARTER(STRING/DATE/DATETIME/TIMESTAMP). return the month for date (1-4). + */ + private DefaultFunctionResolver quarter() { + return define( + BuiltinFunctionName.QUARTER.getName(), + impl(nullMissingHandling(DateTimeFunction::exprQuarter), INTEGER, DATE), + impl(nullMissingHandling(DateTimeFunction::exprQuarter), INTEGER, DATETIME), + impl(nullMissingHandling(DateTimeFunction::exprQuarter), INTEGER, TIMESTAMP), + impl(nullMissingHandling(DateTimeFunction::exprQuarter), INTEGER, STRING) + ); + } + + private DefaultFunctionResolver sec_to_time() { + return define( + BuiltinFunctionName.SEC_TO_TIME.getName(), + impl((nullMissingHandling(DateTimeFunction::exprSecToTime)), TIME, INTEGER), + impl((nullMissingHandling(DateTimeFunction::exprSecToTime)), TIME, LONG), + impl((nullMissingHandling(DateTimeFunction::exprSecToTimeWithNanos)), TIME, DOUBLE), + impl((nullMissingHandling(DateTimeFunction::exprSecToTimeWithNanos)), TIME, FLOAT) + ); + } + + /** + * SECOND(STRING/TIME/DATETIME/TIMESTAMP). return the second value for time. + */ + private DefaultFunctionResolver second(BuiltinFunctionName name) { + return define( + name.getName(), + impl(nullMissingHandling(DateTimeFunction::exprSecond), INTEGER, STRING), + impl(nullMissingHandling(DateTimeFunction::exprSecond), INTEGER, TIME), + impl(nullMissingHandling(DateTimeFunction::exprSecond), INTEGER, DATE), + impl(nullMissingHandling(DateTimeFunction::exprSecond), INTEGER, DATETIME), + impl(nullMissingHandling(DateTimeFunction::exprSecond), INTEGER, TIMESTAMP) + ); + } + + private DefaultFunctionResolver subdate() { + return define( + BuiltinFunctionName.SUBDATE.getName(), + (SerializableFunction>[]) (Stream.concat( + get_date_add_date_sub_signatures(DateTimeFunction::exprSubDateInterval), + get_adddate_subdate_signatures(DateTimeFunction::exprSubDateDays) + ).toArray(SerializableFunction[]::new)) + ); + } + + /** + * Subtracts expr2 from expr1 and returns the result. + * (TIME, TIME/DATE/DATETIME/TIMESTAMP) -> TIME + * (DATE/DATETIME/TIMESTAMP, TIME/DATE/DATETIME/TIMESTAMP) -> DATETIME + * TODO: MySQL has these signatures too + * (STRING, STRING/TIME) -> STRING // second arg - string with time only + * (x, STRING) -> NULL // second arg - string with timestamp + * (x, STRING/DATE) -> x // second arg - string with date only + */ + private DefaultFunctionResolver subtime() { + return define( + BuiltinFunctionName.SUBTIME.getName(), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), TIME, TIME, TIME), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), TIME, TIME, DATE), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), TIME, TIME, DATETIME), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), TIME, TIME, TIMESTAMP), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), DATETIME, DATETIME, TIME), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), DATETIME, DATETIME, DATE), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), DATETIME, DATETIME, DATETIME), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), DATETIME, DATETIME, TIMESTAMP), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), DATETIME, DATE, TIME), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), DATETIME, DATE, DATE), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), DATETIME, DATE, DATETIME), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), DATETIME, DATE, TIMESTAMP), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), DATETIME, TIMESTAMP, TIME), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), DATETIME, TIMESTAMP, DATE), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), DATETIME, TIMESTAMP, DATETIME), + implWithProperties(nullMissingHandlingWithProperties(DateTimeFunction::exprSubTime), DATETIME, TIMESTAMP, TIMESTAMP) + ); + } + + /** + * Extracts a date, time, or datetime from the given string. + * It accomplishes this using another string which specifies the input format. + */ + private DefaultFunctionResolver str_to_date() { + return define( + BuiltinFunctionName.STR_TO_DATE.getName(), + implWithProperties( + nullMissingHandlingWithProperties( + (functionProperties, arg, format) -> DateTimeFunction.exprStrToDate(functionProperties, arg, format) + ), + DATETIME, + STRING, + STRING + ) + ); + } + + /** + * Extracts the time part of a date and time value. + * Also to construct a time type. The supported signatures: + * STRING/DATE/DATETIME/TIME/TIMESTAMP -> TIME + */ + private DefaultFunctionResolver time() { + return define( + BuiltinFunctionName.TIME.getName(), + impl(nullMissingHandling(DateTimeFunction::exprTime), TIME, STRING), + impl(nullMissingHandling(DateTimeFunction::exprTime), TIME, DATE), + impl(nullMissingHandling(DateTimeFunction::exprTime), TIME, DATETIME), + impl(nullMissingHandling(DateTimeFunction::exprTime), TIME, TIME), + impl(nullMissingHandling(DateTimeFunction::exprTime), TIME, TIMESTAMP) + ); + } + + /** + * Returns different between two times as a time. + * (TIME, TIME) -> TIME + * MySQL has these signatures too + * (DATE, DATE) -> TIME // result is > 24 hours + * (DATETIME, DATETIME) -> TIME // result is > 24 hours + * (TIMESTAMP, TIMESTAMP) -> TIME // result is > 24 hours + * (x, x) -> NULL // when args have different types + * (STRING, STRING) -> TIME // argument strings contain same types only + * (STRING, STRING) -> NULL // argument strings are different types + */ + private DefaultFunctionResolver timediff() { + return define(BuiltinFunctionName.TIMEDIFF.getName(), impl(nullMissingHandling(DateTimeFunction::exprTimeDiff), TIME, TIME, TIME)); + } + + /** + * TIME_TO_SEC(STRING/TIME/DATETIME/TIMESTAMP). return the time argument, converted to seconds. + */ + private DefaultFunctionResolver time_to_sec() { + return define( + BuiltinFunctionName.TIME_TO_SEC.getName(), + impl(nullMissingHandling(DateTimeFunction::exprTimeToSec), LONG, STRING), + impl(nullMissingHandling(DateTimeFunction::exprTimeToSec), LONG, TIME), + impl(nullMissingHandling(DateTimeFunction::exprTimeToSec), LONG, TIMESTAMP), + impl(nullMissingHandling(DateTimeFunction::exprTimeToSec), LONG, DATETIME) + ); + } + + /** + * Extracts the timestamp of a date and time value. + * Input strings may contain a timestamp only in format 'yyyy-MM-dd HH:mm:ss[.SSSSSSSSS]' + * STRING/DATE/TIME/DATETIME/TIMESTAMP -> TIMESTAMP + * STRING/DATE/TIME/DATETIME/TIMESTAMP, STRING/DATE/TIME/DATETIME/TIMESTAMP -> TIMESTAMP + * All types are converted to TIMESTAMP actually before the function call - it is responsibility + * of the automatic cast mechanism defined in `ExprCoreType` and performed by `TypeCastOperator`. + */ + private DefaultFunctionResolver timestamp() { + return define( + BuiltinFunctionName.TIMESTAMP.getName(), + impl(nullMissingHandling(v -> v), TIMESTAMP, TIMESTAMP), + // We can use FunctionProperties.None, because it is not used. It is required to convert + // TIME to other datetime types, but arguments there are already converted. + impl(nullMissingHandling((v1, v2) -> exprAddTime(FunctionProperties.None, v1, v2)), TIMESTAMP, TIMESTAMP, TIMESTAMP) + ); + } + + /** + * Adds an interval of time to the provided DATE/DATETIME/TIME/TIMESTAMP/STRING argument. + * The interval of time added is determined by the given first and second arguments. + * The first argument is an interval type, and must be one of the tokens below... + * [MICROSECOND, SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, YEAR] + * The second argument is the amount of the interval type to be added. + * The third argument is the DATE/DATETIME/TIME/TIMESTAMP/STRING to add to. + * @return The DATETIME representing the summed DATE/DATETIME/TIME/TIMESTAMP and interval. + */ + private DefaultFunctionResolver timestampadd() { + return define( + BuiltinFunctionName.TIMESTAMPADD.getName(), + impl(nullMissingHandling(DateTimeFunction::exprTimestampAdd), DATETIME, STRING, INTEGER, DATETIME), + impl(nullMissingHandling(DateTimeFunction::exprTimestampAdd), DATETIME, STRING, INTEGER, TIMESTAMP), + implWithProperties( + nullMissingHandlingWithProperties( + (functionProperties, part, amount, time) -> exprTimestampAddForTimeType( + functionProperties.getQueryStartClock(), + part, + amount, + time + ) + ), + DATETIME, + STRING, + INTEGER, + TIME + ) + ); + } + + /** + * Finds the difference between provided DATE/DATETIME/TIME/TIMESTAMP/STRING arguments. + * The first argument is an interval type, and must be one of the tokens below... + * [MICROSECOND, SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, YEAR] + * The second argument the DATE/DATETIME/TIME/TIMESTAMP/STRING representing the start time. + * The third argument is the DATE/DATETIME/TIME/TIMESTAMP/STRING representing the end time. + * @return A LONG representing the difference between arguments, using the given interval type. + */ + private DefaultFunctionResolver timestampdiff() { + return define( + BuiltinFunctionName.TIMESTAMPDIFF.getName(), + impl(nullMissingHandling(DateTimeFunction::exprTimestampDiff), DATETIME, STRING, DATETIME, DATETIME), + impl(nullMissingHandling(DateTimeFunction::exprTimestampDiff), DATETIME, STRING, DATETIME, TIMESTAMP), + impl(nullMissingHandling(DateTimeFunction::exprTimestampDiff), DATETIME, STRING, TIMESTAMP, DATETIME), + impl(nullMissingHandling(DateTimeFunction::exprTimestampDiff), DATETIME, STRING, TIMESTAMP, TIMESTAMP), + implWithProperties( + nullMissingHandlingWithProperties( + (functionProperties, part, startTime, endTime) -> exprTimestampDiffForTimeType( + functionProperties, + part, + startTime, + endTime + ) + ), + DATETIME, + STRING, + TIME, + TIME + ) + ); + } + + /** + * TO_DAYS(STRING/DATE/DATETIME/TIMESTAMP). return the day number of the given date. + */ + private DefaultFunctionResolver to_days() { + return define( + BuiltinFunctionName.TO_DAYS.getName(), + impl(nullMissingHandling(DateTimeFunction::exprToDays), LONG, STRING), + impl(nullMissingHandling(DateTimeFunction::exprToDays), LONG, TIMESTAMP), + impl(nullMissingHandling(DateTimeFunction::exprToDays), LONG, DATE), + impl(nullMissingHandling(DateTimeFunction::exprToDays), LONG, DATETIME) + ); + } + + /** + * TO_SECONDS(TIMESTAMP/LONG). return the seconds number of the given date. + * Arguments of type STRING/TIMESTAMP/LONG are also accepted. + * STRING/TIMESTAMP/LONG arguments are automatically cast to TIMESTAMP. + */ + private DefaultFunctionResolver to_seconds() { + return define( + BuiltinFunctionName.TO_SECONDS.getName(), + impl(nullMissingHandling(DateTimeFunction::exprToSeconds), LONG, TIMESTAMP), + impl(nullMissingHandling(DateTimeFunction::exprToSecondsForIntType), LONG, LONG) + ); + } + + private FunctionResolver unix_timestamp() { + return define( + BuiltinFunctionName.UNIX_TIMESTAMP.getName(), + implWithProperties(functionProperties -> DateTimeFunction.unixTimeStamp(functionProperties.getQueryStartClock()), LONG), + impl(nullMissingHandling(DateTimeFunction::unixTimeStampOf), DOUBLE, DATE), + impl(nullMissingHandling(DateTimeFunction::unixTimeStampOf), DOUBLE, DATETIME), + impl(nullMissingHandling(DateTimeFunction::unixTimeStampOf), DOUBLE, TIMESTAMP), + impl(nullMissingHandling(DateTimeFunction::unixTimeStampOf), DOUBLE, DOUBLE) + ); + } + + /** + * UTC_DATE(). return the current UTC Date in format yyyy-MM-dd + */ + private DefaultFunctionResolver utc_date() { + return define( + BuiltinFunctionName.UTC_DATE.getName(), + implWithProperties(functionProperties -> exprUtcDate(functionProperties), DATE) + ); + } + + /** + * UTC_TIME(). return the current UTC Time in format HH:mm:ss + */ + private DefaultFunctionResolver utc_time() { + return define( + BuiltinFunctionName.UTC_TIME.getName(), + implWithProperties(functionProperties -> exprUtcTime(functionProperties), TIME) + ); + } + + /** + * UTC_TIMESTAMP(). return the current UTC TimeStamp in format yyyy-MM-dd HH:mm:ss + */ + private DefaultFunctionResolver utc_timestamp() { + return define( + BuiltinFunctionName.UTC_TIMESTAMP.getName(), + implWithProperties(functionProperties -> exprUtcTimeStamp(functionProperties), DATETIME) + ); + } + + /** + * WEEK(DATE[,mode]). return the week number for date. + */ + private DefaultFunctionResolver week(BuiltinFunctionName week) { + return define( + week.getName(), + implWithProperties( + nullMissingHandlingWithProperties( + (functionProperties, arg) -> DateTimeFunction.weekOfYearToday( + DEFAULT_WEEK_OF_YEAR_MODE, + functionProperties.getQueryStartClock() + ) + ), + INTEGER, + TIME + ), + impl(nullMissingHandling(DateTimeFunction::exprWeekWithoutMode), INTEGER, DATE), + impl(nullMissingHandling(DateTimeFunction::exprWeekWithoutMode), INTEGER, DATETIME), + impl(nullMissingHandling(DateTimeFunction::exprWeekWithoutMode), INTEGER, TIMESTAMP), + impl(nullMissingHandling(DateTimeFunction::exprWeekWithoutMode), INTEGER, STRING), + implWithProperties( + nullMissingHandlingWithProperties( + (functionProperties, time, modeArg) -> DateTimeFunction.weekOfYearToday( + modeArg, + functionProperties.getQueryStartClock() + ) + ), + INTEGER, + TIME, + INTEGER + ), + impl(nullMissingHandling(DateTimeFunction::exprWeek), INTEGER, DATE, INTEGER), + impl(nullMissingHandling(DateTimeFunction::exprWeek), INTEGER, DATETIME, INTEGER), + impl(nullMissingHandling(DateTimeFunction::exprWeek), INTEGER, TIMESTAMP, INTEGER), + impl(nullMissingHandling(DateTimeFunction::exprWeek), INTEGER, STRING, INTEGER) + ); + } + + private DefaultFunctionResolver weekday() { + return define( + BuiltinFunctionName.WEEKDAY.getName(), + implWithProperties( + nullMissingHandlingWithProperties( + (functionProperties, arg) -> new ExprIntegerValue( + formatNow(functionProperties.getQueryStartClock()).getDayOfWeek().getValue() - 1 + ) + ), + INTEGER, + TIME + ), + impl(nullMissingHandling(DateTimeFunction::exprWeekday), INTEGER, DATE), + impl(nullMissingHandling(DateTimeFunction::exprWeekday), INTEGER, DATETIME), + impl(nullMissingHandling(DateTimeFunction::exprWeekday), INTEGER, TIMESTAMP), + impl(nullMissingHandling(DateTimeFunction::exprWeekday), INTEGER, STRING) + ); + } + + /** + * YEAR(STRING/DATE/DATETIME/TIMESTAMP). return the year for date (1000-9999). + */ + private DefaultFunctionResolver year() { + return define( + BuiltinFunctionName.YEAR.getName(), + impl(nullMissingHandling(DateTimeFunction::exprYear), INTEGER, DATE), + impl(nullMissingHandling(DateTimeFunction::exprYear), INTEGER, DATETIME), + impl(nullMissingHandling(DateTimeFunction::exprYear), INTEGER, TIMESTAMP), + impl(nullMissingHandling(DateTimeFunction::exprYear), INTEGER, STRING) + ); + } + + /** + * YEARWEEK(DATE[,mode]). return the week number for date. + */ + private DefaultFunctionResolver yearweek() { + return define( + BuiltinFunctionName.YEARWEEK.getName(), + implWithProperties( + nullMissingHandlingWithProperties( + (functionProperties, arg) -> yearweekToday(DEFAULT_WEEK_OF_YEAR_MODE, functionProperties.getQueryStartClock()) + ), + INTEGER, + TIME + ), + impl(nullMissingHandling(DateTimeFunction::exprYearweekWithoutMode), INTEGER, DATE), + impl(nullMissingHandling(DateTimeFunction::exprYearweekWithoutMode), INTEGER, DATETIME), + impl(nullMissingHandling(DateTimeFunction::exprYearweekWithoutMode), INTEGER, TIMESTAMP), + impl(nullMissingHandling(DateTimeFunction::exprYearweekWithoutMode), INTEGER, STRING), + implWithProperties( + nullMissingHandlingWithProperties( + (functionProperties, time, modeArg) -> yearweekToday(modeArg, functionProperties.getQueryStartClock()) + ), + INTEGER, + TIME, + INTEGER + ), + impl(nullMissingHandling(DateTimeFunction::exprYearweek), INTEGER, DATE, INTEGER), + impl(nullMissingHandling(DateTimeFunction::exprYearweek), INTEGER, DATETIME, INTEGER), + impl(nullMissingHandling(DateTimeFunction::exprYearweek), INTEGER, TIMESTAMP, INTEGER), + impl(nullMissingHandling(DateTimeFunction::exprYearweek), INTEGER, STRING, INTEGER) + ); + } + + /** + * Formats date according to format specifier. First argument is date, second is format. + * Detailed supported signatures: + * (STRING, STRING) -> STRING + * (DATE, STRING) -> STRING + * (DATETIME, STRING) -> STRING + * (TIME, STRING) -> STRING + * (TIMESTAMP, STRING) -> STRING + */ + private DefaultFunctionResolver date_format() { + return define( + BuiltinFunctionName.DATE_FORMAT.getName(), + impl(nullMissingHandling(DateTimeFormatterUtil::getFormattedDate), STRING, STRING, STRING), + impl(nullMissingHandling(DateTimeFormatterUtil::getFormattedDate), STRING, DATE, STRING), + impl(nullMissingHandling(DateTimeFormatterUtil::getFormattedDate), STRING, DATETIME, STRING), + implWithProperties( + nullMissingHandlingWithProperties( + (functionProperties, time, formatString) -> DateTimeFormatterUtil.getFormattedDateOfToday( + formatString, + time, + functionProperties.getQueryStartClock() + ) + ), + STRING, + TIME, + STRING + ), + impl(nullMissingHandling(DateTimeFormatterUtil::getFormattedDate), STRING, TIMESTAMP, STRING) + ); + } + + private ExprValue dayOfMonthToday(Clock clock) { + return new ExprIntegerValue(LocalDateTime.now(clock).getDayOfMonth()); + } + + private ExprValue dayOfYearToday(Clock clock) { + return new ExprIntegerValue(LocalDateTime.now(clock).getDayOfYear()); + } + + private ExprValue weekOfYearToday(ExprValue mode, Clock clock) { + return new ExprIntegerValue(CalendarLookup.getWeekNumber(mode.integerValue(), LocalDateTime.now(clock).toLocalDate())); + } + + /** + * Day of Week implementation for ExprValue when passing in an arguemt of type TIME. + * + * @param clock Current clock taken from function properties + * @return ExprValue. + */ + private ExprValue dayOfWeekToday(Clock clock) { + return new ExprIntegerValue((formatNow(clock).getDayOfWeek().getValue() % 7) + 1); } - return new ExprLongValue(returnVal); - } - - private ExprValue exprTimestampDiff( - ExprValue partExpr, - ExprValue startTimeExpr, - ExprValue endTimeExpr) { - return getTimeDifference( - partExpr.stringValue(), - startTimeExpr.datetimeValue(), - endTimeExpr.datetimeValue()); - } - - private ExprValue exprTimestampDiffForTimeType(FunctionProperties fp, - ExprValue partExpr, - ExprValue startTimeExpr, - ExprValue endTimeExpr) { - return getTimeDifference( - partExpr.stringValue(), - extractDateTime(startTimeExpr, fp), - extractDateTime(endTimeExpr, fp)); - } - - /** - * UTC_DATE implementation for ExprValue. - * - * @param functionProperties FunctionProperties. - * @return ExprValue. - */ - private ExprValue exprUtcDate(FunctionProperties functionProperties) { - return new ExprDateValue(exprUtcTimeStamp(functionProperties).dateValue()); - } - - /** - * UTC_TIME implementation for ExprValue. - * - * @param functionProperties FunctionProperties. - * @return ExprValue. - */ - private ExprValue exprUtcTime(FunctionProperties functionProperties) { - return new ExprTimeValue(exprUtcTimeStamp(functionProperties).timeValue()); - } - - /** - * UTC_TIMESTAMP implementation for ExprValue. - * - * @param functionProperties FunctionProperties. - * @return ExprValue. - */ - private ExprValue exprUtcTimeStamp(FunctionProperties functionProperties) { - var zdt = ZonedDateTime.now(functionProperties.getQueryStartClock()) - .withZoneSameInstant(UTC_ZONE_ID); - return new ExprDatetimeValue(zdt.toLocalDateTime()); - } - - /** - * To_days implementation for ExprValue. - * - * @param date ExprValue of Date/String type. - * @return ExprValue. - */ - private ExprValue exprToDays(ExprValue date) { - return new ExprLongValue(date.dateValue().toEpochDay() + DAYS_0000_TO_1970); - } - - /** - * To_seconds implementation for ExprValue. - * - * @param date ExprValue of Date/Datetime/Timestamp/String type. - * @return ExprValue. - */ - private ExprValue exprToSeconds(ExprValue date) { - return new ExprLongValue( - date.datetimeValue().toEpochSecond(ZoneOffset.UTC) + DAYS_0000_TO_1970 * SECONDS_PER_DAY); - } - - /** - * Helper function to determine the correct formatter for date arguments passed in as integers. - * - * @param dateAsInt is an integer formatted as one of YYYYMMDD, YYMMDD, YMMDD, MMDD, MDD - * @return is a DateTimeFormatter that can parse the input. - */ - private DateTimeFormatter getFormatter(int dateAsInt) { - int length = String.format("%d", dateAsInt).length(); - - if (length > 8) { - throw new DateTimeException("Integer argument was out of range"); - } - - //Check below from YYYYMMDD - MMDD which format should be used - switch (length) { - //Check if dateAsInt is at least 8 digits long - case FULL_DATE_LENGTH: - return DATE_FORMATTER_LONG_YEAR; - - //Check if dateAsInt is at least 6 digits long - case SHORT_DATE_LENGTH: - return DATE_FORMATTER_SHORT_YEAR; - - //Check if dateAsInt is at least 5 digits long - case SINGLE_DIGIT_YEAR_DATE_LENGTH: - return DATE_FORMATTER_SINGLE_DIGIT_YEAR; - - //Check if dateAsInt is at least 4 digits long - case NO_YEAR_DATE_LENGTH: - return DATE_FORMATTER_NO_YEAR; - - //Check if dateAsInt is at least 3 digits long - case SINGLE_DIGIT_MONTH_DATE_LENGTH: - return DATE_FORMATTER_SINGLE_DIGIT_MONTH; - - default: - break; - } - - throw new DateTimeException("No Matching Format"); - } - - /** - * To_seconds implementation with an integer argument for ExprValue. - * - * @param dateExpr ExprValue of an Integer/Long formatted for a date (e.g., 950501 = 1995-05-01) - * @return ExprValue. - */ - private ExprValue exprToSecondsForIntType(ExprValue dateExpr) { - try { - //Attempt to parse integer argument as date - LocalDate date = LocalDate.parse(String.valueOf(dateExpr.integerValue()), - getFormatter(dateExpr.integerValue())); - - return new ExprLongValue(date.toEpochSecond(LocalTime.MIN, ZoneOffset.UTC) - + DAYS_0000_TO_1970 * SECONDS_PER_DAY); - - } catch (DateTimeException ignored) { - //Return null if parsing error - return ExprNullValue.of(); - } - } - - /** - * Week for date implementation for ExprValue. - * - * @param date ExprValue of Date/Datetime/Timestamp/String type. - * @param mode ExprValue of Integer type. - */ - private ExprValue exprWeek(ExprValue date, ExprValue mode) { - return new ExprIntegerValue( - CalendarLookup.getWeekNumber(mode.integerValue(), date.dateValue())); - } - - /** - * Weekday implementation for ExprValue. - * - * @param date ExprValue of Date/Datetime/String/Timstamp type. - * @return ExprValue. - */ - private ExprValue exprWeekday(ExprValue date) { - return new ExprIntegerValue(date.dateValue().getDayOfWeek().getValue() - 1); - } - - private ExprValue unixTimeStamp(Clock clock) { - return new ExprLongValue(Instant.now(clock).getEpochSecond()); - } - - private ExprValue unixTimeStampOf(ExprValue value) { - var res = unixTimeStampOfImpl(value); - if (res == null) { - return ExprNullValue.of(); - } - if (res < 0) { - // According to MySQL returns 0 if year < 1970, don't return negative values as java does. - return new ExprDoubleValue(0); - } - if (res >= MYSQL_MAX_TIMESTAMP) { - // Return 0 also for dates > '3001-01-19 03:14:07.999999' UTC (32536771199.999999 sec) - return new ExprDoubleValue(0); - } - return new ExprDoubleValue(res); - } - - private Double unixTimeStampOfImpl(ExprValue value) { - // Also, according to MySQL documentation: - // The date argument may be a DATE, DATETIME, or TIMESTAMP ... - switch ((ExprCoreType)value.type()) { - case DATE: return value.dateValue().toEpochSecond(LocalTime.MIN, ZoneOffset.UTC) + 0d; - case DATETIME: return value.datetimeValue().toEpochSecond(ZoneOffset.UTC) - + value.datetimeValue().getNano() / 1E9; - case TIMESTAMP: return value.timestampValue().getEpochSecond() - + value.timestampValue().getNano() / 1E9; - default: - // ... or a number in YYMMDD, YYMMDDhhmmss, YYYYMMDD, or YYYYMMDDhhmmss format. - // If the argument includes a time part, it may optionally include a fractional - // seconds part. - - var format = new DecimalFormat("0.#"); - format.setMinimumFractionDigits(0); - format.setMaximumFractionDigits(6); - String input = format.format(value.doubleValue()); - double fraction = 0; - if (input.contains(".")) { - // Keeping fraction second part and adding it to the result, don't parse it - // Because `toEpochSecond` returns only `long` - // input = 12345.6789 becomes input = 12345 and fraction = 0.6789 - fraction = value.doubleValue() - Math.round(Math.ceil(value.doubleValue())); - input = input.substring(0, input.indexOf('.')); + + /** + * DATE_ADD function implementation for ExprValue. + * + * @param functionProperties An FunctionProperties object. + * @param datetime ExprValue of Date/Time/Datetime/Timestamp type. + * @param interval ExprValue of Interval type, the temporal amount to add. + * @return Datetime resulted from `interval` added to `datetime`. + */ + private ExprValue exprAddDateInterval(FunctionProperties functionProperties, ExprValue datetime, ExprValue interval) { + return exprDateApplyInterval(functionProperties, datetime, interval.intervalValue(), true); + } + + /** + * Adds or subtracts `interval` to/from `datetime`. + * + * @param functionProperties An FunctionProperties object. + * @param datetime A Date/Time/Datetime/Timestamp value to change. + * @param interval An Interval to isAdd or subtract. + * @param isAdd A flag: true to isAdd, false to subtract. + * @return Datetime calculated. + */ + private ExprValue exprDateApplyInterval( + FunctionProperties functionProperties, + ExprValue datetime, + TemporalAmount interval, + Boolean isAdd + ) { + var dt = extractDateTime(datetime, functionProperties); + return new ExprDatetimeValue(isAdd ? dt.plus(interval) : dt.minus(interval)); + } + + /** + * Formats date according to format specifier. First argument is time, second is format. + * Detailed supported signatures: + * (STRING, STRING) -> STRING + * (DATE, STRING) -> STRING + * (DATETIME, STRING) -> STRING + * (TIME, STRING) -> STRING + * (TIMESTAMP, STRING) -> STRING + */ + private DefaultFunctionResolver time_format() { + return define( + BuiltinFunctionName.TIME_FORMAT.getName(), + impl(nullMissingHandling(DateTimeFormatterUtil::getFormattedTime), STRING, STRING, STRING), + impl(nullMissingHandling(DateTimeFormatterUtil::getFormattedTime), STRING, DATE, STRING), + impl(nullMissingHandling(DateTimeFormatterUtil::getFormattedTime), STRING, DATETIME, STRING), + impl(nullMissingHandling(DateTimeFormatterUtil::getFormattedTime), STRING, TIME, STRING), + impl(nullMissingHandling(DateTimeFormatterUtil::getFormattedTime), STRING, TIMESTAMP, STRING) + ); + } + + /** + * ADDDATE function implementation for ExprValue. + * + * @param functionProperties An FunctionProperties object. + * @param datetime ExprValue of Time/Date/Datetime/Timestamp type. + * @param days ExprValue of Long type, representing the number of days to add. + * @return Date/Datetime resulted from days added to `datetime`. + */ + private ExprValue exprAddDateDays(FunctionProperties functionProperties, ExprValue datetime, ExprValue days) { + return exprDateApplyDays(functionProperties, datetime, days.longValue(), true); + } + + /** + * Adds or subtracts `days` to/from `datetime`. + * + * @param functionProperties An FunctionProperties object. + * @param datetime A Date/Time/Datetime/Timestamp value to change. + * @param days A days amount to add or subtract. + * @param isAdd A flag: true to add, false to subtract. + * @return Datetime calculated. + */ + private ExprValue exprDateApplyDays(FunctionProperties functionProperties, ExprValue datetime, Long days, Boolean isAdd) { + if (datetime.type() == DATE) { + return new ExprDateValue(isAdd ? datetime.dateValue().plusDays(days) : datetime.dateValue().minusDays(days)); + } + var dt = extractDateTime(datetime, functionProperties); + return new ExprDatetimeValue(isAdd ? dt.plusDays(days) : dt.minusDays(days)); + } + + /** + * Adds or subtracts time to/from date and returns the result. + * + * @param functionProperties A FunctionProperties object. + * @param temporal A Date/Time/Datetime/Timestamp value to change. + * @param temporalDelta A Date/Time/Datetime/Timestamp object to add/subtract time from. + * @param isAdd A flag: true to add, false to subtract. + * @return A value calculated. + */ + private ExprValue exprApplyTime(FunctionProperties functionProperties, ExprValue temporal, ExprValue temporalDelta, Boolean isAdd) { + var interval = Duration.between(LocalTime.MIN, temporalDelta.timeValue()); + var result = isAdd + ? extractDateTime(temporal, functionProperties).plus(interval) + : extractDateTime(temporal, functionProperties).minus(interval); + return temporal.type() == TIME ? new ExprTimeValue(result.toLocalTime()) : new ExprDatetimeValue(result); + } + + /** + * Adds time to date and returns the result. + * + * @param functionProperties A FunctionProperties object. + * @param temporal A Date/Time/Datetime/Timestamp value to change. + * @param temporalDelta A Date/Time/Datetime/Timestamp object to add time from. + * @return A value calculated. + */ + private ExprValue exprAddTime(FunctionProperties functionProperties, ExprValue temporal, ExprValue temporalDelta) { + return exprApplyTime(functionProperties, temporal, temporalDelta, true); + } + + /** + * CONVERT_TZ function implementation for ExprValue. + * Returns null for time zones outside of +13:00 and -12:00. + * + * @param startingDateTime ExprValue of DateTime that is being converted from + * @param fromTz ExprValue of time zone, representing the time to convert from. + * @param toTz ExprValue of time zone, representing the time to convert to. + * @return DateTime that has been converted to the to_tz timezone. + */ + private ExprValue exprConvertTZ(ExprValue startingDateTime, ExprValue fromTz, ExprValue toTz) { + if (startingDateTime.type() == ExprCoreType.STRING) { + startingDateTime = exprDateTimeNoTimezone(startingDateTime); } try { - var res = LocalDateTime.parse(input, DATE_TIME_FORMATTER_SHORT_YEAR); - return res.toEpochSecond(ZoneOffset.UTC) + fraction; - } catch (DateTimeParseException ignored) { - // nothing to do, try another format + ZoneId convertedFromTz = ZoneId.of(fromTz.stringValue()); + ZoneId convertedToTz = ZoneId.of(toTz.stringValue()); + + // isValidMySqlTimeZoneId checks if the timezone is within the range accepted by + // MySQL standard. + if (!DateTimeUtils.isValidMySqlTimeZoneId(convertedFromTz) || !DateTimeUtils.isValidMySqlTimeZoneId(convertedToTz)) { + return ExprNullValue.of(); + } + ZonedDateTime zonedDateTime = startingDateTime.datetimeValue().atZone(convertedFromTz); + return new ExprDatetimeValue(zonedDateTime.withZoneSameInstant(convertedToTz).toLocalDateTime()); + + // Catches exception for invalid timezones. + // ex. "+0:00" is an invalid timezone and would result in this exception being thrown. + } catch (ExpressionEvaluationException | DateTimeException e) { + return ExprNullValue.of(); + } + } + + /** + * Date implementation for ExprValue. + * + * @param exprValue ExprValue of Date type or String type. + * @return ExprValue. + */ + private ExprValue exprDate(ExprValue exprValue) { + if (exprValue instanceof ExprStringValue) { + return new ExprDateValue(exprValue.stringValue()); + } else { + return new ExprDateValue(exprValue.dateValue()); } + } + + /** + * Calculate the value in days from one date to the other. + * Only the date parts of the values are used in the calculation. + * + * @param first The first value. + * @param second The second value. + * @return The diff. + */ + private ExprValue exprDateDiff(FunctionProperties functionProperties, ExprValue first, ExprValue second) { + // java inverses the value, so we have to swap 1 and 2 + return new ExprLongValue(DAYS.between(extractDate(second, functionProperties), extractDate(first, functionProperties))); + } + + /** + * DateTime implementation for ExprValue. + * + * @param dateTime ExprValue of String type. + * @param timeZone ExprValue of String type (or null). + * @return ExprValue of date type. + */ + private ExprValue exprDateTime(ExprValue dateTime, ExprValue timeZone) { + String defaultTimeZone = TimeZone.getDefault().getID(); + try { - var res = LocalDateTime.parse(input, DATE_TIME_FORMATTER_LONG_YEAR); - return res.toEpochSecond(ZoneOffset.UTC) + fraction; - } catch (DateTimeParseException ignored) { - // nothing to do, try another format + LocalDateTime ldtFormatted = LocalDateTime.parse(dateTime.stringValue(), DATE_TIME_FORMATTER_STRICT_WITH_TZ); + if (timeZone.isNull()) { + return new ExprDatetimeValue(ldtFormatted); + } + + // Used if datetime field is invalid format. + } catch (DateTimeParseException e) { + return ExprNullValue.of(); + } + + ExprValue convertTZResult; + ExprDatetimeValue ldt; + String toTz; + + try { + ZonedDateTime zdtWithZoneOffset = ZonedDateTime.parse(dateTime.stringValue(), DATE_TIME_FORMATTER_STRICT_WITH_TZ); + ZoneId fromTZ = zdtWithZoneOffset.getZone(); + + ldt = new ExprDatetimeValue(zdtWithZoneOffset.toLocalDateTime()); + toTz = String.valueOf(fromTZ); + } catch (DateTimeParseException e) { + ldt = new ExprDatetimeValue(dateTime.stringValue()); + toTz = defaultTimeZone; + } + convertTZResult = exprConvertTZ(ldt, new ExprStringValue(toTz), timeZone); + + return convertTZResult; + } + + /** + * DateTime implementation for ExprValue without a timezone to convert to. + * + * @param dateTime ExprValue of String type. + * @return ExprValue of date type. + */ + private ExprValue exprDateTimeNoTimezone(ExprValue dateTime) { + return exprDateTime(dateTime, ExprNullValue.of()); + } + + /** + * Name of the Weekday implementation for ExprValue. + * + * @param date ExprValue of Date/String type. + * @return ExprValue. + */ + private ExprValue exprDayName(ExprValue date) { + return new ExprStringValue(date.dateValue().getDayOfWeek().getDisplayName(TextStyle.FULL, Locale.getDefault())); + } + + /** + * Day of Month implementation for ExprValue. + * + * @param date ExprValue of Date/Datetime/String/Time/Timestamp type. + * @return ExprValue. + */ + private ExprValue exprDayOfMonth(ExprValue date) { + return new ExprIntegerValue(date.dateValue().getDayOfMonth()); + } + + /** + * Day of Week implementation for ExprValue. + * + * @param date ExprValue of Date/Datetime/String/Timstamp type. + * @return ExprValue. + */ + private ExprValue exprDayOfWeek(ExprValue date) { + return new ExprIntegerValue((date.dateValue().getDayOfWeek().getValue() % 7) + 1); + } + + /** + * Day of Year implementation for ExprValue. + * + * @param date ExprValue of Date/String type. + * @return ExprValue. + */ + private ExprValue exprDayOfYear(ExprValue date) { + return new ExprIntegerValue(date.dateValue().getDayOfYear()); + } + + /** + * Obtains a formatted long value for a specified part and datetime for the 'extract' function. + * + * @param part is an ExprValue which comes from a defined list of accepted values. + * @param datetime the date to be formatted as an ExprValue. + * @return is a LONG formatted according to the input arguments. + */ + public ExprLongValue formatExtractFunction(ExprValue part, ExprValue datetime) { + String partName = part.stringValue().toUpperCase(); + LocalDateTime arg = datetime.datetimeValue(); + String text = arg.format(DateTimeFormatter.ofPattern(extract_formats.get(partName), Locale.ENGLISH)); + + return new ExprLongValue(Long.parseLong(text)); + } + + /** + * Implements extract function. Returns a LONG formatted according to the 'part' argument. + * + * @param part Literal that determines the format of the outputted LONG. + * @param datetime The date/datetime to be formatted. + * @return A LONG + */ + private ExprValue exprExtract(ExprValue part, ExprValue datetime) { + return formatExtractFunction(part, datetime); + } + + /** + * Implements extract function. Returns a LONG formatted according to the 'part' argument. + * + * @param part Literal that determines the format of the outputted LONG. + * @param time The time to be formatted. + * @return A LONG + */ + private ExprValue exprExtractForTime(FunctionProperties functionProperties, ExprValue part, ExprValue time) { + return formatExtractFunction(part, new ExprDatetimeValue(extractDateTime(time, functionProperties))); + } + + /** + * From_days implementation for ExprValue. + * + * @param exprValue Day number N. + * @return ExprValue. + */ + private ExprValue exprFromDays(ExprValue exprValue) { + return new ExprDateValue(LocalDate.ofEpochDay(exprValue.longValue() - DAYS_0000_TO_1970)); + } + + private ExprValue exprFromUnixTime(ExprValue time) { + if (0 > time.doubleValue()) { + return ExprNullValue.of(); + } + // According to MySQL documentation: + // effective maximum is 32536771199.999999, which returns '3001-01-18 23:59:59.999999' UTC. + // Regardless of platform or version, a greater value for first argument than the effective + // maximum returns 0. + if (MYSQL_MAX_TIMESTAMP <= time.doubleValue()) { + return ExprNullValue.of(); + } + return new ExprDatetimeValue(exprFromUnixTimeImpl(time)); + } + + private LocalDateTime exprFromUnixTimeImpl(ExprValue time) { + return LocalDateTime.ofInstant(Instant.ofEpochSecond((long) Math.floor(time.doubleValue())), UTC_ZONE_ID) + .withNano((int) ((time.doubleValue() % 1) * 1E9)); + } + + private ExprValue exprFromUnixTimeFormat(ExprValue time, ExprValue format) { + var value = exprFromUnixTime(time); + if (value.equals(ExprNullValue.of())) { + return ExprNullValue.of(); + } + return DateTimeFormatterUtil.getFormattedDate(value, format); + } + + /** + * get_format implementation for ExprValue. + * + * @param type ExprValue of the type. + * @param format ExprValue of Time/String type + * @return ExprValue.. + */ + private ExprValue exprGetFormat(ExprValue type, ExprValue format) { + if (formats.contains(type.stringValue().toLowerCase(), format.stringValue().toLowerCase())) { + return new ExprStringValue(formats.get(type.stringValue().toLowerCase(), format.stringValue().toLowerCase())); + } + + return ExprNullValue.of(); + } + + /** + * Hour implementation for ExprValue. + * + * @param time ExprValue of Time/String type. + * @return ExprValue. + */ + private ExprValue exprHour(ExprValue time) { + return new ExprIntegerValue(HOURS.between(LocalTime.MIN, time.timeValue())); + } + + /** + * Helper function to retrieve the last day of a month based on a LocalDate argument. + * + * @param today a LocalDate. + * @return a LocalDate associated with the last day of the month for the given input. + */ + private LocalDate getLastDay(LocalDate today) { + return LocalDate.of(today.getYear(), today.getMonth(), today.getMonth().length(today.isLeapYear())); + } + + /** + * Returns a DATE for the last day of the month of a given argument. + * + * @param datetime A DATE/DATETIME/TIMESTAMP/STRING ExprValue. + * @return An DATE value corresponding to the last day of the month of the given argument. + */ + private ExprValue exprLastDay(ExprValue datetime) { + return new ExprDateValue(getLastDay(datetime.dateValue())); + } + + /** + * Returns a DATE for the last day of the current month. + * + * @param clock The clock for the query start time from functionProperties. + * @return An DATE value corresponding to the last day of the month of the given argument. + */ + private ExprValue exprLastDayToday(Clock clock) { + return new ExprDateValue(getLastDay(formatNow(clock).toLocalDate())); + } + + /** + * Following MySQL, function receives arguments of type double and rounds them before use. + * Furthermore: + * - zero year interpreted as 2000 + * - negative year is not accepted + * - @dayOfYear should be greater than 1 + * - if @dayOfYear is greater than 365/366, calculation goes to the next year(s) + * + * @param yearExpr year + * @param dayOfYearExp day of the @year, starting from 1 + * @return Date - ExprDateValue object with LocalDate + */ + private ExprValue exprMakeDate(ExprValue yearExpr, ExprValue dayOfYearExp) { + var year = Math.round(yearExpr.doubleValue()); + var dayOfYear = Math.round(dayOfYearExp.doubleValue()); + // We need to do this to comply with MySQL + if (0 >= dayOfYear || 0 > year) { + return ExprNullValue.of(); + } + if (0 == year) { + year = 2000; + } + return new ExprDateValue(LocalDate.ofYearDay((int) year, 1).plusDays(dayOfYear - 1)); + } + + /** + * Following MySQL, function receives arguments of type double. @hour and @minute are rounded, + * while @second used as is, including fraction part. + * @param hourExpr hour + * @param minuteExpr minute + * @param secondExpr second + * @return Time - ExprTimeValue object with LocalTime + */ + private ExprValue exprMakeTime(ExprValue hourExpr, ExprValue minuteExpr, ExprValue secondExpr) { + var hour = Math.round(hourExpr.doubleValue()); + var minute = Math.round(minuteExpr.doubleValue()); + var second = secondExpr.doubleValue(); + if (0 > hour || 0 > minute || 0 > second) { + return ExprNullValue.of(); + } + return new ExprTimeValue(LocalTime.parse(String.format("%02d:%02d:%012.9f", hour, minute, second), DateTimeFormatter.ISO_TIME)); + } + + /** + * Microsecond implementation for ExprValue. + * + * @param time ExprValue of Time/String type. + * @return ExprValue. + */ + private ExprValue exprMicrosecond(ExprValue time) { + return new ExprIntegerValue(TimeUnit.MICROSECONDS.convert(time.timeValue().getNano(), TimeUnit.NANOSECONDS)); + } + + /** + * Minute implementation for ExprValue. + * + * @param time ExprValue of Time/String type. + * @return ExprValue. + */ + private ExprValue exprMinute(ExprValue time) { + return new ExprIntegerValue((MINUTES.between(LocalTime.MIN, time.timeValue()) % 60)); + } + + /** + * Minute_of_day implementation for ExprValue. + * + * @param time ExprValue of Time/String type. + * @return ExprValue. + */ + private ExprValue exprMinuteOfDay(ExprValue time) { + return new ExprIntegerValue(MINUTES.between(LocalTime.MIN, time.timeValue())); + } + + /** + * Month for date implementation for ExprValue. + * + * @param date ExprValue of Date/String type. + * @return ExprValue. + */ + private ExprValue exprMonth(ExprValue date) { + return new ExprIntegerValue(date.dateValue().getMonthValue()); + } + + /** + * Name of the Month implementation for ExprValue. + * + * @param date ExprValue of Date/String type. + * @return ExprValue. + */ + private ExprValue exprMonthName(ExprValue date) { + return new ExprStringValue(date.dateValue().getMonth().getDisplayName(TextStyle.FULL, Locale.getDefault())); + } + + private LocalDate parseDatePeriod(Integer period) { + var input = period.toString(); + // MySQL undocumented: if year is not specified or has 1 digit - 2000/200x is assumed + if (input.length() <= 5) { + input = String.format("200%05d", period); } try { - var res = LocalDate.parse(input, DATE_FORMATTER_SHORT_YEAR); - return res.toEpochSecond(LocalTime.MIN, ZoneOffset.UTC) + 0d; + return LocalDate.parse(input, DATE_FORMATTER_SHORT_YEAR); } catch (DateTimeParseException ignored) { - // nothing to do, try another format + // nothing to do, try another format } try { - var res = LocalDate.parse(input, DATE_FORMATTER_LONG_YEAR); - return res.toEpochSecond(LocalTime.MIN, ZoneOffset.UTC) + 0d; + return LocalDate.parse(input, DATE_FORMATTER_LONG_YEAR); } catch (DateTimeParseException ignored) { - return null; + return null; + } + } + + /** + * Adds N months to period P (in the format YYMM or YYYYMM). + * Returns a value in the format YYYYMM. + * + * @param period Period in the format YYMM or YYYYMM. + * @param months Amount of months to add. + * @return ExprIntegerValue. + */ + private ExprValue exprPeriodAdd(ExprValue period, ExprValue months) { + // We should add a day to make string parsable and remove it afterwards + var input = period.integerValue() * 100 + 1; // adds 01 to end of the string + var parsedDate = parseDatePeriod(input); + if (parsedDate == null) { + return ExprNullValue.of(); + } + var res = DATE_FORMATTER_LONG_YEAR.format(parsedDate.plusMonths(months.integerValue())); + return new ExprIntegerValue(Integer.parseInt(res.substring(0, res.length() - 2))); // Remove the day part, .eg. 20070101 -> 200701 + } + + /** + * Returns the number of months between periods P1 and P2. + * P1 and P2 should be in the format YYMM or YYYYMM. + * + * @param period1 Period in the format YYMM or YYYYMM. + * @param period2 Period in the format YYMM or YYYYMM. + * @return ExprIntegerValue. + */ + private ExprValue exprPeriodDiff(ExprValue period1, ExprValue period2) { + var parsedDate1 = parseDatePeriod(period1.integerValue() * 100 + 1); + var parsedDate2 = parseDatePeriod(period2.integerValue() * 100 + 1); + if (parsedDate1 == null || parsedDate2 == null) { + return ExprNullValue.of(); + } + return new ExprIntegerValue(MONTHS.between(parsedDate2, parsedDate1)); + } + + /** + * Quarter for date implementation for ExprValue. + * + * @param date ExprValue of Date/String type. + * @return ExprValue. + */ + private ExprValue exprQuarter(ExprValue date) { + int month = date.dateValue().getMonthValue(); + return new ExprIntegerValue((month / 3) + ((month % 3) == 0 ? 0 : 1)); + } + + /** + * Returns TIME value of sec_to_time function for an INTEGER or LONG arguments. + * @param totalSeconds The total number of seconds + * @return A TIME value + */ + private ExprValue exprSecToTime(ExprValue totalSeconds) { + return new ExprTimeValue(LocalTime.MIN.plus(Duration.ofSeconds(totalSeconds.longValue()))); + } + + /** + * Helper function which obtains the decimal portion of the seconds value passed in. + * Uses BigDecimal to prevent issues with math on floating point numbers. + * Return is formatted to be used with Duration.ofSeconds(); + * + * @param seconds and ExprDoubleValue or ExprFloatValue for the seconds + * @return A LONG representing the nanoseconds portion + */ + private long formatNanos(ExprValue seconds) { + // Convert ExprValue to BigDecimal + BigDecimal formattedNanos = BigDecimal.valueOf(seconds.doubleValue()); + // Extract only the nanosecond part + formattedNanos = formattedNanos.subtract(BigDecimal.valueOf(formattedNanos.intValue())); + + return formattedNanos.scaleByPowerOfTen(9).longValue(); + } + + /** + * Returns TIME value of sec_to_time function for FLOAT or DOUBLE arguments. + * @param totalSeconds The total number of seconds + * @return A TIME value + */ + private ExprValue exprSecToTimeWithNanos(ExprValue totalSeconds) { + long nanos = formatNanos(totalSeconds); + + return new ExprTimeValue(LocalTime.MIN.plus(Duration.ofSeconds(totalSeconds.longValue(), nanos))); + } + + /** + * Second implementation for ExprValue. + * + * @param time ExprValue of Time/String type. + * @return ExprValue. + */ + private ExprValue exprSecond(ExprValue time) { + return new ExprIntegerValue((SECONDS.between(LocalTime.MIN, time.timeValue()) % 60)); + } + + /** + * SUBDATE function implementation for ExprValue. + * + * @param functionProperties An FunctionProperties object. + * @param date ExprValue of Time/Date/Datetime/Timestamp type. + * @param days ExprValue of Long type, representing the number of days to subtract. + * @return Date/Datetime resulted from days subtracted to date. + */ + private ExprValue exprSubDateDays(FunctionProperties functionProperties, ExprValue date, ExprValue days) { + return exprDateApplyDays(functionProperties, date, days.longValue(), false); + } + + /** + * DATE_SUB function implementation for ExprValue. + * + * @param functionProperties An FunctionProperties object. + * @param datetime ExprValue of Time/Date/Datetime/Timestamp type. + * @param expr ExprValue of Interval type, the temporal amount to subtract. + * @return Datetime resulted from expr subtracted to `datetime`. + */ + private ExprValue exprSubDateInterval(FunctionProperties functionProperties, ExprValue datetime, ExprValue expr) { + return exprDateApplyInterval(functionProperties, datetime, expr.intervalValue(), false); + } + + /** + * Subtracts expr2 from expr1 and returns the result. + * + * @param temporal A Date/Time/Datetime/Timestamp value to change. + * @param temporalDelta A Date/Time/Datetime/Timestamp to subtract time from. + * @return A value calculated. + */ + private ExprValue exprSubTime(FunctionProperties functionProperties, ExprValue temporal, ExprValue temporalDelta) { + return exprApplyTime(functionProperties, temporal, temporalDelta, false); + } + + private ExprValue exprStrToDate(FunctionProperties fp, ExprValue dateTimeExpr, ExprValue formatStringExp) { + return DateTimeFormatterUtil.parseStringWithDateOrTime(fp, dateTimeExpr, formatStringExp); + } + + /** + * Time implementation for ExprValue. + * + * @param exprValue ExprValue of Time type or String. + * @return ExprValue. + */ + private ExprValue exprTime(ExprValue exprValue) { + if (exprValue instanceof ExprStringValue) { + return new ExprTimeValue(exprValue.stringValue()); + } else { + return new ExprTimeValue(exprValue.timeValue()); + } + } + + /** + * Calculate the time difference between two times. + * + * @param first The first value. + * @param second The second value. + * @return The diff. + */ + private ExprValue exprTimeDiff(ExprValue first, ExprValue second) { + // java inverses the value, so we have to swap 1 and 2 + return new ExprTimeValue(LocalTime.MIN.plus(Duration.between(second.timeValue(), first.timeValue()))); + } + + /** + * Time To Sec implementation for ExprValue. + * + * @param time ExprValue of Time/String type. + * @return ExprValue. + */ + private ExprValue exprTimeToSec(ExprValue time) { + return new ExprLongValue(time.timeValue().toSecondOfDay()); + } + + private ExprValue exprTimestampAdd(ExprValue partExpr, ExprValue amountExpr, ExprValue datetimeExpr) { + String part = partExpr.stringValue(); + int amount = amountExpr.integerValue(); + LocalDateTime datetime = datetimeExpr.datetimeValue(); + ChronoUnit temporalUnit; + + switch (part) { + case "MICROSECOND": + temporalUnit = MICROS; + break; + case "SECOND": + temporalUnit = SECONDS; + break; + case "MINUTE": + temporalUnit = MINUTES; + break; + case "HOUR": + temporalUnit = HOURS; + break; + case "DAY": + temporalUnit = DAYS; + break; + case "WEEK": + temporalUnit = WEEKS; + break; + case "MONTH": + temporalUnit = MONTHS; + break; + case "QUARTER": + temporalUnit = MONTHS; + amount *= 3; + break; + case "YEAR": + temporalUnit = YEARS; + break; + default: + return ExprNullValue.of(); + } + return new ExprDatetimeValue(datetime.plus(amount, temporalUnit)); + } + + private ExprValue exprTimestampAddForTimeType(Clock clock, ExprValue partExpr, ExprValue amountExpr, ExprValue timeExpr) { + LocalDateTime datetime = LocalDateTime.of(formatNow(clock).toLocalDate(), timeExpr.timeValue()); + return exprTimestampAdd(partExpr, amountExpr, new ExprDatetimeValue(datetime)); + } + + private ExprValue getTimeDifference(String part, LocalDateTime startTime, LocalDateTime endTime) { + long returnVal; + switch (part) { + case "MICROSECOND": + returnVal = MICROS.between(startTime, endTime); + break; + case "SECOND": + returnVal = SECONDS.between(startTime, endTime); + break; + case "MINUTE": + returnVal = MINUTES.between(startTime, endTime); + break; + case "HOUR": + returnVal = HOURS.between(startTime, endTime); + break; + case "DAY": + returnVal = DAYS.between(startTime, endTime); + break; + case "WEEK": + returnVal = WEEKS.between(startTime, endTime); + break; + case "MONTH": + returnVal = MONTHS.between(startTime, endTime); + break; + case "QUARTER": + returnVal = MONTHS.between(startTime, endTime) / 3; + break; + case "YEAR": + returnVal = YEARS.between(startTime, endTime); + break; + default: + return ExprNullValue.of(); + } + return new ExprLongValue(returnVal); + } + + private ExprValue exprTimestampDiff(ExprValue partExpr, ExprValue startTimeExpr, ExprValue endTimeExpr) { + return getTimeDifference(partExpr.stringValue(), startTimeExpr.datetimeValue(), endTimeExpr.datetimeValue()); + } + + private ExprValue exprTimestampDiffForTimeType( + FunctionProperties fp, + ExprValue partExpr, + ExprValue startTimeExpr, + ExprValue endTimeExpr + ) { + return getTimeDifference(partExpr.stringValue(), extractDateTime(startTimeExpr, fp), extractDateTime(endTimeExpr, fp)); + } + + /** + * UTC_DATE implementation for ExprValue. + * + * @param functionProperties FunctionProperties. + * @return ExprValue. + */ + private ExprValue exprUtcDate(FunctionProperties functionProperties) { + return new ExprDateValue(exprUtcTimeStamp(functionProperties).dateValue()); + } + + /** + * UTC_TIME implementation for ExprValue. + * + * @param functionProperties FunctionProperties. + * @return ExprValue. + */ + private ExprValue exprUtcTime(FunctionProperties functionProperties) { + return new ExprTimeValue(exprUtcTimeStamp(functionProperties).timeValue()); + } + + /** + * UTC_TIMESTAMP implementation for ExprValue. + * + * @param functionProperties FunctionProperties. + * @return ExprValue. + */ + private ExprValue exprUtcTimeStamp(FunctionProperties functionProperties) { + var zdt = ZonedDateTime.now(functionProperties.getQueryStartClock()).withZoneSameInstant(UTC_ZONE_ID); + return new ExprDatetimeValue(zdt.toLocalDateTime()); + } + + /** + * To_days implementation for ExprValue. + * + * @param date ExprValue of Date/String type. + * @return ExprValue. + */ + private ExprValue exprToDays(ExprValue date) { + return new ExprLongValue(date.dateValue().toEpochDay() + DAYS_0000_TO_1970); + } + + /** + * To_seconds implementation for ExprValue. + * + * @param date ExprValue of Date/Datetime/Timestamp/String type. + * @return ExprValue. + */ + private ExprValue exprToSeconds(ExprValue date) { + return new ExprLongValue(date.datetimeValue().toEpochSecond(ZoneOffset.UTC) + DAYS_0000_TO_1970 * SECONDS_PER_DAY); + } + + /** + * Helper function to determine the correct formatter for date arguments passed in as integers. + * + * @param dateAsInt is an integer formatted as one of YYYYMMDD, YYMMDD, YMMDD, MMDD, MDD + * @return is a DateTimeFormatter that can parse the input. + */ + private DateTimeFormatter getFormatter(int dateAsInt) { + int length = String.format("%d", dateAsInt).length(); + + if (length > 8) { + throw new DateTimeException("Integer argument was out of range"); + } + + // Check below from YYYYMMDD - MMDD which format should be used + switch (length) { + // Check if dateAsInt is at least 8 digits long + case FULL_DATE_LENGTH: + return DATE_FORMATTER_LONG_YEAR; + + // Check if dateAsInt is at least 6 digits long + case SHORT_DATE_LENGTH: + return DATE_FORMATTER_SHORT_YEAR; + + // Check if dateAsInt is at least 5 digits long + case SINGLE_DIGIT_YEAR_DATE_LENGTH: + return DATE_FORMATTER_SINGLE_DIGIT_YEAR; + + // Check if dateAsInt is at least 4 digits long + case NO_YEAR_DATE_LENGTH: + return DATE_FORMATTER_NO_YEAR; + + // Check if dateAsInt is at least 3 digits long + case SINGLE_DIGIT_MONTH_DATE_LENGTH: + return DATE_FORMATTER_SINGLE_DIGIT_MONTH; + + default: + break; + } + + throw new DateTimeException("No Matching Format"); + } + + /** + * To_seconds implementation with an integer argument for ExprValue. + * + * @param dateExpr ExprValue of an Integer/Long formatted for a date (e.g., 950501 = 1995-05-01) + * @return ExprValue. + */ + private ExprValue exprToSecondsForIntType(ExprValue dateExpr) { + try { + // Attempt to parse integer argument as date + LocalDate date = LocalDate.parse(String.valueOf(dateExpr.integerValue()), getFormatter(dateExpr.integerValue())); + + return new ExprLongValue(date.toEpochSecond(LocalTime.MIN, ZoneOffset.UTC) + DAYS_0000_TO_1970 * SECONDS_PER_DAY); + + } catch (DateTimeException ignored) { + // Return null if parsing error + return ExprNullValue.of(); + } + } + + /** + * Week for date implementation for ExprValue. + * + * @param date ExprValue of Date/Datetime/Timestamp/String type. + * @param mode ExprValue of Integer type. + */ + private ExprValue exprWeek(ExprValue date, ExprValue mode) { + return new ExprIntegerValue(CalendarLookup.getWeekNumber(mode.integerValue(), date.dateValue())); + } + + /** + * Weekday implementation for ExprValue. + * + * @param date ExprValue of Date/Datetime/String/Timstamp type. + * @return ExprValue. + */ + private ExprValue exprWeekday(ExprValue date) { + return new ExprIntegerValue(date.dateValue().getDayOfWeek().getValue() - 1); + } + + private ExprValue unixTimeStamp(Clock clock) { + return new ExprLongValue(Instant.now(clock).getEpochSecond()); + } + + private ExprValue unixTimeStampOf(ExprValue value) { + var res = unixTimeStampOfImpl(value); + if (res == null) { + return ExprNullValue.of(); + } + if (res < 0) { + // According to MySQL returns 0 if year < 1970, don't return negative values as java does. + return new ExprDoubleValue(0); + } + if (res >= MYSQL_MAX_TIMESTAMP) { + // Return 0 also for dates > '3001-01-19 03:14:07.999999' UTC (32536771199.999999 sec) + return new ExprDoubleValue(0); + } + return new ExprDoubleValue(res); + } + + private Double unixTimeStampOfImpl(ExprValue value) { + // Also, according to MySQL documentation: + // The date argument may be a DATE, DATETIME, or TIMESTAMP ... + switch ((ExprCoreType) value.type()) { + case DATE: + return value.dateValue().toEpochSecond(LocalTime.MIN, ZoneOffset.UTC) + 0d; + case DATETIME: + return value.datetimeValue().toEpochSecond(ZoneOffset.UTC) + value.datetimeValue().getNano() / 1E9; + case TIMESTAMP: + return value.timestampValue().getEpochSecond() + value.timestampValue().getNano() / 1E9; + default: + // ... or a number in YYMMDD, YYMMDDhhmmss, YYYYMMDD, or YYYYMMDDhhmmss format. + // If the argument includes a time part, it may optionally include a fractional + // seconds part. + + var format = new DecimalFormat("0.#"); + format.setMinimumFractionDigits(0); + format.setMaximumFractionDigits(6); + String input = format.format(value.doubleValue()); + double fraction = 0; + if (input.contains(".")) { + // Keeping fraction second part and adding it to the result, don't parse it + // Because `toEpochSecond` returns only `long` + // input = 12345.6789 becomes input = 12345 and fraction = 0.6789 + fraction = value.doubleValue() - Math.round(Math.ceil(value.doubleValue())); + input = input.substring(0, input.indexOf('.')); + } + try { + var res = LocalDateTime.parse(input, DATE_TIME_FORMATTER_SHORT_YEAR); + return res.toEpochSecond(ZoneOffset.UTC) + fraction; + } catch (DateTimeParseException ignored) { + // nothing to do, try another format + } + try { + var res = LocalDateTime.parse(input, DATE_TIME_FORMATTER_LONG_YEAR); + return res.toEpochSecond(ZoneOffset.UTC) + fraction; + } catch (DateTimeParseException ignored) { + // nothing to do, try another format + } + try { + var res = LocalDate.parse(input, DATE_FORMATTER_SHORT_YEAR); + return res.toEpochSecond(LocalTime.MIN, ZoneOffset.UTC) + 0d; + } catch (DateTimeParseException ignored) { + // nothing to do, try another format + } + try { + var res = LocalDate.parse(input, DATE_FORMATTER_LONG_YEAR); + return res.toEpochSecond(LocalTime.MIN, ZoneOffset.UTC) + 0d; + } catch (DateTimeParseException ignored) { + return null; + } + } + } + + /** + * Week for date implementation for ExprValue. + * When mode is not specified default value mode 0 is used for default_week_format. + * + * @param date ExprValue of Date/Datetime/Timestamp/String type. + * @return ExprValue. + */ + private ExprValue exprWeekWithoutMode(ExprValue date) { + return exprWeek(date, DEFAULT_WEEK_OF_YEAR_MODE); + } + + /** + * Year for date implementation for ExprValue. + * + * @param date ExprValue of Date/String type. + * @return ExprValue. + */ + private ExprValue exprYear(ExprValue date) { + return new ExprIntegerValue(date.dateValue().getYear()); + } + + /** + * Helper function to extract the yearweek output from a given date. + * + * @param date is a LocalDate input argument. + * @param mode is an integer containing the mode used to parse the LocalDate. + * @return is a long containing the formatted output for the yearweek function. + */ + private ExprIntegerValue extractYearweek(LocalDate date, int mode) { + // Needed to align with MySQL. Due to how modes for this function work. + // See description of modes here ... + // https://dev.mysql.com/doc/refman/8.0/en/date-and-time-functions.html#function_week + int modeJava = CalendarLookup.getWeekNumber(mode, date) != 0 ? mode : mode <= 4 ? 2 : 7; + + int formatted = CalendarLookup.getYearNumber(modeJava, date) * 100 + CalendarLookup.getWeekNumber(modeJava, date); + + return new ExprIntegerValue(formatted); + } + + /** + * Yearweek for date implementation for ExprValue. + * + * @param date ExprValue of Date/Datetime/Time/Timestamp/String type. + * @param mode ExprValue of Integer type. + */ + private ExprValue exprYearweek(ExprValue date, ExprValue mode) { + return extractYearweek(date.dateValue(), mode.integerValue()); + } + + /** + * Yearweek for date implementation for ExprValue. + * When mode is not specified default value mode 0 is used. + * + * @param date ExprValue of Date/Datetime/Time/Timestamp/String type. + * @return ExprValue. + */ + private ExprValue exprYearweekWithoutMode(ExprValue date) { + return exprYearweek(date, new ExprIntegerValue(0)); + } + + private ExprValue yearweekToday(ExprValue mode, Clock clock) { + return extractYearweek(LocalDateTime.now(clock).toLocalDate(), mode.integerValue()); + } + + private ExprValue monthOfYearToday(Clock clock) { + return new ExprIntegerValue(LocalDateTime.now(clock).getMonthValue()); + } + + private LocalDateTime formatNow(Clock clock) { + return formatNow(clock, 0); + } + + /** + * Prepare LocalDateTime value. Truncate fractional second part according to the argument. + * @param fsp argument is given to specify a fractional seconds precision from 0 to 6, + * the return value includes a fractional seconds part of that many digits. + * @return LocalDateTime object. + */ + private LocalDateTime formatNow(Clock clock, Integer fsp) { + var res = LocalDateTime.now(clock); + var defaultPrecision = 9; // There are 10^9 nanoseconds in one second + if (fsp < 0 || fsp > 6) { // Check that the argument is in the allowed range [0, 6] + throw new IllegalArgumentException(String.format("Invalid `fsp` value: %d, allowed 0 to 6", fsp)); } + var nano = new BigDecimal(res.getNano()).setScale(fsp - defaultPrecision, RoundingMode.DOWN).intValue(); + return res.withNano(nano); } - } - - /** - * Week for date implementation for ExprValue. - * When mode is not specified default value mode 0 is used for default_week_format. - * - * @param date ExprValue of Date/Datetime/Timestamp/String type. - * @return ExprValue. - */ - private ExprValue exprWeekWithoutMode(ExprValue date) { - return exprWeek(date, DEFAULT_WEEK_OF_YEAR_MODE); - } - - /** - * Year for date implementation for ExprValue. - * - * @param date ExprValue of Date/String type. - * @return ExprValue. - */ - private ExprValue exprYear(ExprValue date) { - return new ExprIntegerValue(date.dateValue().getYear()); - } - - /** - * Helper function to extract the yearweek output from a given date. - * - * @param date is a LocalDate input argument. - * @param mode is an integer containing the mode used to parse the LocalDate. - * @return is a long containing the formatted output for the yearweek function. - */ - private ExprIntegerValue extractYearweek(LocalDate date, int mode) { - // Needed to align with MySQL. Due to how modes for this function work. - // See description of modes here ... - // https://dev.mysql.com/doc/refman/8.0/en/date-and-time-functions.html#function_week - int modeJava = CalendarLookup.getWeekNumber(mode, date) != 0 ? mode : - mode <= 4 ? 2 : - 7; - - int formatted = CalendarLookup.getYearNumber(modeJava, date) * 100 - + CalendarLookup.getWeekNumber(modeJava, date); - - return new ExprIntegerValue(formatted); - } - - /** - * Yearweek for date implementation for ExprValue. - * - * @param date ExprValue of Date/Datetime/Time/Timestamp/String type. - * @param mode ExprValue of Integer type. - */ - private ExprValue exprYearweek(ExprValue date, ExprValue mode) { - return extractYearweek(date.dateValue(), mode.integerValue()); - } - - /** - * Yearweek for date implementation for ExprValue. - * When mode is not specified default value mode 0 is used. - * - * @param date ExprValue of Date/Datetime/Time/Timestamp/String type. - * @return ExprValue. - */ - private ExprValue exprYearweekWithoutMode(ExprValue date) { - return exprYearweek(date, new ExprIntegerValue(0)); - } - - private ExprValue yearweekToday(ExprValue mode, Clock clock) { - return extractYearweek(LocalDateTime.now(clock).toLocalDate(), mode.integerValue()); - } - - private ExprValue monthOfYearToday(Clock clock) { - return new ExprIntegerValue(LocalDateTime.now(clock).getMonthValue()); - } - - private LocalDateTime formatNow(Clock clock) { - return formatNow(clock, 0); - } - - /** - * Prepare LocalDateTime value. Truncate fractional second part according to the argument. - * @param fsp argument is given to specify a fractional seconds precision from 0 to 6, - * the return value includes a fractional seconds part of that many digits. - * @return LocalDateTime object. - */ - private LocalDateTime formatNow(Clock clock, Integer fsp) { - var res = LocalDateTime.now(clock); - var defaultPrecision = 9; // There are 10^9 nanoseconds in one second - if (fsp < 0 || fsp > 6) { // Check that the argument is in the allowed range [0, 6] - throw new IllegalArgumentException( - String.format("Invalid `fsp` value: %d, allowed 0 to 6", fsp)); - } - var nano = new BigDecimal(res.getNano()) - .setScale(fsp - defaultPrecision, RoundingMode.DOWN).intValue(); - return res.withNano(nano); - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/datetime/IntervalClause.java b/core/src/main/java/org/opensearch/sql/expression/datetime/IntervalClause.java index 3df8489b20..92e7e3a1c3 100644 --- a/core/src/main/java/org/opensearch/sql/expression/datetime/IntervalClause.java +++ b/core/src/main/java/org/opensearch/sql/expression/datetime/IntervalClause.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.datetime; import static org.opensearch.sql.data.model.ExprValueUtils.getIntegerValue; @@ -30,85 +29,86 @@ @UtilityClass public class IntervalClause { - private static final String MICRO_SECOND = "microsecond"; - private static final String SECOND = "second"; - private static final String MINUTE = "minute"; - private static final String HOUR = "hour"; - private static final String DAY = "day"; - private static final String WEEK = "week"; - private static final String MONTH = "month"; - private static final String QUARTER = "quarter"; - private static final String YEAR = "year"; - - public void register(BuiltinFunctionRepository repository) { - repository.register(interval()); - } - - private DefaultFunctionResolver interval() { - return define(BuiltinFunctionName.INTERVAL.getName(), - impl(nullMissingHandling(IntervalClause::interval), INTERVAL, INTEGER, STRING), - impl(nullMissingHandling(IntervalClause::interval), INTERVAL, LONG, STRING)); - } - - private ExprValue interval(ExprValue value, ExprValue unit) { - switch (getStringValue(unit).toLowerCase()) { - case MICRO_SECOND: - return microsecond(value); - case SECOND: - return second(value); - case MINUTE: - return minute(value); - case HOUR: - return hour(value); - case DAY: - return day(value); - case WEEK: - return week(value); - case MONTH: - return month(value); - case QUARTER: - return quarter(value); - case YEAR: - return year(value); - default: - throw new ExpressionEvaluationException( - String.format("interval unit %s is not supported", getStringValue(unit))); + private static final String MICRO_SECOND = "microsecond"; + private static final String SECOND = "second"; + private static final String MINUTE = "minute"; + private static final String HOUR = "hour"; + private static final String DAY = "day"; + private static final String WEEK = "week"; + private static final String MONTH = "month"; + private static final String QUARTER = "quarter"; + private static final String YEAR = "year"; + + public void register(BuiltinFunctionRepository repository) { + repository.register(interval()); + } + + private DefaultFunctionResolver interval() { + return define( + BuiltinFunctionName.INTERVAL.getName(), + impl(nullMissingHandling(IntervalClause::interval), INTERVAL, INTEGER, STRING), + impl(nullMissingHandling(IntervalClause::interval), INTERVAL, LONG, STRING) + ); } - } - private ExprValue microsecond(ExprValue value) { - return new ExprIntervalValue(Duration.ofNanos(getLongValue(value) * 1000)); - } + private ExprValue interval(ExprValue value, ExprValue unit) { + switch (getStringValue(unit).toLowerCase()) { + case MICRO_SECOND: + return microsecond(value); + case SECOND: + return second(value); + case MINUTE: + return minute(value); + case HOUR: + return hour(value); + case DAY: + return day(value); + case WEEK: + return week(value); + case MONTH: + return month(value); + case QUARTER: + return quarter(value); + case YEAR: + return year(value); + default: + throw new ExpressionEvaluationException(String.format("interval unit %s is not supported", getStringValue(unit))); + } + } - private ExprValue second(ExprValue value) { - return new ExprIntervalValue(Duration.ofSeconds(getLongValue(value))); - } + private ExprValue microsecond(ExprValue value) { + return new ExprIntervalValue(Duration.ofNanos(getLongValue(value) * 1000)); + } + + private ExprValue second(ExprValue value) { + return new ExprIntervalValue(Duration.ofSeconds(getLongValue(value))); + } - private ExprValue minute(ExprValue value) { - return new ExprIntervalValue(Duration.ofMinutes(getLongValue(value))); - } + private ExprValue minute(ExprValue value) { + return new ExprIntervalValue(Duration.ofMinutes(getLongValue(value))); + } - private ExprValue hour(ExprValue value) { - return new ExprIntervalValue(Duration.ofHours(getLongValue(value))); - } + private ExprValue hour(ExprValue value) { + return new ExprIntervalValue(Duration.ofHours(getLongValue(value))); + } - private ExprValue day(ExprValue value) { - return new ExprIntervalValue(Period.ofDays(getIntegerValue(value))); - } + private ExprValue day(ExprValue value) { + return new ExprIntervalValue(Period.ofDays(getIntegerValue(value))); + } - private ExprValue week(ExprValue value) { - return new ExprIntervalValue(Period.ofWeeks(getIntegerValue(value))); - } + private ExprValue week(ExprValue value) { + return new ExprIntervalValue(Period.ofWeeks(getIntegerValue(value))); + } - private ExprValue month(ExprValue value) { - return new ExprIntervalValue(Period.ofMonths(getIntegerValue(value))); - } + private ExprValue month(ExprValue value) { + return new ExprIntervalValue(Period.ofMonths(getIntegerValue(value))); + } - private ExprValue quarter(ExprValue value) { - return new ExprIntervalValue(Period.ofMonths(getIntegerValue(value) * 3)); - } + private ExprValue quarter(ExprValue value) { + return new ExprIntervalValue(Period.ofMonths(getIntegerValue(value) * 3)); + } - private ExprValue year(ExprValue value) { - return new ExprIntervalValue(Period.ofYears(getIntegerValue(value))); - } + private ExprValue year(ExprValue value) { + return new ExprIntervalValue(Period.ofYears(getIntegerValue(value))); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/env/Environment.java b/core/src/main/java/org/opensearch/sql/expression/env/Environment.java index d96d0c0a50..9c001ee310 100644 --- a/core/src/main/java/org/opensearch/sql/expression/env/Environment.java +++ b/core/src/main/java/org/opensearch/sql/expression/env/Environment.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.env; /** @@ -13,29 +12,28 @@ */ public interface Environment { - /** - * resolve the value of expression from the environment. - */ - V resolve(E var); + /** + * resolve the value of expression from the environment. + */ + V resolve(E var); - /** - * Extend the environment. - * - * @param env environment - * @param expr expression. - * @param value expression value. - * @param the type of expression - * @param the type of expression value - * @return extended environment. - */ - static Environment extendEnv( - Environment env, E expr, V value) { - return var -> { - if (var.equals(expr)) { - return value; - } else { - return env.resolve(var); - } - }; - } + /** + * Extend the environment. + * + * @param env environment + * @param expr expression. + * @param value expression value. + * @param the type of expression + * @param the type of expression value + * @return extended environment. + */ + static Environment extendEnv(Environment env, E expr, V value) { + return var -> { + if (var.equals(expr)) { + return value; + } else { + return env.resolve(var); + } + }; + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 728712f537..6fe457b688 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -18,299 +18,297 @@ @Getter @RequiredArgsConstructor public enum BuiltinFunctionName { - /** - * Mathematical Functions. - */ - ABS(FunctionName.of("abs")), - CEIL(FunctionName.of("ceil")), - CEILING(FunctionName.of("ceiling")), - CONV(FunctionName.of("conv")), - CRC32(FunctionName.of("crc32")), - E(FunctionName.of("e")), - EXP(FunctionName.of("exp")), - EXPM1(FunctionName.of("expm1")), - FLOOR(FunctionName.of("floor")), - LN(FunctionName.of("ln")), - LOG(FunctionName.of("log")), - LOG10(FunctionName.of("log10")), - LOG2(FunctionName.of("log2")), - PI(FunctionName.of("pi")), - POW(FunctionName.of("pow")), - POWER(FunctionName.of("power")), - RAND(FunctionName.of("rand")), - RINT(FunctionName.of("rint")), - ROUND(FunctionName.of("round")), - SIGN(FunctionName.of("sign")), - SIGNUM(FunctionName.of("signum")), - SINH(FunctionName.of("sinh")), - SQRT(FunctionName.of("sqrt")), - CBRT(FunctionName.of("cbrt")), - TRUNCATE(FunctionName.of("truncate")), + /** + * Mathematical Functions. + */ + ABS(FunctionName.of("abs")), + CEIL(FunctionName.of("ceil")), + CEILING(FunctionName.of("ceiling")), + CONV(FunctionName.of("conv")), + CRC32(FunctionName.of("crc32")), + E(FunctionName.of("e")), + EXP(FunctionName.of("exp")), + EXPM1(FunctionName.of("expm1")), + FLOOR(FunctionName.of("floor")), + LN(FunctionName.of("ln")), + LOG(FunctionName.of("log")), + LOG10(FunctionName.of("log10")), + LOG2(FunctionName.of("log2")), + PI(FunctionName.of("pi")), + POW(FunctionName.of("pow")), + POWER(FunctionName.of("power")), + RAND(FunctionName.of("rand")), + RINT(FunctionName.of("rint")), + ROUND(FunctionName.of("round")), + SIGN(FunctionName.of("sign")), + SIGNUM(FunctionName.of("signum")), + SINH(FunctionName.of("sinh")), + SQRT(FunctionName.of("sqrt")), + CBRT(FunctionName.of("cbrt")), + TRUNCATE(FunctionName.of("truncate")), - ACOS(FunctionName.of("acos")), - ASIN(FunctionName.of("asin")), - ATAN(FunctionName.of("atan")), - ATAN2(FunctionName.of("atan2")), - COS(FunctionName.of("cos")), - COSH(FunctionName.of("cosh")), - COT(FunctionName.of("cot")), - DEGREES(FunctionName.of("degrees")), - RADIANS(FunctionName.of("radians")), - SIN(FunctionName.of("sin")), - TAN(FunctionName.of("tan")), + ACOS(FunctionName.of("acos")), + ASIN(FunctionName.of("asin")), + ATAN(FunctionName.of("atan")), + ATAN2(FunctionName.of("atan2")), + COS(FunctionName.of("cos")), + COSH(FunctionName.of("cosh")), + COT(FunctionName.of("cot")), + DEGREES(FunctionName.of("degrees")), + RADIANS(FunctionName.of("radians")), + SIN(FunctionName.of("sin")), + TAN(FunctionName.of("tan")), - /** - * Date and Time Functions. - */ - ADDDATE(FunctionName.of("adddate")), - ADDTIME(FunctionName.of("addtime")), - CONVERT_TZ(FunctionName.of("convert_tz")), - DATE(FunctionName.of("date")), - DATEDIFF(FunctionName.of("datediff")), - DATETIME(FunctionName.of("datetime")), - DATE_ADD(FunctionName.of("date_add")), - DATE_FORMAT(FunctionName.of("date_format")), - DATE_SUB(FunctionName.of("date_sub")), - DAY(FunctionName.of("day")), - DAYNAME(FunctionName.of("dayname")), - DAYOFMONTH(FunctionName.of("dayofmonth")), - DAY_OF_MONTH(FunctionName.of("day_of_month")), - DAYOFWEEK(FunctionName.of("dayofweek")), - DAYOFYEAR(FunctionName.of("dayofyear")), - DAY_OF_WEEK(FunctionName.of("day_of_week")), - DAY_OF_YEAR(FunctionName.of("day_of_year")), - EXTRACT(FunctionName.of("extract")), - FROM_DAYS(FunctionName.of("from_days")), - FROM_UNIXTIME(FunctionName.of("from_unixtime")), - GET_FORMAT(FunctionName.of("get_format")), - HOUR(FunctionName.of("hour")), - HOUR_OF_DAY(FunctionName.of("hour_of_day")), - LAST_DAY(FunctionName.of("last_day")), - MAKEDATE(FunctionName.of("makedate")), - MAKETIME(FunctionName.of("maketime")), - MICROSECOND(FunctionName.of("microsecond")), - MINUTE(FunctionName.of("minute")), - MINUTE_OF_DAY(FunctionName.of("minute_of_day")), - MINUTE_OF_HOUR(FunctionName.of("minute_of_hour")), - MONTH(FunctionName.of("month")), - MONTH_OF_YEAR(FunctionName.of("month_of_year")), - MONTHNAME(FunctionName.of("monthname")), - PERIOD_ADD(FunctionName.of("period_add")), - PERIOD_DIFF(FunctionName.of("period_diff")), - QUARTER(FunctionName.of("quarter")), - SEC_TO_TIME(FunctionName.of("sec_to_time")), - SECOND(FunctionName.of("second")), - SECOND_OF_MINUTE(FunctionName.of("second_of_minute")), - STR_TO_DATE(FunctionName.of("str_to_date")), - SUBDATE(FunctionName.of("subdate")), - SUBTIME(FunctionName.of("subtime")), - TIME(FunctionName.of("time")), - TIMEDIFF(FunctionName.of("timediff")), - TIME_TO_SEC(FunctionName.of("time_to_sec")), - TIMESTAMP(FunctionName.of("timestamp")), - TIMESTAMPADD(FunctionName.of("timestampadd")), - TIMESTAMPDIFF(FunctionName.of("timestampdiff")), - TIME_FORMAT(FunctionName.of("time_format")), - TO_DAYS(FunctionName.of("to_days")), - TO_SECONDS(FunctionName.of("to_seconds")), - UTC_DATE(FunctionName.of("utc_date")), - UTC_TIME(FunctionName.of("utc_time")), - UTC_TIMESTAMP(FunctionName.of("utc_timestamp")), - UNIX_TIMESTAMP(FunctionName.of("unix_timestamp")), - WEEK(FunctionName.of("week")), - WEEKDAY(FunctionName.of("weekday")), - WEEKOFYEAR(FunctionName.of("weekofyear")), - WEEK_OF_YEAR(FunctionName.of("week_of_year")), - YEAR(FunctionName.of("year")), - YEARWEEK(FunctionName.of("yearweek")), + /** + * Date and Time Functions. + */ + ADDDATE(FunctionName.of("adddate")), + ADDTIME(FunctionName.of("addtime")), + CONVERT_TZ(FunctionName.of("convert_tz")), + DATE(FunctionName.of("date")), + DATEDIFF(FunctionName.of("datediff")), + DATETIME(FunctionName.of("datetime")), + DATE_ADD(FunctionName.of("date_add")), + DATE_FORMAT(FunctionName.of("date_format")), + DATE_SUB(FunctionName.of("date_sub")), + DAY(FunctionName.of("day")), + DAYNAME(FunctionName.of("dayname")), + DAYOFMONTH(FunctionName.of("dayofmonth")), + DAY_OF_MONTH(FunctionName.of("day_of_month")), + DAYOFWEEK(FunctionName.of("dayofweek")), + DAYOFYEAR(FunctionName.of("dayofyear")), + DAY_OF_WEEK(FunctionName.of("day_of_week")), + DAY_OF_YEAR(FunctionName.of("day_of_year")), + EXTRACT(FunctionName.of("extract")), + FROM_DAYS(FunctionName.of("from_days")), + FROM_UNIXTIME(FunctionName.of("from_unixtime")), + GET_FORMAT(FunctionName.of("get_format")), + HOUR(FunctionName.of("hour")), + HOUR_OF_DAY(FunctionName.of("hour_of_day")), + LAST_DAY(FunctionName.of("last_day")), + MAKEDATE(FunctionName.of("makedate")), + MAKETIME(FunctionName.of("maketime")), + MICROSECOND(FunctionName.of("microsecond")), + MINUTE(FunctionName.of("minute")), + MINUTE_OF_DAY(FunctionName.of("minute_of_day")), + MINUTE_OF_HOUR(FunctionName.of("minute_of_hour")), + MONTH(FunctionName.of("month")), + MONTH_OF_YEAR(FunctionName.of("month_of_year")), + MONTHNAME(FunctionName.of("monthname")), + PERIOD_ADD(FunctionName.of("period_add")), + PERIOD_DIFF(FunctionName.of("period_diff")), + QUARTER(FunctionName.of("quarter")), + SEC_TO_TIME(FunctionName.of("sec_to_time")), + SECOND(FunctionName.of("second")), + SECOND_OF_MINUTE(FunctionName.of("second_of_minute")), + STR_TO_DATE(FunctionName.of("str_to_date")), + SUBDATE(FunctionName.of("subdate")), + SUBTIME(FunctionName.of("subtime")), + TIME(FunctionName.of("time")), + TIMEDIFF(FunctionName.of("timediff")), + TIME_TO_SEC(FunctionName.of("time_to_sec")), + TIMESTAMP(FunctionName.of("timestamp")), + TIMESTAMPADD(FunctionName.of("timestampadd")), + TIMESTAMPDIFF(FunctionName.of("timestampdiff")), + TIME_FORMAT(FunctionName.of("time_format")), + TO_DAYS(FunctionName.of("to_days")), + TO_SECONDS(FunctionName.of("to_seconds")), + UTC_DATE(FunctionName.of("utc_date")), + UTC_TIME(FunctionName.of("utc_time")), + UTC_TIMESTAMP(FunctionName.of("utc_timestamp")), + UNIX_TIMESTAMP(FunctionName.of("unix_timestamp")), + WEEK(FunctionName.of("week")), + WEEKDAY(FunctionName.of("weekday")), + WEEKOFYEAR(FunctionName.of("weekofyear")), + WEEK_OF_YEAR(FunctionName.of("week_of_year")), + YEAR(FunctionName.of("year")), + YEARWEEK(FunctionName.of("yearweek")), - // `now`-like functions - NOW(FunctionName.of("now")), - CURDATE(FunctionName.of("curdate")), - CURRENT_DATE(FunctionName.of("current_date")), - CURTIME(FunctionName.of("curtime")), - CURRENT_TIME(FunctionName.of("current_time")), - LOCALTIME(FunctionName.of("localtime")), - CURRENT_TIMESTAMP(FunctionName.of("current_timestamp")), - LOCALTIMESTAMP(FunctionName.of("localtimestamp")), - SYSDATE(FunctionName.of("sysdate")), + // `now`-like functions + NOW(FunctionName.of("now")), + CURDATE(FunctionName.of("curdate")), + CURRENT_DATE(FunctionName.of("current_date")), + CURTIME(FunctionName.of("curtime")), + CURRENT_TIME(FunctionName.of("current_time")), + LOCALTIME(FunctionName.of("localtime")), + CURRENT_TIMESTAMP(FunctionName.of("current_timestamp")), + LOCALTIMESTAMP(FunctionName.of("localtimestamp")), + SYSDATE(FunctionName.of("sysdate")), - /** - * Text Functions. - */ - TOSTRING(FunctionName.of("tostring")), + /** + * Text Functions. + */ + TOSTRING(FunctionName.of("tostring")), - /** - * Arithmetic Operators. - */ - ADD(FunctionName.of("+")), - ADDFUNCTION(FunctionName.of("add")), - DIVIDE(FunctionName.of("/")), - DIVIDEFUNCTION(FunctionName.of("divide")), - MOD(FunctionName.of("mod")), - MODULUS(FunctionName.of("%")), - MODULUSFUNCTION(FunctionName.of("modulus")), - MULTIPLY(FunctionName.of("*")), - MULTIPLYFUNCTION(FunctionName.of("multiply")), - SUBTRACT(FunctionName.of("-")), - SUBTRACTFUNCTION(FunctionName.of("subtract")), + /** + * Arithmetic Operators. + */ + ADD(FunctionName.of("+")), + ADDFUNCTION(FunctionName.of("add")), + DIVIDE(FunctionName.of("/")), + DIVIDEFUNCTION(FunctionName.of("divide")), + MOD(FunctionName.of("mod")), + MODULUS(FunctionName.of("%")), + MODULUSFUNCTION(FunctionName.of("modulus")), + MULTIPLY(FunctionName.of("*")), + MULTIPLYFUNCTION(FunctionName.of("multiply")), + SUBTRACT(FunctionName.of("-")), + SUBTRACTFUNCTION(FunctionName.of("subtract")), - /** - * Boolean Operators. - */ - AND(FunctionName.of("and")), - OR(FunctionName.of("or")), - XOR(FunctionName.of("xor")), - NOT(FunctionName.of("not")), - EQUAL(FunctionName.of("=")), - NOTEQUAL(FunctionName.of("!=")), - LESS(FunctionName.of("<")), - LTE(FunctionName.of("<=")), - GREATER(FunctionName.of(">")), - GTE(FunctionName.of(">=")), - LIKE(FunctionName.of("like")), - NOT_LIKE(FunctionName.of("not like")), + /** + * Boolean Operators. + */ + AND(FunctionName.of("and")), + OR(FunctionName.of("or")), + XOR(FunctionName.of("xor")), + NOT(FunctionName.of("not")), + EQUAL(FunctionName.of("=")), + NOTEQUAL(FunctionName.of("!=")), + LESS(FunctionName.of("<")), + LTE(FunctionName.of("<=")), + GREATER(FunctionName.of(">")), + GTE(FunctionName.of(">=")), + LIKE(FunctionName.of("like")), + NOT_LIKE(FunctionName.of("not like")), - /** - * Aggregation Function. - */ - AVG(FunctionName.of("avg")), - SUM(FunctionName.of("sum")), - COUNT(FunctionName.of("count")), - MIN(FunctionName.of("min")), - MAX(FunctionName.of("max")), - // sample variance - VARSAMP(FunctionName.of("var_samp")), - // population standard variance - VARPOP(FunctionName.of("var_pop")), - // sample standard deviation. - STDDEV_SAMP(FunctionName.of("stddev_samp")), - // population standard deviation. - STDDEV_POP(FunctionName.of("stddev_pop")), - // take top documents from aggregation bucket. - TAKE(FunctionName.of("take")), - // Not always an aggregation query - NESTED(FunctionName.of("nested")), + /** + * Aggregation Function. + */ + AVG(FunctionName.of("avg")), + SUM(FunctionName.of("sum")), + COUNT(FunctionName.of("count")), + MIN(FunctionName.of("min")), + MAX(FunctionName.of("max")), + // sample variance + VARSAMP(FunctionName.of("var_samp")), + // population standard variance + VARPOP(FunctionName.of("var_pop")), + // sample standard deviation. + STDDEV_SAMP(FunctionName.of("stddev_samp")), + // population standard deviation. + STDDEV_POP(FunctionName.of("stddev_pop")), + // take top documents from aggregation bucket. + TAKE(FunctionName.of("take")), + // Not always an aggregation query + NESTED(FunctionName.of("nested")), - /** - * Text Functions. - */ - ASCII(FunctionName.of("ascii")), - CONCAT(FunctionName.of("concat")), - CONCAT_WS(FunctionName.of("concat_ws")), - LEFT(FunctionName.of("left")), - LENGTH(FunctionName.of("length")), - LOCATE(FunctionName.of("locate")), - LOWER(FunctionName.of("lower")), - LTRIM(FunctionName.of("ltrim")), - POSITION(FunctionName.of("position")), - REGEXP(FunctionName.of("regexp")), - REPLACE(FunctionName.of("replace")), - REVERSE(FunctionName.of("reverse")), - RIGHT(FunctionName.of("right")), - RTRIM(FunctionName.of("rtrim")), - STRCMP(FunctionName.of("strcmp")), - SUBSTR(FunctionName.of("substr")), - SUBSTRING(FunctionName.of("substring")), - TRIM(FunctionName.of("trim")), - UPPER(FunctionName.of("upper")), + /** + * Text Functions. + */ + ASCII(FunctionName.of("ascii")), + CONCAT(FunctionName.of("concat")), + CONCAT_WS(FunctionName.of("concat_ws")), + LEFT(FunctionName.of("left")), + LENGTH(FunctionName.of("length")), + LOCATE(FunctionName.of("locate")), + LOWER(FunctionName.of("lower")), + LTRIM(FunctionName.of("ltrim")), + POSITION(FunctionName.of("position")), + REGEXP(FunctionName.of("regexp")), + REPLACE(FunctionName.of("replace")), + REVERSE(FunctionName.of("reverse")), + RIGHT(FunctionName.of("right")), + RTRIM(FunctionName.of("rtrim")), + STRCMP(FunctionName.of("strcmp")), + SUBSTR(FunctionName.of("substr")), + SUBSTRING(FunctionName.of("substring")), + TRIM(FunctionName.of("trim")), + UPPER(FunctionName.of("upper")), - /** - * NULL Test. - */ - IS_NULL(FunctionName.of("is null")), - IS_NOT_NULL(FunctionName.of("is not null")), - IFNULL(FunctionName.of("ifnull")), - IF(FunctionName.of("if")), - NULLIF(FunctionName.of("nullif")), - ISNULL(FunctionName.of("isnull")), + /** + * NULL Test. + */ + IS_NULL(FunctionName.of("is null")), + IS_NOT_NULL(FunctionName.of("is not null")), + IFNULL(FunctionName.of("ifnull")), + IF(FunctionName.of("if")), + NULLIF(FunctionName.of("nullif")), + ISNULL(FunctionName.of("isnull")), - ROW_NUMBER(FunctionName.of("row_number")), - RANK(FunctionName.of("rank")), - DENSE_RANK(FunctionName.of("dense_rank")), + ROW_NUMBER(FunctionName.of("row_number")), + RANK(FunctionName.of("rank")), + DENSE_RANK(FunctionName.of("dense_rank")), - INTERVAL(FunctionName.of("interval")), + INTERVAL(FunctionName.of("interval")), - /** - * Data Type Convert Function. - */ - CAST_TO_STRING(FunctionName.of("cast_to_string")), - CAST_TO_BYTE(FunctionName.of("cast_to_byte")), - CAST_TO_SHORT(FunctionName.of("cast_to_short")), - CAST_TO_INT(FunctionName.of("cast_to_int")), - CAST_TO_LONG(FunctionName.of("cast_to_long")), - CAST_TO_FLOAT(FunctionName.of("cast_to_float")), - CAST_TO_DOUBLE(FunctionName.of("cast_to_double")), - CAST_TO_BOOLEAN(FunctionName.of("cast_to_boolean")), - CAST_TO_DATE(FunctionName.of("cast_to_date")), - CAST_TO_TIME(FunctionName.of("cast_to_time")), - CAST_TO_TIMESTAMP(FunctionName.of("cast_to_timestamp")), - CAST_TO_DATETIME(FunctionName.of("cast_to_datetime")), - TYPEOF(FunctionName.of("typeof")), + /** + * Data Type Convert Function. + */ + CAST_TO_STRING(FunctionName.of("cast_to_string")), + CAST_TO_BYTE(FunctionName.of("cast_to_byte")), + CAST_TO_SHORT(FunctionName.of("cast_to_short")), + CAST_TO_INT(FunctionName.of("cast_to_int")), + CAST_TO_LONG(FunctionName.of("cast_to_long")), + CAST_TO_FLOAT(FunctionName.of("cast_to_float")), + CAST_TO_DOUBLE(FunctionName.of("cast_to_double")), + CAST_TO_BOOLEAN(FunctionName.of("cast_to_boolean")), + CAST_TO_DATE(FunctionName.of("cast_to_date")), + CAST_TO_TIME(FunctionName.of("cast_to_time")), + CAST_TO_TIMESTAMP(FunctionName.of("cast_to_timestamp")), + CAST_TO_DATETIME(FunctionName.of("cast_to_datetime")), + TYPEOF(FunctionName.of("typeof")), - /** - * Relevance Function. - */ - MATCH(FunctionName.of("match")), - SIMPLE_QUERY_STRING(FunctionName.of("simple_query_string")), - MATCH_PHRASE(FunctionName.of("match_phrase")), - MATCHPHRASE(FunctionName.of("matchphrase")), - MATCHPHRASEQUERY(FunctionName.of("matchphrasequery")), - QUERY_STRING(FunctionName.of("query_string")), - MATCH_BOOL_PREFIX(FunctionName.of("match_bool_prefix")), - HIGHLIGHT(FunctionName.of("highlight")), - MATCH_PHRASE_PREFIX(FunctionName.of("match_phrase_prefix")), - SCORE(FunctionName.of("score")), - SCOREQUERY(FunctionName.of("scorequery")), - SCORE_QUERY(FunctionName.of("score_query")), + /** + * Relevance Function. + */ + MATCH(FunctionName.of("match")), + SIMPLE_QUERY_STRING(FunctionName.of("simple_query_string")), + MATCH_PHRASE(FunctionName.of("match_phrase")), + MATCHPHRASE(FunctionName.of("matchphrase")), + MATCHPHRASEQUERY(FunctionName.of("matchphrasequery")), + QUERY_STRING(FunctionName.of("query_string")), + MATCH_BOOL_PREFIX(FunctionName.of("match_bool_prefix")), + HIGHLIGHT(FunctionName.of("highlight")), + MATCH_PHRASE_PREFIX(FunctionName.of("match_phrase_prefix")), + SCORE(FunctionName.of("score")), + SCOREQUERY(FunctionName.of("scorequery")), + SCORE_QUERY(FunctionName.of("score_query")), - /** - * Legacy Relevance Function. - */ - QUERY(FunctionName.of("query")), - MATCH_QUERY(FunctionName.of("match_query")), - MATCHQUERY(FunctionName.of("matchquery")), - MULTI_MATCH(FunctionName.of("multi_match")), - MULTIMATCH(FunctionName.of("multimatch")), - MULTIMATCHQUERY(FunctionName.of("multimatchquery")), - WILDCARDQUERY(FunctionName.of("wildcardquery")), - WILDCARD_QUERY(FunctionName.of("wildcard_query")); + /** + * Legacy Relevance Function. + */ + QUERY(FunctionName.of("query")), + MATCH_QUERY(FunctionName.of("match_query")), + MATCHQUERY(FunctionName.of("matchquery")), + MULTI_MATCH(FunctionName.of("multi_match")), + MULTIMATCH(FunctionName.of("multimatch")), + MULTIMATCHQUERY(FunctionName.of("multimatchquery")), + WILDCARDQUERY(FunctionName.of("wildcardquery")), + WILDCARD_QUERY(FunctionName.of("wildcard_query")); - private final FunctionName name; + private final FunctionName name; - private static final Map ALL_NATIVE_FUNCTIONS; + private static final Map ALL_NATIVE_FUNCTIONS; - static { - ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); - for (BuiltinFunctionName func : BuiltinFunctionName.values()) { - builder.put(func.getName(), func); + static { + ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); + for (BuiltinFunctionName func : BuiltinFunctionName.values()) { + builder.put(func.getName(), func); + } + ALL_NATIVE_FUNCTIONS = builder.build(); } - ALL_NATIVE_FUNCTIONS = builder.build(); - } - private static final Map AGGREGATION_FUNC_MAPPING = - new ImmutableMap.Builder() - .put("max", BuiltinFunctionName.MAX) - .put("min", BuiltinFunctionName.MIN) - .put("avg", BuiltinFunctionName.AVG) - .put("count", BuiltinFunctionName.COUNT) - .put("sum", BuiltinFunctionName.SUM) - .put("var_pop", BuiltinFunctionName.VARPOP) - .put("var_samp", BuiltinFunctionName.VARSAMP) - .put("variance", BuiltinFunctionName.VARPOP) - .put("std", BuiltinFunctionName.STDDEV_POP) - .put("stddev", BuiltinFunctionName.STDDEV_POP) - .put("stddev_pop", BuiltinFunctionName.STDDEV_POP) - .put("stddev_samp", BuiltinFunctionName.STDDEV_SAMP) - .put("take", BuiltinFunctionName.TAKE) - .build(); + private static final Map AGGREGATION_FUNC_MAPPING = new ImmutableMap.Builder() + .put("max", BuiltinFunctionName.MAX) + .put("min", BuiltinFunctionName.MIN) + .put("avg", BuiltinFunctionName.AVG) + .put("count", BuiltinFunctionName.COUNT) + .put("sum", BuiltinFunctionName.SUM) + .put("var_pop", BuiltinFunctionName.VARPOP) + .put("var_samp", BuiltinFunctionName.VARSAMP) + .put("variance", BuiltinFunctionName.VARPOP) + .put("std", BuiltinFunctionName.STDDEV_POP) + .put("stddev", BuiltinFunctionName.STDDEV_POP) + .put("stddev_pop", BuiltinFunctionName.STDDEV_POP) + .put("stddev_samp", BuiltinFunctionName.STDDEV_SAMP) + .put("take", BuiltinFunctionName.TAKE) + .build(); - public static Optional of(String str) { - return Optional.ofNullable(ALL_NATIVE_FUNCTIONS.getOrDefault(FunctionName.of(str), null)); - } + public static Optional of(String str) { + return Optional.ofNullable(ALL_NATIVE_FUNCTIONS.getOrDefault(FunctionName.of(str), null)); + } - public static Optional ofAggregation(String functionName) { - return Optional.ofNullable( - AGGREGATION_FUNC_MAPPING.getOrDefault(functionName.toLowerCase(Locale.ROOT), null)); - } + public static Optional ofAggregation(String functionName) { + return Optional.ofNullable(AGGREGATION_FUNC_MAPPING.getOrDefault(functionName.toLowerCase(Locale.ROOT), null)); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java index 0eb11a9280..4b8c4e1dc9 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java @@ -45,173 +45,159 @@ */ public class BuiltinFunctionRepository { - private final Map functionResolverMap; - - /** The singleton instance. */ - private static BuiltinFunctionRepository instance; - - /** - * Construct a function repository with the given function registered. This is only used in test. - * - * @param functionResolverMap function supported - */ - @VisibleForTesting - BuiltinFunctionRepository(Map functionResolverMap) { - this.functionResolverMap = functionResolverMap; - } - - /** - * Get singleton instance of the function repository. Initialize it with all built-in functions - * for the first time in synchronized way. - * - * @return singleton instance - */ - public static synchronized BuiltinFunctionRepository getInstance() { - if (instance == null) { - instance = new BuiltinFunctionRepository(new HashMap<>()); - - // Register all built-in functions - ArithmeticFunction.register(instance); - BinaryPredicateOperator.register(instance); - MathematicalFunction.register(instance); - UnaryPredicateOperator.register(instance); - AggregatorFunction.register(instance); - DateTimeFunction.register(instance); - IntervalClause.register(instance); - WindowFunctions.register(instance); - TextFunction.register(instance); - TypeCastOperator.register(instance); - SystemFunctions.register(instance); - OpenSearchFunctions.register(instance); + private final Map functionResolverMap; + + /** The singleton instance. */ + private static BuiltinFunctionRepository instance; + + /** + * Construct a function repository with the given function registered. This is only used in test. + * + * @param functionResolverMap function supported + */ + @VisibleForTesting + BuiltinFunctionRepository(Map functionResolverMap) { + this.functionResolverMap = functionResolverMap; } - return instance; - } - - /** - * Register {@link DefaultFunctionResolver} to the Builtin Function Repository. - * - * @param resolver {@link DefaultFunctionResolver} to be registered - */ - public void register(FunctionResolver resolver) { - functionResolverMap.put(resolver.getFunctionName(), resolver); - } - - /** - * Compile FunctionExpression using core function resolver. - * - */ - public FunctionImplementation compile(FunctionProperties functionProperties, - FunctionName functionName, List expressions) { - return compile(functionProperties, Collections.emptyList(), functionName, expressions); - } - - - /** - * Compile FunctionExpression within {@link StorageEngine} provided {@link FunctionResolver}. - */ - public FunctionImplementation compile(FunctionProperties functionProperties, - Collection dataSourceFunctionResolver, - FunctionName functionName, - List expressions) { - FunctionBuilder resolvedFunctionBuilder = - resolve( + + /** + * Get singleton instance of the function repository. Initialize it with all built-in functions + * for the first time in synchronized way. + * + * @return singleton instance + */ + public static synchronized BuiltinFunctionRepository getInstance() { + if (instance == null) { + instance = new BuiltinFunctionRepository(new HashMap<>()); + + // Register all built-in functions + ArithmeticFunction.register(instance); + BinaryPredicateOperator.register(instance); + MathematicalFunction.register(instance); + UnaryPredicateOperator.register(instance); + AggregatorFunction.register(instance); + DateTimeFunction.register(instance); + IntervalClause.register(instance); + WindowFunctions.register(instance); + TextFunction.register(instance); + TypeCastOperator.register(instance); + SystemFunctions.register(instance); + OpenSearchFunctions.register(instance); + } + return instance; + } + + /** + * Register {@link DefaultFunctionResolver} to the Builtin Function Repository. + * + * @param resolver {@link DefaultFunctionResolver} to be registered + */ + public void register(FunctionResolver resolver) { + functionResolverMap.put(resolver.getFunctionName(), resolver); + } + + /** + * Compile FunctionExpression using core function resolver. + * + */ + public FunctionImplementation compile(FunctionProperties functionProperties, FunctionName functionName, List expressions) { + return compile(functionProperties, Collections.emptyList(), functionName, expressions); + } + + /** + * Compile FunctionExpression within {@link StorageEngine} provided {@link FunctionResolver}. + */ + public FunctionImplementation compile( + FunctionProperties functionProperties, + Collection dataSourceFunctionResolver, + FunctionName functionName, + List expressions + ) { + FunctionBuilder resolvedFunctionBuilder = resolve( dataSourceFunctionResolver, - new FunctionSignature( - functionName, - expressions.stream().map(Expression::type).collect(Collectors.toList()))); - return resolvedFunctionBuilder.apply(functionProperties, expressions); - } - - /** - * Resolve the {@link FunctionBuilder} in repository under a list of namespaces. Returns the First - * FunctionBuilder found. So list of namespaces is also the priority of namespaces. - * - * @param functionSignature {@link FunctionSignature} functionsignature. - * @return Original function builder if it's a cast function or all arguments have expected types - * or otherwise wrap its arguments by cast function as needed. - */ - @VisibleForTesting - public FunctionBuilder resolve( - Collection dataSourceFunctionResolver, - FunctionSignature functionSignature) { - Map dataSourceFunctionMap = dataSourceFunctionResolver.stream() - .collect(Collectors.toMap(FunctionResolver::getFunctionName, t -> t)); - - // first, resolve in datasource provide function resolver. - // second, resolve in builtin function resolver. - return resolve(functionSignature, dataSourceFunctionMap) - .or(() -> resolve(functionSignature, functionResolverMap)) - .orElseThrow( - () -> - new ExpressionEvaluationException( - String.format( - "unsupported function name: %s", functionSignature.getFunctionName()))); - } - - private Optional resolve( - FunctionSignature functionSignature, - Map functionResolverMap) { - FunctionName functionName = functionSignature.getFunctionName(); - if (functionResolverMap.containsKey(functionName)) { - Pair resolvedSignature = - functionResolverMap.get(functionName).resolve(functionSignature); - - List sourceTypes = functionSignature.getParamTypeList(); - List targetTypes = resolvedSignature.getKey().getParamTypeList(); - FunctionBuilder funcBuilder = resolvedSignature.getValue(); - if (isCastFunction(functionName) - || FunctionSignature.isVarArgFunction(targetTypes) - || sourceTypes.equals(targetTypes)) { - return Optional.of(funcBuilder); - } - return Optional.of(castArguments(sourceTypes, targetTypes, funcBuilder)); - } else { - return Optional.empty(); + new FunctionSignature(functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())) + ); + return resolvedFunctionBuilder.apply(functionProperties, expressions); + } + + /** + * Resolve the {@link FunctionBuilder} in repository under a list of namespaces. Returns the First + * FunctionBuilder found. So list of namespaces is also the priority of namespaces. + * + * @param functionSignature {@link FunctionSignature} functionsignature. + * @return Original function builder if it's a cast function or all arguments have expected types + * or otherwise wrap its arguments by cast function as needed. + */ + @VisibleForTesting + public FunctionBuilder resolve(Collection dataSourceFunctionResolver, FunctionSignature functionSignature) { + Map dataSourceFunctionMap = dataSourceFunctionResolver.stream() + .collect(Collectors.toMap(FunctionResolver::getFunctionName, t -> t)); + + // first, resolve in datasource provide function resolver. + // second, resolve in builtin function resolver. + return resolve(functionSignature, dataSourceFunctionMap).or(() -> resolve(functionSignature, functionResolverMap)) + .orElseThrow( + () -> new ExpressionEvaluationException(String.format("unsupported function name: %s", functionSignature.getFunctionName())) + ); } - } - - /** - * Wrap resolved function builder's arguments by cast function to cast input expression value - * to value of target type at runtime. For example, suppose unresolved signature is - * equal(BOOL,STRING) and its resolved function builder is F with signature equal(BOOL,BOOL). - * In this case, wrap F and return equal(BOOL, cast_to_bool(STRING)). - */ - private FunctionBuilder castArguments(List sourceTypes, - List targetTypes, - FunctionBuilder funcBuilder) { - return (fp, arguments) -> { - List argsCasted = new ArrayList<>(); - for (int i = 0; i < arguments.size(); i++) { - Expression arg = arguments.get(i); - ExprType sourceType = sourceTypes.get(i); - ExprType targetType = targetTypes.get(i); - - if (isCastRequired(sourceType, targetType)) { - argsCasted.add(cast(arg, targetType).apply(fp)); + + private Optional resolve( + FunctionSignature functionSignature, + Map functionResolverMap + ) { + FunctionName functionName = functionSignature.getFunctionName(); + if (functionResolverMap.containsKey(functionName)) { + Pair resolvedSignature = functionResolverMap.get(functionName).resolve(functionSignature); + + List sourceTypes = functionSignature.getParamTypeList(); + List targetTypes = resolvedSignature.getKey().getParamTypeList(); + FunctionBuilder funcBuilder = resolvedSignature.getValue(); + if (isCastFunction(functionName) || FunctionSignature.isVarArgFunction(targetTypes) || sourceTypes.equals(targetTypes)) { + return Optional.of(funcBuilder); + } + return Optional.of(castArguments(sourceTypes, targetTypes, funcBuilder)); } else { - argsCasted.add(arg); + return Optional.empty(); } - } - return funcBuilder.apply(fp, argsCasted); - }; - } - - private boolean isCastRequired(ExprType sourceType, ExprType targetType) { - // TODO: Remove this special case after fixing all failed UTs - if (ExprCoreType.numberTypes().contains(sourceType) - && ExprCoreType.numberTypes().contains(targetType)) { - return false; } - return sourceType.shouldCast(targetType); - } - - private Function cast(Expression arg, ExprType targetType) { - FunctionName castFunctionName = getCastFunctionName(targetType); - if (castFunctionName == null) { - throw new ExpressionEvaluationException(StringUtils.format( - "Type conversion to type %s is not supported", targetType)); + + /** + * Wrap resolved function builder's arguments by cast function to cast input expression value + * to value of target type at runtime. For example, suppose unresolved signature is + * equal(BOOL,STRING) and its resolved function builder is F with signature equal(BOOL,BOOL). + * In this case, wrap F and return equal(BOOL, cast_to_bool(STRING)). + */ + private FunctionBuilder castArguments(List sourceTypes, List targetTypes, FunctionBuilder funcBuilder) { + return (fp, arguments) -> { + List argsCasted = new ArrayList<>(); + for (int i = 0; i < arguments.size(); i++) { + Expression arg = arguments.get(i); + ExprType sourceType = sourceTypes.get(i); + ExprType targetType = targetTypes.get(i); + + if (isCastRequired(sourceType, targetType)) { + argsCasted.add(cast(arg, targetType).apply(fp)); + } else { + argsCasted.add(arg); + } + } + return funcBuilder.apply(fp, argsCasted); + }; + } + + private boolean isCastRequired(ExprType sourceType, ExprType targetType) { + // TODO: Remove this special case after fixing all failed UTs + if (ExprCoreType.numberTypes().contains(sourceType) && ExprCoreType.numberTypes().contains(targetType)) { + return false; + } + return sourceType.shouldCast(targetType); + } + + private Function cast(Expression arg, ExprType targetType) { + FunctionName castFunctionName = getCastFunctionName(targetType); + if (castFunctionName == null) { + throw new ExpressionEvaluationException(StringUtils.format("Type conversion to type %s is not supported", targetType)); + } + return functionProperties -> (Expression) compile(functionProperties, castFunctionName, List.of(arg)); } - return functionProperties -> (Expression) compile(functionProperties, - castFunctionName, List.of(arg)); - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java b/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java index a28fa7e0ad..bcbef76896 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java @@ -26,52 +26,50 @@ @Builder @RequiredArgsConstructor public class DefaultFunctionResolver implements FunctionResolver { - @Getter - private final FunctionName functionName; - @Singular("functionBundle") - private final Map functionBundle; + @Getter + private final FunctionName functionName; + @Singular("functionBundle") + private final Map functionBundle; - /** - * Resolve the {@link FunctionBuilder} by using input {@link FunctionSignature}. - * If the {@link FunctionBuilder} exactly match the input {@link FunctionSignature}, return it. - * If applying the widening rule, found the most match one, return it. - * If nothing found, throw {@link ExpressionEvaluationException} - * - * @return function signature and its builder - */ - @Override - public Pair resolve(FunctionSignature unresolvedSignature) { - PriorityQueue> functionMatchQueue = new PriorityQueue<>( - Map.Entry.comparingByKey()); + /** + * Resolve the {@link FunctionBuilder} by using input {@link FunctionSignature}. + * If the {@link FunctionBuilder} exactly match the input {@link FunctionSignature}, return it. + * If applying the widening rule, found the most match one, return it. + * If nothing found, throw {@link ExpressionEvaluationException} + * + * @return function signature and its builder + */ + @Override + public Pair resolve(FunctionSignature unresolvedSignature) { + PriorityQueue> functionMatchQueue = new PriorityQueue<>(Map.Entry.comparingByKey()); - for (FunctionSignature functionSignature : functionBundle.keySet()) { - functionMatchQueue.add( - new AbstractMap.SimpleEntry<>(unresolvedSignature.match(functionSignature), - functionSignature)); - } - Map.Entry bestMatchEntry = functionMatchQueue.peek(); - if (FunctionSignature.isVarArgFunction(bestMatchEntry.getValue().getParamTypeList()) - && (unresolvedSignature.getParamTypeList().isEmpty() - || unresolvedSignature.getParamTypeList().size() > 9)) { - throw new ExpressionEvaluationException( - String.format("%s function expected 1-9 arguments, but got %d", - functionName, unresolvedSignature.getParamTypeList().size())); - } - if (FunctionSignature.NOT_MATCH.equals(bestMatchEntry.getKey()) + for (FunctionSignature functionSignature : functionBundle.keySet()) { + functionMatchQueue.add(new AbstractMap.SimpleEntry<>(unresolvedSignature.match(functionSignature), functionSignature)); + } + Map.Entry bestMatchEntry = functionMatchQueue.peek(); + if (FunctionSignature.isVarArgFunction(bestMatchEntry.getValue().getParamTypeList()) + && (unresolvedSignature.getParamTypeList().isEmpty() || unresolvedSignature.getParamTypeList().size() > 9)) { + throw new ExpressionEvaluationException( + String.format("%s function expected 1-9 arguments, but got %d", functionName, unresolvedSignature.getParamTypeList().size()) + ); + } + if (FunctionSignature.NOT_MATCH.equals(bestMatchEntry.getKey()) && !FunctionSignature.isVarArgFunction(bestMatchEntry.getValue().getParamTypeList())) { - throw new ExpressionEvaluationException( - String.format("%s function expected %s, but get %s", functionName, - formatFunctions(functionBundle.keySet()), - unresolvedSignature.formatTypes() - )); - } else { - FunctionSignature resolvedSignature = bestMatchEntry.getValue(); - return Pair.of(resolvedSignature, functionBundle.get(resolvedSignature)); + throw new ExpressionEvaluationException( + String.format( + "%s function expected %s, but get %s", + functionName, + formatFunctions(functionBundle.keySet()), + unresolvedSignature.formatTypes() + ) + ); + } else { + FunctionSignature resolvedSignature = bestMatchEntry.getValue(); + return Pair.of(resolvedSignature, functionBundle.get(resolvedSignature)); + } } - } - private String formatFunctions(Set functionSignatures) { - return functionSignatures.stream().map(FunctionSignature::formatTypes) - .collect(Collectors.joining(",", "{", "}")); - } + private String formatFunctions(Set functionSignatures) { + return functionSignatures.stream().map(FunctionSignature::formatTypes).collect(Collectors.joining(",", "{", "}")); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionBuilder.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionBuilder.java index b6e32a1d27..33c7f443f4 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/FunctionBuilder.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionBuilder.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.function; import java.util.List; @@ -15,12 +14,12 @@ */ public interface FunctionBuilder { - /** - * Create {@link FunctionImplementation} from input {@link Expression} list. - * - * @param functionProperties context for function execution. - * @param arguments {@link Expression} list. - * @return {@link FunctionImplementation} - */ - FunctionImplementation apply(FunctionProperties functionProperties, List arguments); + /** + * Create {@link FunctionImplementation} from input {@link Expression} list. + * + * @param functionProperties context for function execution. + * @param arguments {@link Expression} list. + * @return {@link FunctionImplementation} + */ + FunctionImplementation apply(FunctionProperties functionProperties, List arguments); } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java index c57d96caea..9431c075bb 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.function; import java.util.Arrays; @@ -26,473 +25,469 @@ */ @UtilityClass public class FunctionDSL { - /** - * Define overloaded function with implementation. - * - * @param functionName function name. - * @param functions a list of function implementation. - * @return FunctionResolver. - */ - public static DefaultFunctionResolver define(FunctionName functionName, - SerializableFunction>... functions) { - return define(functionName, List.of(functions)); - } - - /** - * Define overloaded function with implementation. - * - * @param functionName function name. - * @param functions a list of function implementation. - * @return FunctionResolver. - */ - public static DefaultFunctionResolver define(FunctionName functionName, List< - SerializableFunction>> functions) { - - DefaultFunctionResolverBuilder builder = DefaultFunctionResolver.builder(); - builder.functionName(functionName); - for (Function> func : functions) { - Pair functionBuilder = func.apply(functionName); - builder.functionBundle(functionBuilder.getKey(), functionBuilder.getValue()); + /** + * Define overloaded function with implementation. + * + * @param functionName function name. + * @param functions a list of function implementation. + * @return FunctionResolver. + */ + public static DefaultFunctionResolver define( + FunctionName functionName, + SerializableFunction>... functions + ) { + return define(functionName, List.of(functions)); + } + + /** + * Define overloaded function with implementation. + * + * @param functionName function name. + * @param functions a list of function implementation. + * @return FunctionResolver. + */ + public static DefaultFunctionResolver define( + FunctionName functionName, + List>> functions + ) { + + DefaultFunctionResolverBuilder builder = DefaultFunctionResolver.builder(); + builder.functionName(functionName); + for (Function> func : functions) { + Pair functionBuilder = func.apply(functionName); + builder.functionBundle(functionBuilder.getKey(), functionBuilder.getValue()); + } + return builder.build(); } - return builder.build(); - } - - - /** - * Implementation of no args function that uses FunctionProperties. - * - * @param function {@link ExprValue} based no args function. - * @param returnType function return type. - * @return no args function implementation. - */ - public static SerializableFunction> - implWithProperties(SerializableFunction function, - ExprType returnType) { - return functionName -> { - FunctionSignature functionSignature = - new FunctionSignature(functionName, Collections.emptyList()); - FunctionBuilder functionBuilder = - (functionProperties, arguments) -> - new FunctionExpression(functionName, Collections.emptyList()) { + + /** + * Implementation of no args function that uses FunctionProperties. + * + * @param function {@link ExprValue} based no args function. + * @param returnType function return type. + * @return no args function implementation. + */ + public static SerializableFunction> implWithProperties( + SerializableFunction function, + ExprType returnType + ) { + return functionName -> { + FunctionSignature functionSignature = new FunctionSignature(functionName, Collections.emptyList()); + FunctionBuilder functionBuilder = (functionProperties, arguments) -> new FunctionExpression( + functionName, + Collections.emptyList() + ) { @Override public ExprValue valueOf(Environment valueEnv) { - return function.apply(functionProperties); + return function.apply(functionProperties); } @Override public ExprType type() { - return returnType; + return returnType; } @Override public String toString() { - return String.format("%s()", functionName); + return String.format("%s()", functionName); } - }; - return Pair.of(functionSignature, functionBuilder); - }; - } - - /** - * Implementation of a function that takes one argument, returns a value, and - * requires FunctionProperties to complete. - * - * @param function {@link ExprValue} based unary function. - * @param returnType return type. - * @param argsType argument type. - * @return Unary Function Implementation. - */ - public static SerializableFunction> - implWithProperties( - SerializableBiFunction function, - ExprType returnType, - ExprType argsType) { - - return functionName -> { - FunctionSignature functionSignature = - new FunctionSignature(functionName, Collections.singletonList(argsType)); - FunctionBuilder functionBuilder = - (functionProperties, arguments) -> new FunctionExpression(functionName, arguments) { - @Override - public ExprValue valueOf(Environment valueEnv) { - ExprValue value = arguments.get(0).valueOf(valueEnv); - return function.apply(functionProperties, value); - } + }; + return Pair.of(functionSignature, functionBuilder); + }; + } - @Override - public ExprType type() { - return returnType; - } + /** + * Implementation of a function that takes one argument, returns a value, and + * requires FunctionProperties to complete. + * + * @param function {@link ExprValue} based unary function. + * @param returnType return type. + * @param argsType argument type. + * @return Unary Function Implementation. + */ + public static SerializableFunction> implWithProperties( + SerializableBiFunction function, + ExprType returnType, + ExprType argsType + ) { + + return functionName -> { + FunctionSignature functionSignature = new FunctionSignature(functionName, Collections.singletonList(argsType)); + FunctionBuilder functionBuilder = (functionProperties, arguments) -> new FunctionExpression(functionName, arguments) { + @Override + public ExprValue valueOf(Environment valueEnv) { + ExprValue value = arguments.get(0).valueOf(valueEnv); + return function.apply(functionProperties, value); + } - @Override - public String toString() { - return String.format("%s(%s)", functionName, - arguments.stream() - .map(Object::toString) - .collect(Collectors.joining(", "))); - } - }; - return Pair.of(functionSignature, functionBuilder); - }; - } - - /** - * Implementation of a function that takes two arguments, returns a value, and - * requires FunctionProperties to complete. - * - * @param function {@link ExprValue} based Binary function. - * @param returnType return type. - * @param args1Type first argument type. - * @param args2Type second argument type. - * @return Binary Function Implementation. - */ - public static SerializableFunction> - implWithProperties( - SerializableTriFunction function, - ExprType returnType, - ExprType args1Type, - ExprType args2Type) { - - return functionName -> { - FunctionSignature functionSignature = - new FunctionSignature(functionName, Arrays.asList(args1Type, args2Type)); - FunctionBuilder functionBuilder = - (functionProperties, arguments) -> new FunctionExpression(functionName, arguments) { - @Override - public ExprValue valueOf(Environment valueEnv) { - ExprValue arg1 = arguments.get(0).valueOf(valueEnv); - ExprValue arg2 = arguments.get(1).valueOf(valueEnv); - return function.apply(functionProperties, arg1, arg2); - } + @Override + public ExprType type() { + return returnType; + } - @Override - public ExprType type() { - return returnType; - } + @Override + public String toString() { + return String.format( + "%s(%s)", + functionName, + arguments.stream().map(Object::toString).collect(Collectors.joining(", ")) + ); + } + }; + return Pair.of(functionSignature, functionBuilder); + }; + } - @Override - public String toString() { - return String.format("%s(%s)", functionName, - arguments.stream() - .map(Object::toString) - .collect(Collectors.joining(", "))); - } - }; - return Pair.of(functionSignature, functionBuilder); - }; - } - - /** - * Implementation of a function that takes three arguments, returns a value, and - * requires FunctionProperties to complete. - * - * @param function {@link ExprValue} based Binary function. - * @param returnType return type. - * @param args1Type first argument type. - * @param args2Type second argument type. - * @param args3Type third argument type. - * @return Binary Function Implementation. - */ - public static SerializableFunction> - implWithProperties( - SerializableQuadFunction< - FunctionProperties, - ExprValue, - ExprValue, - ExprValue, - ExprValue> function, - ExprType returnType, - ExprType args1Type, - ExprType args2Type, - ExprType args3Type) { - - return functionName -> { - FunctionSignature functionSignature = - new FunctionSignature(functionName, Arrays.asList(args1Type, args2Type, args3Type)); - FunctionBuilder functionBuilder = - (functionProperties, arguments) -> new FunctionExpression(functionName, arguments) { - @Override - public ExprValue valueOf(Environment valueEnv) { - ExprValue arg1 = arguments.get(0).valueOf(valueEnv); - ExprValue arg2 = arguments.get(1).valueOf(valueEnv); - ExprValue arg3 = arguments.get(2).valueOf(valueEnv); - return function.apply(functionProperties, arg1, arg2, arg3); - } + /** + * Implementation of a function that takes two arguments, returns a value, and + * requires FunctionProperties to complete. + * + * @param function {@link ExprValue} based Binary function. + * @param returnType return type. + * @param args1Type first argument type. + * @param args2Type second argument type. + * @return Binary Function Implementation. + */ + public static SerializableFunction> implWithProperties( + SerializableTriFunction function, + ExprType returnType, + ExprType args1Type, + ExprType args2Type + ) { + + return functionName -> { + FunctionSignature functionSignature = new FunctionSignature(functionName, Arrays.asList(args1Type, args2Type)); + FunctionBuilder functionBuilder = (functionProperties, arguments) -> new FunctionExpression(functionName, arguments) { + @Override + public ExprValue valueOf(Environment valueEnv) { + ExprValue arg1 = arguments.get(0).valueOf(valueEnv); + ExprValue arg2 = arguments.get(1).valueOf(valueEnv); + return function.apply(functionProperties, arg1, arg2); + } - @Override - public ExprType type() { - return returnType; - } + @Override + public ExprType type() { + return returnType; + } - @Override - public String toString() { - return String.format("%s(%s)", functionName, - arguments.stream() - .map(Object::toString) - .collect(Collectors.joining(", "))); + @Override + public String toString() { + return String.format( + "%s(%s)", + functionName, + arguments.stream().map(Object::toString).collect(Collectors.joining(", ")) + ); + } + }; + return Pair.of(functionSignature, functionBuilder); + }; + } + + /** + * Implementation of a function that takes three arguments, returns a value, and + * requires FunctionProperties to complete. + * + * @param function {@link ExprValue} based Binary function. + * @param returnType return type. + * @param args1Type first argument type. + * @param args2Type second argument type. + * @param args3Type third argument type. + * @return Binary Function Implementation. + */ + public static SerializableFunction> implWithProperties( + SerializableQuadFunction function, + ExprType returnType, + ExprType args1Type, + ExprType args2Type, + ExprType args3Type + ) { + + return functionName -> { + FunctionSignature functionSignature = new FunctionSignature(functionName, Arrays.asList(args1Type, args2Type, args3Type)); + FunctionBuilder functionBuilder = (functionProperties, arguments) -> new FunctionExpression(functionName, arguments) { + @Override + public ExprValue valueOf(Environment valueEnv) { + ExprValue arg1 = arguments.get(0).valueOf(valueEnv); + ExprValue arg2 = arguments.get(1).valueOf(valueEnv); + ExprValue arg3 = arguments.get(2).valueOf(valueEnv); + return function.apply(functionProperties, arg1, arg2, arg3); + } + + @Override + public ExprType type() { + return returnType; + } + + @Override + public String toString() { + return String.format( + "%s(%s)", + functionName, + arguments.stream().map(Object::toString).collect(Collectors.joining(", ")) + ); + } + }; + return Pair.of(functionSignature, functionBuilder); + }; + } + + /** + * No Arg Function Implementation. + * + * @param function {@link ExprValue} based unary function. + * @param returnType return type. + * @return Unary Function Implementation. + */ + public static SerializableFunction> impl( + SerializableNoArgFunction function, + ExprType returnType + ) { + return implWithProperties(fp -> function.get(), returnType); + } + + /** + * Unary Function Implementation. + * + * @param function {@link ExprValue} based unary function. + * @param returnType return type. + * @param argsType argument type. + * @return Unary Function Implementation. + */ + public static SerializableFunction> impl( + SerializableFunction function, + ExprType returnType, + ExprType argsType + ) { + + return implWithProperties((fp, arg) -> function.apply(arg), returnType, argsType); + } + + /** + * Binary Function Implementation. + * + * @param function {@link ExprValue} based unary function. + * @param returnType return type. + * @param args1Type argument type. + * @param args2Type argument type. + * @return Binary Function Implementation. + */ + public static SerializableFunction> impl( + SerializableBiFunction function, + ExprType returnType, + ExprType args1Type, + ExprType args2Type + ) { + + return implWithProperties((fp, arg1, arg2) -> function.apply(arg1, arg2), returnType, args1Type, args2Type); + } + + /** + * Triple Function Implementation. + * + * @param function {@link ExprValue} based unary function. + * @param returnType return type. + * @param args1Type argument type. + * @param args2Type argument type. + * @return Binary Function Implementation. + */ + public static SerializableFunction> impl( + SerializableTriFunction function, + ExprType returnType, + ExprType args1Type, + ExprType args2Type, + ExprType args3Type + ) { + + return functionName -> { + FunctionSignature functionSignature = new FunctionSignature(functionName, Arrays.asList(args1Type, args2Type, args3Type)); + FunctionBuilder functionBuilder = (functionProperties, arguments) -> new FunctionExpression(functionName, arguments) { + @Override + public ExprValue valueOf(Environment valueEnv) { + ExprValue arg1 = arguments.get(0).valueOf(valueEnv); + ExprValue arg2 = arguments.get(1).valueOf(valueEnv); + ExprValue arg3 = arguments.get(2).valueOf(valueEnv); + return function.apply(arg1, arg2, arg3); + } + + @Override + public ExprType type() { + return returnType; + } + + @Override + public String toString() { + return String.format( + "%s(%s, %s, %s)", + functionName, + arguments.get(0).toString(), + arguments.get(1).toString(), + arguments.get(2).toString() + ); + } + }; + return Pair.of(functionSignature, functionBuilder); + }; + } + + /** + * Quadruple Function Implementation. + * + * @param function {@link ExprValue} based unary function. + * @param returnType return type. + * @param args1Type argument type. + * @param args2Type argument type. + * @param args3Type argument type. + * @return Quadruple Function Implementation. + */ + public static SerializableFunction> impl( + SerializableQuadFunction function, + ExprType returnType, + ExprType args1Type, + ExprType args2Type, + ExprType args3Type, + ExprType args4Type + ) { + + return functionName -> { + FunctionSignature functionSignature = new FunctionSignature( + functionName, + Arrays.asList(args1Type, args2Type, args3Type, args4Type) + ); + FunctionBuilder functionBuilder = (functionProperties, arguments) -> new FunctionExpression(functionName, arguments) { + @Override + public ExprValue valueOf(Environment valueEnv) { + ExprValue arg1 = arguments.get(0).valueOf(valueEnv); + ExprValue arg2 = arguments.get(1).valueOf(valueEnv); + ExprValue arg3 = arguments.get(2).valueOf(valueEnv); + ExprValue arg4 = arguments.get(3).valueOf(valueEnv); + return function.apply(arg1, arg2, arg3, arg4); + } + + @Override + public ExprType type() { + return returnType; + } + + @Override + public String toString() { + return String.format( + "%s(%s, %s, %s, %s)", + functionName, + arguments.get(0).toString(), + arguments.get(1).toString(), + arguments.get(2).toString(), + arguments.get(3).toString() + ); + } + }; + return Pair.of(functionSignature, functionBuilder); + }; + } + + /** + * Wrapper the unary ExprValue function with default NULL and MISSING handling. + */ + public static SerializableFunction nullMissingHandling(SerializableFunction function) { + return value -> { + if (value.isMissing()) { + return ExprValueUtils.missingValue(); + } else if (value.isNull()) { + return ExprValueUtils.nullValue(); + } else { + return function.apply(value); } - }; - return Pair.of(functionSignature, functionBuilder); - }; - } - - /** - * No Arg Function Implementation. - * - * @param function {@link ExprValue} based unary function. - * @param returnType return type. - * @return Unary Function Implementation. - */ - public static SerializableFunction> impl( - SerializableNoArgFunction function, - ExprType returnType) { - return implWithProperties(fp -> function.get(), returnType); - } - - /** - * Unary Function Implementation. - * - * @param function {@link ExprValue} based unary function. - * @param returnType return type. - * @param argsType argument type. - * @return Unary Function Implementation. - */ - public static SerializableFunction> impl( - SerializableFunction function, - ExprType returnType, - ExprType argsType) { - - return implWithProperties((fp, arg) -> function.apply(arg), returnType, argsType); - } - - /** - * Binary Function Implementation. - * - * @param function {@link ExprValue} based unary function. - * @param returnType return type. - * @param args1Type argument type. - * @param args2Type argument type. - * @return Binary Function Implementation. - */ - public static SerializableFunction> impl( - SerializableBiFunction function, - ExprType returnType, - ExprType args1Type, - ExprType args2Type) { - - return implWithProperties((fp, arg1, arg2) -> - function.apply(arg1, arg2), returnType, args1Type, args2Type); - } - - /** - * Triple Function Implementation. - * - * @param function {@link ExprValue} based unary function. - * @param returnType return type. - * @param args1Type argument type. - * @param args2Type argument type. - * @return Binary Function Implementation. - */ - public static SerializableFunction> impl( - SerializableTriFunction function, - ExprType returnType, - ExprType args1Type, - ExprType args2Type, - ExprType args3Type) { - - return functionName -> { - FunctionSignature functionSignature = - new FunctionSignature(functionName, Arrays.asList(args1Type, args2Type, args3Type)); - FunctionBuilder functionBuilder = - (functionProperties, arguments) -> new FunctionExpression(functionName, arguments) { - @Override - public ExprValue valueOf(Environment valueEnv) { - ExprValue arg1 = arguments.get(0).valueOf(valueEnv); - ExprValue arg2 = arguments.get(1).valueOf(valueEnv); - ExprValue arg3 = arguments.get(2).valueOf(valueEnv); - return function.apply(arg1, arg2, arg3); + }; + } + + /** + * Wrapper the binary ExprValue function with default NULL and MISSING handling. + */ + public static SerializableBiFunction nullMissingHandling( + SerializableBiFunction function + ) { + return (v1, v2) -> { + if (v1.isMissing() || v2.isMissing()) { + return ExprValueUtils.missingValue(); + } else if (v1.isNull() || v2.isNull()) { + return ExprValueUtils.nullValue(); + } else { + return function.apply(v1, v2); } + }; + } - @Override - public ExprType type() { - return returnType; + /** + * Wrapper the triple ExprValue function with default NULL and MISSING handling. + */ + public SerializableTriFunction nullMissingHandling( + SerializableTriFunction function + ) { + return (v1, v2, v3) -> { + if (v1.isMissing() || v2.isMissing() || v3.isMissing()) { + return ExprValueUtils.missingValue(); + } else if (v1.isNull() || v2.isNull() || v3.isNull()) { + return ExprValueUtils.nullValue(); + } else { + return function.apply(v1, v2, v3); } + }; + } - @Override - public String toString() { - return String.format("%s(%s, %s, %s)", functionName, arguments.get(0).toString(), - arguments.get(1).toString(), arguments.get(2).toString()); + /** + * Wrapper the unary ExprValue function that is aware of FunctionProperties, + * with default NULL and MISSING handling. + */ + public static SerializableBiFunction nullMissingHandlingWithProperties( + SerializableBiFunction implementation + ) { + return (functionProperties, v1) -> { + if (v1.isMissing()) { + return ExprValueUtils.missingValue(); + } else if (v1.isNull()) { + return ExprValueUtils.nullValue(); + } else { + return implementation.apply(functionProperties, v1); } - }; - return Pair.of(functionSignature, functionBuilder); - }; - } - - /** - * Quadruple Function Implementation. - * - * @param function {@link ExprValue} based unary function. - * @param returnType return type. - * @param args1Type argument type. - * @param args2Type argument type. - * @param args3Type argument type. - * @return Quadruple Function Implementation. - */ - public static SerializableFunction> impl( - SerializableQuadFunction function, - ExprType returnType, - ExprType args1Type, - ExprType args2Type, - ExprType args3Type, - ExprType args4Type) { - - return functionName -> { - FunctionSignature functionSignature = - new FunctionSignature(functionName, Arrays.asList( - args1Type, - args2Type, - args3Type, - args4Type)); - FunctionBuilder functionBuilder = - (functionProperties, arguments) -> new FunctionExpression(functionName, arguments) { - @Override - public ExprValue valueOf(Environment valueEnv) { - ExprValue arg1 = arguments.get(0).valueOf(valueEnv); - ExprValue arg2 = arguments.get(1).valueOf(valueEnv); - ExprValue arg3 = arguments.get(2).valueOf(valueEnv); - ExprValue arg4 = arguments.get(3).valueOf(valueEnv); - return function.apply(arg1, arg2, arg3, arg4); + }; + } + + /** + * Wrapper for the ExprValue function that takes 2 arguments and is aware of FunctionProperties, + * with default NULL and MISSING handling. + */ + public static SerializableTriFunction nullMissingHandlingWithProperties( + SerializableTriFunction implementation + ) { + return (functionProperties, v1, v2) -> { + if (v1.isMissing() || v2.isMissing()) { + return ExprValueUtils.missingValue(); + } else if (v1.isNull() || v2.isNull()) { + return ExprValueUtils.nullValue(); + } else { + return implementation.apply(functionProperties, v1, v2); } + }; + } - @Override - public ExprType type() { - return returnType; + /** + * Wrapper for the ExprValue function that takes 3 arguments and is aware of FunctionProperties, + * with default NULL and MISSING handling. + */ + public static + SerializableQuadFunction + nullMissingHandlingWithProperties( + SerializableQuadFunction implementation + ) { + return (functionProperties, v1, v2, v3) -> { + if (v1.isMissing() || v2.isMissing() || v3.isMissing()) { + return ExprValueUtils.missingValue(); } - @Override - public String toString() { - return String.format("%s(%s, %s, %s, %s)", functionName, arguments.get(0).toString(), - arguments.get(1).toString(), - arguments.get(2).toString(), - arguments.get(3).toString()); + if (v1.isNull() || v2.isNull() || v3.isNull()) { + return ExprValueUtils.nullValue(); } - }; - return Pair.of(functionSignature, functionBuilder); - }; - } - - /** - * Wrapper the unary ExprValue function with default NULL and MISSING handling. - */ - public static SerializableFunction nullMissingHandling( - SerializableFunction function) { - return value -> { - if (value.isMissing()) { - return ExprValueUtils.missingValue(); - } else if (value.isNull()) { - return ExprValueUtils.nullValue(); - } else { - return function.apply(value); - } - }; - } - - /** - * Wrapper the binary ExprValue function with default NULL and MISSING handling. - */ - public static SerializableBiFunction nullMissingHandling( - SerializableBiFunction function) { - return (v1, v2) -> { - if (v1.isMissing() || v2.isMissing()) { - return ExprValueUtils.missingValue(); - } else if (v1.isNull() || v2.isNull()) { - return ExprValueUtils.nullValue(); - } else { - return function.apply(v1, v2); - } - }; - } - - /** - * Wrapper the triple ExprValue function with default NULL and MISSING handling. - */ - public SerializableTriFunction nullMissingHandling( - SerializableTriFunction function) { - return (v1, v2, v3) -> { - if (v1.isMissing() || v2.isMissing() || v3.isMissing()) { - return ExprValueUtils.missingValue(); - } else if (v1.isNull() || v2.isNull() || v3.isNull()) { - return ExprValueUtils.nullValue(); - } else { - return function.apply(v1, v2, v3); - } - }; - } - - /** - * Wrapper the unary ExprValue function that is aware of FunctionProperties, - * with default NULL and MISSING handling. - */ - public static SerializableBiFunction - nullMissingHandlingWithProperties( - SerializableBiFunction implementation) { - return (functionProperties, v1) -> { - if (v1.isMissing()) { - return ExprValueUtils.missingValue(); - } else if (v1.isNull()) { - return ExprValueUtils.nullValue(); - } else { - return implementation.apply(functionProperties, v1); - } - }; - } - - /** - * Wrapper for the ExprValue function that takes 2 arguments and is aware of FunctionProperties, - * with default NULL and MISSING handling. - */ - public static SerializableTriFunction - nullMissingHandlingWithProperties( - SerializableTriFunction implementation) { - return (functionProperties, v1, v2) -> { - if (v1.isMissing() || v2.isMissing()) { - return ExprValueUtils.missingValue(); - } else if (v1.isNull() || v2.isNull()) { - return ExprValueUtils.nullValue(); - } else { - return implementation.apply(functionProperties, v1, v2); - } - }; - } - - /** - * Wrapper for the ExprValue function that takes 3 arguments and is aware of FunctionProperties, - * with default NULL and MISSING handling. - */ - public static SerializableQuadFunction< - FunctionProperties, - ExprValue, - ExprValue, - ExprValue, - ExprValue> - nullMissingHandlingWithProperties( - SerializableQuadFunction< - FunctionProperties, - ExprValue, - ExprValue, - ExprValue, - ExprValue> implementation) { - return (functionProperties, v1, v2, v3) -> { - if (v1.isMissing() || v2.isMissing() || v3.isMissing()) { - return ExprValueUtils.missingValue(); - } - - if (v1.isNull() || v2.isNull() || v3.isNull()) { - return ExprValueUtils.nullValue(); - } - - return implementation.apply(functionProperties, v1, v2, v3); - }; - } + + return implementation.apply(functionProperties, v1, v2, v3); + }; + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionImplementation.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionImplementation.java index d829e01225..3a65c4dd9f 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/FunctionImplementation.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionImplementation.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.function; import java.util.List; @@ -14,13 +13,13 @@ */ public interface FunctionImplementation { - /** - * Get Function Name. - */ - FunctionName getFunctionName(); + /** + * Get Function Name. + */ + FunctionName getFunctionName(); - /** - * Get Function Arguments. - */ - List getArguments(); + /** + * Get Function Arguments. + */ + List getArguments(); } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionName.java index cb3d5fab92..5c0e989662 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/FunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionName.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.function; import java.io.Serializable; @@ -17,15 +16,15 @@ @EqualsAndHashCode @RequiredArgsConstructor public class FunctionName implements Serializable { - @Getter - private final String functionName; + @Getter + private final String functionName; - public static FunctionName of(String functionName) { - return new FunctionName(functionName.toLowerCase()); - } + public static FunctionName of(String functionName) { + return new FunctionName(functionName.toLowerCase()); + } - @Override - public String toString() { - return functionName; - } + @Override + public String toString() { + return functionName; + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionProperties.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionProperties.java index 4222748051..804cd861a1 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/FunctionProperties.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionProperties.java @@ -16,53 +16,53 @@ @EqualsAndHashCode public class FunctionProperties implements Serializable { - private final Instant nowInstant; - private final ZoneId currentZoneId; + private final Instant nowInstant; + private final ZoneId currentZoneId; - /** - * By default, use current time and current timezone. - */ - public FunctionProperties() { - nowInstant = Instant.now(); - currentZoneId = ZoneId.systemDefault(); - } - - /** - * Method to access current system clock. - * @return a ticking clock that tells the time. - */ - public Clock getSystemClock() { - return Clock.system(currentZoneId); - } - - /** - * Method to get time when query began execution. - * Clock class combines an instant Supplier and a time zone. - * @return a fixed clock that returns the time execution started at. - * - */ - public Clock getQueryStartClock() { - return Clock.fixed(nowInstant, currentZoneId); - } + /** + * By default, use current time and current timezone. + */ + public FunctionProperties() { + nowInstant = Instant.now(); + currentZoneId = ZoneId.systemDefault(); + } - /** - * Use when compiling functions that do not rely on function properties. - */ - public static final FunctionProperties None = new FunctionProperties() { - @Override + /** + * Method to access current system clock. + * @return a ticking clock that tells the time. + */ public Clock getSystemClock() { - throw new UnexpectedCallException(); + return Clock.system(currentZoneId); } - @Override + /** + * Method to get time when query began execution. + * Clock class combines an instant Supplier and a time zone. + * @return a fixed clock that returns the time execution started at. + * + */ public Clock getQueryStartClock() { - throw new UnexpectedCallException(); + return Clock.fixed(nowInstant, currentZoneId); } - }; - class UnexpectedCallException extends RuntimeException { - public UnexpectedCallException() { - super("FunctionProperties.None is a null object and not meant to be accessed."); + /** + * Use when compiling functions that do not rely on function properties. + */ + public static final FunctionProperties None = new FunctionProperties() { + @Override + public Clock getSystemClock() { + throw new UnexpectedCallException(); + } + + @Override + public Clock getQueryStartClock() { + throw new UnexpectedCallException(); + } + }; + + class UnexpectedCallException extends RuntimeException { + public UnexpectedCallException() { + super("FunctionProperties.None is a null object and not meant to be accessed."); + } } - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionResolver.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionResolver.java index 1635b6f846..3e2a920ff8 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/FunctionResolver.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionResolver.java @@ -12,7 +12,7 @@ * given a {@ref FunctionSignature}. */ public interface FunctionResolver { - Pair resolve(FunctionSignature unresolvedSignature); + Pair resolve(FunctionSignature unresolvedSignature); - FunctionName getFunctionName(); + FunctionName getFunctionName(); } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionSignature.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionSignature.java index 0c59d71c25..939dbb18aa 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/FunctionSignature.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionSignature.java @@ -22,57 +22,54 @@ @RequiredArgsConstructor @EqualsAndHashCode public class FunctionSignature { - public static final Integer NOT_MATCH = Integer.MAX_VALUE; - public static final Integer EXACTLY_MATCH = 0; + public static final Integer NOT_MATCH = Integer.MAX_VALUE; + public static final Integer EXACTLY_MATCH = 0; - private final FunctionName functionName; - private final List paramTypeList; + private final FunctionName functionName; + private final List paramTypeList; - /** - * calculate the function signature match degree. - * - * @return EXACTLY_MATCH: exactly match - * NOT_MATCH: not match - * By widening rule, the small number means better match - */ - public int match(FunctionSignature functionSignature) { - List functionTypeList = functionSignature.getParamTypeList(); - if (!functionName.equals(functionSignature.getFunctionName()) - || paramTypeList.size() != functionTypeList.size()) { - return NOT_MATCH; - } - // TODO: improve to support regular and array type mixed, ex. func(int,string,array) - if (isVarArgFunction(functionTypeList)) { - return EXACTLY_MATCH; - } + /** + * calculate the function signature match degree. + * + * @return EXACTLY_MATCH: exactly match + * NOT_MATCH: not match + * By widening rule, the small number means better match + */ + public int match(FunctionSignature functionSignature) { + List functionTypeList = functionSignature.getParamTypeList(); + if (!functionName.equals(functionSignature.getFunctionName()) || paramTypeList.size() != functionTypeList.size()) { + return NOT_MATCH; + } + // TODO: improve to support regular and array type mixed, ex. func(int,string,array) + if (isVarArgFunction(functionTypeList)) { + return EXACTLY_MATCH; + } - int matchDegree = EXACTLY_MATCH; - for (int i = 0; i < paramTypeList.size(); i++) { - ExprType paramType = paramTypeList.get(i); - ExprType funcType = functionTypeList.get(i); - int match = WideningTypeRule.distance(paramType, funcType); - if (match == WideningTypeRule.IMPOSSIBLE_WIDENING) { - return NOT_MATCH; - } else { - matchDegree += match; - } + int matchDegree = EXACTLY_MATCH; + for (int i = 0; i < paramTypeList.size(); i++) { + ExprType paramType = paramTypeList.get(i); + ExprType funcType = functionTypeList.get(i); + int match = WideningTypeRule.distance(paramType, funcType); + if (match == WideningTypeRule.IMPOSSIBLE_WIDENING) { + return NOT_MATCH; + } else { + matchDegree += match; + } + } + return matchDegree; } - return matchDegree; - } - /** - * util function for formatted arguments list. - */ - public String formatTypes() { - return getParamTypeList().stream() - .map(ExprType::typeName) - .collect(Collectors.joining(",", "[", "]")); - } + /** + * util function for formatted arguments list. + */ + public String formatTypes() { + return getParamTypeList().stream().map(ExprType::typeName).collect(Collectors.joining(",", "[", "]")); + } - /** - * util function - returns true if function has variable arguments. - */ - protected static boolean isVarArgFunction(List argTypes) { - return argTypes.size() == 1 && argTypes.get(0) == ARRAY; - } + /** + * util function - returns true if function has variable arguments. + */ + protected static boolean isVarArgFunction(List argTypes) { + return argTypes.size() == 1 && argTypes.get(0) == ARRAY; + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java index c5fcb010f5..34128f3ae3 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java @@ -22,154 +22,156 @@ @UtilityClass public class OpenSearchFunctions { - /** - * Add functions specific to OpenSearch to repository. - */ - public void register(BuiltinFunctionRepository repository) { - repository.register(match_bool_prefix()); - repository.register(multi_match(BuiltinFunctionName.MULTI_MATCH)); - repository.register(multi_match(BuiltinFunctionName.MULTIMATCH)); - repository.register(multi_match(BuiltinFunctionName.MULTIMATCHQUERY)); - repository.register(match(BuiltinFunctionName.MATCH)); - repository.register(match(BuiltinFunctionName.MATCHQUERY)); - repository.register(match(BuiltinFunctionName.MATCH_QUERY)); - repository.register(simple_query_string()); - repository.register(query()); - repository.register(query_string()); - - // Register MATCHPHRASE as MATCH_PHRASE as well for backwards - // compatibility. - repository.register(match_phrase(BuiltinFunctionName.MATCH_PHRASE)); - repository.register(match_phrase(BuiltinFunctionName.MATCHPHRASE)); - repository.register(match_phrase(BuiltinFunctionName.MATCHPHRASEQUERY)); - repository.register(match_phrase_prefix()); - repository.register(wildcard_query(BuiltinFunctionName.WILDCARD_QUERY)); - repository.register(wildcard_query(BuiltinFunctionName.WILDCARDQUERY)); - repository.register(score(BuiltinFunctionName.SCORE)); - repository.register(score(BuiltinFunctionName.SCOREQUERY)); - repository.register(score(BuiltinFunctionName.SCORE_QUERY)); - // Functions supported in SELECT clause - repository.register(nested()); - } - - private static FunctionResolver match_bool_prefix() { - FunctionName name = BuiltinFunctionName.MATCH_BOOL_PREFIX.getName(); - return new RelevanceFunctionResolver(name); - } - - private static FunctionResolver match(BuiltinFunctionName match) { - FunctionName funcName = match.getName(); - return new RelevanceFunctionResolver(funcName); - } - - private static FunctionResolver match_phrase_prefix() { - FunctionName funcName = BuiltinFunctionName.MATCH_PHRASE_PREFIX.getName(); - return new RelevanceFunctionResolver(funcName); - } - - private static FunctionResolver match_phrase(BuiltinFunctionName matchPhrase) { - FunctionName funcName = matchPhrase.getName(); - return new RelevanceFunctionResolver(funcName); - } - - private static FunctionResolver multi_match(BuiltinFunctionName multiMatchName) { - return new RelevanceFunctionResolver(multiMatchName.getName()); - } - - private static FunctionResolver simple_query_string() { - FunctionName funcName = BuiltinFunctionName.SIMPLE_QUERY_STRING.getName(); - return new RelevanceFunctionResolver(funcName); - } - - private static FunctionResolver query() { - FunctionName funcName = BuiltinFunctionName.QUERY.getName(); - return new RelevanceFunctionResolver(funcName); - } - - private static FunctionResolver query_string() { - FunctionName funcName = BuiltinFunctionName.QUERY_STRING.getName(); - return new RelevanceFunctionResolver(funcName); - } - - private static FunctionResolver wildcard_query(BuiltinFunctionName wildcardQuery) { - FunctionName funcName = wildcardQuery.getName(); - return new RelevanceFunctionResolver(funcName); - } - - private static FunctionResolver nested() { - return new FunctionResolver() { - @Override - public Pair resolve( - FunctionSignature unresolvedSignature) { - return Pair.of(unresolvedSignature, - (functionProperties, arguments) -> - new FunctionExpression(BuiltinFunctionName.NESTED.getName(), arguments) { - @Override - public ExprValue valueOf(Environment valueEnv) { - return valueEnv.resolve(getArguments().get(0)); - } - - @Override - public ExprType type() { - return getArguments().get(0).type(); - } - }); - } - - @Override - public FunctionName getFunctionName() { - return BuiltinFunctionName.NESTED.getName(); - } - }; - } - - - - - private static FunctionResolver score(BuiltinFunctionName score) { - FunctionName funcName = score.getName(); - return new RelevanceFunctionResolver(funcName); - } - - public static class OpenSearchFunction extends FunctionExpression { - private final FunctionName functionName; - private final List arguments; - - @Getter - @Setter - private boolean isScoreTracked; - /** - * Required argument constructor. - * @param functionName name of the function - * @param arguments a list of expressions + * Add functions specific to OpenSearch to repository. */ - public OpenSearchFunction(FunctionName functionName, List arguments) { - super(functionName, arguments); - this.functionName = functionName; - this.arguments = arguments; - this.isScoreTracked = false; + public void register(BuiltinFunctionRepository repository) { + repository.register(match_bool_prefix()); + repository.register(multi_match(BuiltinFunctionName.MULTI_MATCH)); + repository.register(multi_match(BuiltinFunctionName.MULTIMATCH)); + repository.register(multi_match(BuiltinFunctionName.MULTIMATCHQUERY)); + repository.register(match(BuiltinFunctionName.MATCH)); + repository.register(match(BuiltinFunctionName.MATCHQUERY)); + repository.register(match(BuiltinFunctionName.MATCH_QUERY)); + repository.register(simple_query_string()); + repository.register(query()); + repository.register(query_string()); + + // Register MATCHPHRASE as MATCH_PHRASE as well for backwards + // compatibility. + repository.register(match_phrase(BuiltinFunctionName.MATCH_PHRASE)); + repository.register(match_phrase(BuiltinFunctionName.MATCHPHRASE)); + repository.register(match_phrase(BuiltinFunctionName.MATCHPHRASEQUERY)); + repository.register(match_phrase_prefix()); + repository.register(wildcard_query(BuiltinFunctionName.WILDCARD_QUERY)); + repository.register(wildcard_query(BuiltinFunctionName.WILDCARDQUERY)); + repository.register(score(BuiltinFunctionName.SCORE)); + repository.register(score(BuiltinFunctionName.SCOREQUERY)); + repository.register(score(BuiltinFunctionName.SCORE_QUERY)); + // Functions supported in SELECT clause + repository.register(nested()); + } + + private static FunctionResolver match_bool_prefix() { + FunctionName name = BuiltinFunctionName.MATCH_BOOL_PREFIX.getName(); + return new RelevanceFunctionResolver(name); + } + + private static FunctionResolver match(BuiltinFunctionName match) { + FunctionName funcName = match.getName(); + return new RelevanceFunctionResolver(funcName); + } + + private static FunctionResolver match_phrase_prefix() { + FunctionName funcName = BuiltinFunctionName.MATCH_PHRASE_PREFIX.getName(); + return new RelevanceFunctionResolver(funcName); + } + + private static FunctionResolver match_phrase(BuiltinFunctionName matchPhrase) { + FunctionName funcName = matchPhrase.getName(); + return new RelevanceFunctionResolver(funcName); + } + + private static FunctionResolver multi_match(BuiltinFunctionName multiMatchName) { + return new RelevanceFunctionResolver(multiMatchName.getName()); + } + + private static FunctionResolver simple_query_string() { + FunctionName funcName = BuiltinFunctionName.SIMPLE_QUERY_STRING.getName(); + return new RelevanceFunctionResolver(funcName); + } + + private static FunctionResolver query() { + FunctionName funcName = BuiltinFunctionName.QUERY.getName(); + return new RelevanceFunctionResolver(funcName); + } + + private static FunctionResolver query_string() { + FunctionName funcName = BuiltinFunctionName.QUERY_STRING.getName(); + return new RelevanceFunctionResolver(funcName); + } + + private static FunctionResolver wildcard_query(BuiltinFunctionName wildcardQuery) { + FunctionName funcName = wildcardQuery.getName(); + return new RelevanceFunctionResolver(funcName); } - @Override - public ExprValue valueOf(Environment valueEnv) { - throw new UnsupportedOperationException(String.format( - "OpenSearch defined function [%s] is only supported in WHERE and HAVING clause.", - functionName)); + private static FunctionResolver nested() { + return new FunctionResolver() { + @Override + public Pair resolve(FunctionSignature unresolvedSignature) { + return Pair.of( + unresolvedSignature, + (functionProperties, arguments) -> new FunctionExpression(BuiltinFunctionName.NESTED.getName(), arguments) { + @Override + public ExprValue valueOf(Environment valueEnv) { + return valueEnv.resolve(getArguments().get(0)); + } + + @Override + public ExprType type() { + return getArguments().get(0).type(); + } + } + ); + } + + @Override + public FunctionName getFunctionName() { + return BuiltinFunctionName.NESTED.getName(); + } + }; } - @Override - public ExprType type() { - return BOOLEAN; + private static FunctionResolver score(BuiltinFunctionName score) { + FunctionName funcName = score.getName(); + return new RelevanceFunctionResolver(funcName); } - @Override - public String toString() { - List args = arguments.stream() - .map(arg -> String.format("%s=%s", ((NamedArgumentExpression) arg) - .getArgName(), ((NamedArgumentExpression) arg).getValue().toString())) - .collect(Collectors.toList()); - return String.format("%s(%s)", functionName, String.join(", ", args)); + public static class OpenSearchFunction extends FunctionExpression { + private final FunctionName functionName; + private final List arguments; + + @Getter + @Setter + private boolean isScoreTracked; + + /** + * Required argument constructor. + * @param functionName name of the function + * @param arguments a list of expressions + */ + public OpenSearchFunction(FunctionName functionName, List arguments) { + super(functionName, arguments); + this.functionName = functionName; + this.arguments = arguments; + this.isScoreTracked = false; + } + + @Override + public ExprValue valueOf(Environment valueEnv) { + throw new UnsupportedOperationException( + String.format("OpenSearch defined function [%s] is only supported in WHERE and HAVING clause.", functionName) + ); + } + + @Override + public ExprType type() { + return BOOLEAN; + } + + @Override + public String toString() { + List args = arguments.stream() + .map( + arg -> String.format( + "%s=%s", + ((NamedArgumentExpression) arg).getArgName(), + ((NamedArgumentExpression) arg).getValue().toString() + ) + ) + .collect(Collectors.toList()); + return String.format("%s(%s)", functionName, String.join(", ", args)); + } } - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java b/core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java index ef0ac9226c..4ce42406f2 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java @@ -14,43 +14,44 @@ import org.opensearch.sql.exception.SemanticCheckException; @RequiredArgsConstructor -public class RelevanceFunctionResolver - implements FunctionResolver { +public class RelevanceFunctionResolver implements FunctionResolver { - @Getter - private final FunctionName functionName; + @Getter + private final FunctionName functionName; - @Override - public Pair resolve(FunctionSignature unresolvedSignature) { - if (!unresolvedSignature.getFunctionName().equals(functionName)) { - throw new SemanticCheckException(String.format("Expected '%s' but got '%s'", - functionName.getFunctionName(), unresolvedSignature.getFunctionName().getFunctionName())); - } - List paramTypes = unresolvedSignature.getParamTypeList(); - // Check if all but the first parameter are of type STRING. - for (int i = 1; i < paramTypes.size(); i++) { - ExprType paramType = paramTypes.get(i); - if (!ExprCoreType.STRING.equals(paramType)) { - throw new SemanticCheckException( - getWrongParameterErrorMessage(i, paramType, ExprCoreType.STRING)); - } - } + @Override + public Pair resolve(FunctionSignature unresolvedSignature) { + if (!unresolvedSignature.getFunctionName().equals(functionName)) { + throw new SemanticCheckException( + String.format( + "Expected '%s' but got '%s'", + functionName.getFunctionName(), + unresolvedSignature.getFunctionName().getFunctionName() + ) + ); + } + List paramTypes = unresolvedSignature.getParamTypeList(); + // Check if all but the first parameter are of type STRING. + for (int i = 1; i < paramTypes.size(); i++) { + ExprType paramType = paramTypes.get(i); + if (!ExprCoreType.STRING.equals(paramType)) { + throw new SemanticCheckException(getWrongParameterErrorMessage(i, paramType, ExprCoreType.STRING)); + } + } - FunctionBuilder buildFunction = (functionProperties, args) - -> new OpenSearchFunctions.OpenSearchFunction(functionName, args); - return Pair.of(unresolvedSignature, buildFunction); - } + FunctionBuilder buildFunction = (functionProperties, args) -> new OpenSearchFunctions.OpenSearchFunction(functionName, args); + return Pair.of(unresolvedSignature, buildFunction); + } - /** Returns a helpful error message when expected parameter type does not match the - * specified parameter type. - * - * @param i 0-based index of the parameter in a function signature. - * @param paramType the type of the ith parameter at run-time. - * @param expectedType the expected type of the ith parameter - * @return A user-friendly error message that informs of the type difference. - */ - private String getWrongParameterErrorMessage(int i, ExprType paramType, ExprType expectedType) { - return String.format("Expected type %s instead of %s for parameter #%d", - expectedType.typeName(), paramType.typeName(), i + 1); - } + /** Returns a helpful error message when expected parameter type does not match the + * specified parameter type. + * + * @param i 0-based index of the parameter in a function signature. + * @param paramType the type of the ith parameter at run-time. + * @param expectedType the expected type of the ith parameter + * @return A user-friendly error message that informs of the type difference. + */ + private String getWrongParameterErrorMessage(int i, ExprType paramType, ExprType expectedType) { + return String.format("Expected type %s instead of %s for parameter #%d", expectedType.typeName(), paramType.typeName(), i + 1); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/SerializableBiFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/SerializableBiFunction.java index 5b3aaf31f3..e0c59c14b9 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/SerializableBiFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/SerializableBiFunction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.function; import java.io.Serializable; @@ -12,5 +11,4 @@ /** * Serializable BiFunction. */ -public interface SerializableBiFunction extends BiFunction, Serializable { -} +public interface SerializableBiFunction extends BiFunction, Serializable {} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/SerializableFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/SerializableFunction.java index 467c034c39..fb3e2f2cfb 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/SerializableFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/SerializableFunction.java @@ -3,11 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.function; import java.io.Serializable; import java.util.function.Function; -public interface SerializableFunction extends Function, Serializable { -} +public interface SerializableFunction extends Function, Serializable {} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/SerializableNoArgFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/SerializableNoArgFunction.java index e68d6084b4..23c47527fb 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/SerializableNoArgFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/SerializableNoArgFunction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.function; import java.io.Serializable; @@ -12,5 +11,4 @@ /** * Serializable no argument function. */ -public interface SerializableNoArgFunction extends Supplier, Serializable { -} +public interface SerializableNoArgFunction extends Supplier, Serializable {} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/SerializableQuadFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/SerializableQuadFunction.java index 056a17d5b3..3a810860d3 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/SerializableQuadFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/SerializableQuadFunction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.function; import java.io.Serializable; @@ -18,14 +17,14 @@ * @param the type of the result of the function */ public interface SerializableQuadFunction extends Serializable { - /** - * Applies this function to the given arguments. - * - * @param t the first function argument - * @param u the second function argument - * @param v the third function argument - * @param w the fourth function argument - * @return the function result - */ - R apply(T t, U u, V v, W w); + /** + * Applies this function to the given arguments. + * + * @param t the first function argument + * @param u the second function argument + * @param v the third function argument + * @param w the fourth function argument + * @return the function result + */ + R apply(T t, U u, V v, W w); } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/SerializableTriFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/SerializableTriFunction.java index 911012fcdb..7ad009561b 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/SerializableTriFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/SerializableTriFunction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.function; import java.io.Serializable; @@ -17,13 +16,13 @@ * @param the type of the result of the function */ public interface SerializableTriFunction extends Serializable { - /** - * Applies this function to the given arguments. - * - * @param t the first function argument - * @param u the second function argument - * @param v the third function argument - * @return the function result - */ - R apply(T t, U u, V v); + /** + * Applies this function to the given arguments. + * + * @param t the first function argument + * @param u the second function argument + * @param v the third function argument + * @return the function result + */ + R apply(T t, U u, V v); } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/TableFunctionImplementation.java b/core/src/main/java/org/opensearch/sql/expression/function/TableFunctionImplementation.java index f35ffe4898..38496a538b 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/TableFunctionImplementation.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/TableFunctionImplementation.java @@ -14,6 +14,6 @@ */ public interface TableFunctionImplementation extends FunctionImplementation { - Table applyArguments(); + Table applyArguments(); } diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java index 1f4ac3943c..0c56b2dc61 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.operator.arthmetic; import static org.opensearch.sql.data.type.ExprCoreType.BYTE; @@ -39,230 +38,260 @@ */ @UtilityClass public class ArithmeticFunction { - /** - * Register Arithmetic Function. - * - * @param repository {@link BuiltinFunctionRepository}. - */ - public static void register(BuiltinFunctionRepository repository) { - repository.register(add()); - repository.register(addFunction()); - repository.register(divide()); - repository.register(divideFunction()); - repository.register(mod()); - repository.register(modulus()); - repository.register(modulusFunction()); - repository.register(multiply()); - repository.register(multiplyFunction()); - repository.register(subtract()); - repository.register(subtractFunction()); - } + /** + * Register Arithmetic Function. + * + * @param repository {@link BuiltinFunctionRepository}. + */ + public static void register(BuiltinFunctionRepository repository) { + repository.register(add()); + repository.register(addFunction()); + repository.register(divide()); + repository.register(divideFunction()); + repository.register(mod()); + repository.register(modulus()); + repository.register(modulusFunction()); + repository.register(multiply()); + repository.register(multiplyFunction()); + repository.register(subtract()); + repository.register(subtractFunction()); + } - /** - * Definition of add(x, y) function. - * Returns the number x plus number y - * The supported signature of add function is - * (x: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE, y: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE) - * -> wider type between types of x and y - */ - private static DefaultFunctionResolver addBase(FunctionName functionName) { - return define(functionName, - impl(nullMissingHandling( - (v1, v2) -> new ExprByteValue(v1.byteValue() + v2.byteValue())), - BYTE, BYTE, BYTE), - impl(nullMissingHandling( - (v1, v2) -> new ExprShortValue(v1.shortValue() + v2.shortValue())), - SHORT, SHORT, SHORT), - impl(nullMissingHandling( - (v1, v2) -> new ExprIntegerValue(Math.addExact(v1.integerValue(), v2.integerValue()))), - INTEGER, INTEGER, INTEGER), - impl(nullMissingHandling( - (v1, v2) -> new ExprLongValue(Math.addExact(v1.longValue(), v2.longValue()))), - LONG, LONG, LONG), - impl(nullMissingHandling( - (v1, v2) -> new ExprFloatValue(v1.floatValue() + v2.floatValue())), - FLOAT, FLOAT, FLOAT), - impl(nullMissingHandling( - (v1, v2) -> new ExprDoubleValue(v1.doubleValue() + v2.doubleValue())), - DOUBLE, DOUBLE, DOUBLE) - ); - } + /** + * Definition of add(x, y) function. + * Returns the number x plus number y + * The supported signature of add function is + * (x: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE, y: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE) + * -> wider type between types of x and y + */ + private static DefaultFunctionResolver addBase(FunctionName functionName) { + return define( + functionName, + impl(nullMissingHandling((v1, v2) -> new ExprByteValue(v1.byteValue() + v2.byteValue())), BYTE, BYTE, BYTE), + impl(nullMissingHandling((v1, v2) -> new ExprShortValue(v1.shortValue() + v2.shortValue())), SHORT, SHORT, SHORT), + impl( + nullMissingHandling((v1, v2) -> new ExprIntegerValue(Math.addExact(v1.integerValue(), v2.integerValue()))), + INTEGER, + INTEGER, + INTEGER + ), + impl(nullMissingHandling((v1, v2) -> new ExprLongValue(Math.addExact(v1.longValue(), v2.longValue()))), LONG, LONG, LONG), + impl(nullMissingHandling((v1, v2) -> new ExprFloatValue(v1.floatValue() + v2.floatValue())), FLOAT, FLOAT, FLOAT), + impl(nullMissingHandling((v1, v2) -> new ExprDoubleValue(v1.doubleValue() + v2.doubleValue())), DOUBLE, DOUBLE, DOUBLE) + ); + } - private static DefaultFunctionResolver add() { - return addBase(BuiltinFunctionName.ADD.getName()); - } + private static DefaultFunctionResolver add() { + return addBase(BuiltinFunctionName.ADD.getName()); + } - private static DefaultFunctionResolver addFunction() { - return addBase(BuiltinFunctionName.ADDFUNCTION.getName()); - } + private static DefaultFunctionResolver addFunction() { + return addBase(BuiltinFunctionName.ADDFUNCTION.getName()); + } - /** - * Definition of divide(x, y) function. - * Returns the number x divided by number y - * The supported signature of divide function is - * (x: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE, y: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE) - * -> wider type between types of x and y - */ - private static DefaultFunctionResolver divideBase(FunctionName functionName) { - return define(functionName, - impl(nullMissingHandling( - (v1, v2) -> v2.byteValue() == 0 ? ExprNullValue.of() : - new ExprByteValue(v1.byteValue() / v2.byteValue())), - BYTE, BYTE, BYTE), - impl(nullMissingHandling( - (v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : - new ExprShortValue(v1.shortValue() / v2.shortValue())), - SHORT, SHORT, SHORT), - impl(nullMissingHandling( - (v1, v2) -> v2.integerValue() == 0 ? ExprNullValue.of() : - new ExprIntegerValue(v1.integerValue() / v2.integerValue())), - INTEGER, INTEGER, INTEGER), - impl(nullMissingHandling( - (v1, v2) -> v2.longValue() == 0 ? ExprNullValue.of() : - new ExprLongValue(v1.longValue() / v2.longValue())), - LONG, LONG, LONG), - impl(nullMissingHandling( - (v1, v2) -> v2.floatValue() == 0 ? ExprNullValue.of() : - new ExprFloatValue(v1.floatValue() / v2.floatValue())), - FLOAT, FLOAT, FLOAT), - impl(nullMissingHandling( - (v1, v2) -> v2.doubleValue() == 0 ? ExprNullValue.of() : - new ExprDoubleValue(v1.doubleValue() / v2.doubleValue())), - DOUBLE, DOUBLE, DOUBLE) - ); - } + /** + * Definition of divide(x, y) function. + * Returns the number x divided by number y + * The supported signature of divide function is + * (x: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE, y: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE) + * -> wider type between types of x and y + */ + private static DefaultFunctionResolver divideBase(FunctionName functionName) { + return define( + functionName, + impl( + nullMissingHandling( + (v1, v2) -> v2.byteValue() == 0 ? ExprNullValue.of() : new ExprByteValue(v1.byteValue() / v2.byteValue()) + ), + BYTE, + BYTE, + BYTE + ), + impl( + nullMissingHandling( + (v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : new ExprShortValue(v1.shortValue() / v2.shortValue()) + ), + SHORT, + SHORT, + SHORT + ), + impl( + nullMissingHandling( + (v1, v2) -> v2.integerValue() == 0 ? ExprNullValue.of() : new ExprIntegerValue(v1.integerValue() / v2.integerValue()) + ), + INTEGER, + INTEGER, + INTEGER + ), + impl( + nullMissingHandling( + (v1, v2) -> v2.longValue() == 0 ? ExprNullValue.of() : new ExprLongValue(v1.longValue() / v2.longValue()) + ), + LONG, + LONG, + LONG + ), + impl( + nullMissingHandling( + (v1, v2) -> v2.floatValue() == 0 ? ExprNullValue.of() : new ExprFloatValue(v1.floatValue() / v2.floatValue()) + ), + FLOAT, + FLOAT, + FLOAT + ), + impl( + nullMissingHandling( + (v1, v2) -> v2.doubleValue() == 0 ? ExprNullValue.of() : new ExprDoubleValue(v1.doubleValue() / v2.doubleValue()) + ), + DOUBLE, + DOUBLE, + DOUBLE + ) + ); + } - private static DefaultFunctionResolver divide() { - return divideBase(BuiltinFunctionName.DIVIDE.getName()); - } + private static DefaultFunctionResolver divide() { + return divideBase(BuiltinFunctionName.DIVIDE.getName()); + } - private static DefaultFunctionResolver divideFunction() { - return divideBase(BuiltinFunctionName.DIVIDEFUNCTION.getName()); - } + private static DefaultFunctionResolver divideFunction() { + return divideBase(BuiltinFunctionName.DIVIDEFUNCTION.getName()); + } - /** - * Definition of modulus(x, y) function. - * Returns the number x modulo by number y - * The supported signature of modulo function is - * (x: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE, y: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE) - * -> wider type between types of x and y - */ - private static DefaultFunctionResolver modulusBase(FunctionName functionName) { - return define(functionName, - impl(nullMissingHandling( - (v1, v2) -> v2.byteValue() == 0 ? ExprNullValue.of() : - new ExprByteValue(v1.byteValue() % v2.byteValue())), - BYTE, BYTE, BYTE), - impl(nullMissingHandling( - (v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : - new ExprShortValue(v1.shortValue() % v2.shortValue())), - SHORT, SHORT, SHORT), - impl(nullMissingHandling( - (v1, v2) -> v2.integerValue() == 0 ? ExprNullValue.of() : - new ExprIntegerValue(v1.integerValue() % v2.integerValue())), - INTEGER, INTEGER, INTEGER), - impl(nullMissingHandling( - (v1, v2) -> v2.longValue() == 0 ? ExprNullValue.of() : - new ExprLongValue(v1.longValue() % v2.longValue())), - LONG, LONG, LONG), - impl(nullMissingHandling( - (v1, v2) -> v2.floatValue() == 0 ? ExprNullValue.of() : - new ExprFloatValue(v1.floatValue() % v2.floatValue())), - FLOAT, FLOAT, FLOAT), - impl(nullMissingHandling( - (v1, v2) -> v2.doubleValue() == 0 ? ExprNullValue.of() : - new ExprDoubleValue(v1.doubleValue() % v2.doubleValue())), - DOUBLE, DOUBLE, DOUBLE) - ); - } + /** + * Definition of modulus(x, y) function. + * Returns the number x modulo by number y + * The supported signature of modulo function is + * (x: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE, y: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE) + * -> wider type between types of x and y + */ + private static DefaultFunctionResolver modulusBase(FunctionName functionName) { + return define( + functionName, + impl( + nullMissingHandling( + (v1, v2) -> v2.byteValue() == 0 ? ExprNullValue.of() : new ExprByteValue(v1.byteValue() % v2.byteValue()) + ), + BYTE, + BYTE, + BYTE + ), + impl( + nullMissingHandling( + (v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : new ExprShortValue(v1.shortValue() % v2.shortValue()) + ), + SHORT, + SHORT, + SHORT + ), + impl( + nullMissingHandling( + (v1, v2) -> v2.integerValue() == 0 ? ExprNullValue.of() : new ExprIntegerValue(v1.integerValue() % v2.integerValue()) + ), + INTEGER, + INTEGER, + INTEGER + ), + impl( + nullMissingHandling( + (v1, v2) -> v2.longValue() == 0 ? ExprNullValue.of() : new ExprLongValue(v1.longValue() % v2.longValue()) + ), + LONG, + LONG, + LONG + ), + impl( + nullMissingHandling( + (v1, v2) -> v2.floatValue() == 0 ? ExprNullValue.of() : new ExprFloatValue(v1.floatValue() % v2.floatValue()) + ), + FLOAT, + FLOAT, + FLOAT + ), + impl( + nullMissingHandling( + (v1, v2) -> v2.doubleValue() == 0 ? ExprNullValue.of() : new ExprDoubleValue(v1.doubleValue() % v2.doubleValue()) + ), + DOUBLE, + DOUBLE, + DOUBLE + ) + ); + } - private static DefaultFunctionResolver mod() { - return modulusBase(BuiltinFunctionName.MOD.getName()); - } + private static DefaultFunctionResolver mod() { + return modulusBase(BuiltinFunctionName.MOD.getName()); + } - private static DefaultFunctionResolver modulus() { - return modulusBase(BuiltinFunctionName.MODULUS.getName()); - } + private static DefaultFunctionResolver modulus() { + return modulusBase(BuiltinFunctionName.MODULUS.getName()); + } - private static DefaultFunctionResolver modulusFunction() { - return modulusBase(BuiltinFunctionName.MODULUSFUNCTION.getName()); - } + private static DefaultFunctionResolver modulusFunction() { + return modulusBase(BuiltinFunctionName.MODULUSFUNCTION.getName()); + } - /** - * Definition of multiply(x, y) function. - * Returns the number x multiplied by number y - * The supported signature of multiply function is - * (x: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE, y: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE) - * -> wider type between types of x and y - */ - private static DefaultFunctionResolver multiplyBase(FunctionName functionName) { - return define(functionName, - impl(nullMissingHandling( - (v1, v2) -> new ExprByteValue(v1.byteValue() * v2.byteValue())), - BYTE, BYTE, BYTE), - impl(nullMissingHandling( - (v1, v2) -> new ExprShortValue(v1.shortValue() * v2.shortValue())), - SHORT, SHORT, SHORT), - impl(nullMissingHandling( - (v1, v2) -> new ExprIntegerValue(Math.multiplyExact(v1.integerValue(), - v2.integerValue()))), - INTEGER, INTEGER, INTEGER), - impl(nullMissingHandling( - (v1, v2) -> new ExprLongValue(Math.multiplyExact(v1.longValue(), v2.longValue()))), - LONG, LONG, LONG), - impl(nullMissingHandling( - (v1, v2) -> new ExprFloatValue(v1.floatValue() * v2.floatValue())), - FLOAT, FLOAT, FLOAT), - impl(nullMissingHandling( - (v1, v2) -> new ExprDoubleValue(v1.doubleValue() * v2.doubleValue())), - DOUBLE, DOUBLE, DOUBLE) - ); - } + /** + * Definition of multiply(x, y) function. + * Returns the number x multiplied by number y + * The supported signature of multiply function is + * (x: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE, y: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE) + * -> wider type between types of x and y + */ + private static DefaultFunctionResolver multiplyBase(FunctionName functionName) { + return define( + functionName, + impl(nullMissingHandling((v1, v2) -> new ExprByteValue(v1.byteValue() * v2.byteValue())), BYTE, BYTE, BYTE), + impl(nullMissingHandling((v1, v2) -> new ExprShortValue(v1.shortValue() * v2.shortValue())), SHORT, SHORT, SHORT), + impl( + nullMissingHandling((v1, v2) -> new ExprIntegerValue(Math.multiplyExact(v1.integerValue(), v2.integerValue()))), + INTEGER, + INTEGER, + INTEGER + ), + impl(nullMissingHandling((v1, v2) -> new ExprLongValue(Math.multiplyExact(v1.longValue(), v2.longValue()))), LONG, LONG, LONG), + impl(nullMissingHandling((v1, v2) -> new ExprFloatValue(v1.floatValue() * v2.floatValue())), FLOAT, FLOAT, FLOAT), + impl(nullMissingHandling((v1, v2) -> new ExprDoubleValue(v1.doubleValue() * v2.doubleValue())), DOUBLE, DOUBLE, DOUBLE) + ); + } - private static DefaultFunctionResolver multiply() { - return multiplyBase(BuiltinFunctionName.MULTIPLY.getName()); - } + private static DefaultFunctionResolver multiply() { + return multiplyBase(BuiltinFunctionName.MULTIPLY.getName()); + } - private static DefaultFunctionResolver multiplyFunction() { - return multiplyBase(BuiltinFunctionName.MULTIPLYFUNCTION.getName()); - } + private static DefaultFunctionResolver multiplyFunction() { + return multiplyBase(BuiltinFunctionName.MULTIPLYFUNCTION.getName()); + } - /** - * Definition of subtract(x, y) function. - * Returns the number x minus number y - * The supported signature of subtract function is - * (x: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE, y: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE) - * -> wider type between types of x and y - */ - private static DefaultFunctionResolver subtractBase(FunctionName functionName) { - return define(functionName, - impl(nullMissingHandling( - (v1, v2) -> new ExprByteValue(v1.byteValue() - v2.byteValue())), - BYTE, BYTE, BYTE), - impl(nullMissingHandling( - (v1, v2) -> new ExprShortValue(v1.shortValue() - v2.shortValue())), - SHORT, SHORT, SHORT), - impl(nullMissingHandling( - (v1, v2) -> new ExprIntegerValue(Math.subtractExact(v1.integerValue(), - v2.integerValue()))), - INTEGER, INTEGER, INTEGER), - impl(nullMissingHandling( - (v1, v2) -> new ExprLongValue(Math.subtractExact(v1.longValue(), v2.longValue()))), - LONG, LONG, LONG), - impl(nullMissingHandling( - (v1, v2) -> new ExprFloatValue(v1.floatValue() - v2.floatValue())), - FLOAT, FLOAT, FLOAT), - impl(nullMissingHandling( - (v1, v2) -> new ExprDoubleValue(v1.doubleValue() - v2.doubleValue())), - DOUBLE, DOUBLE, DOUBLE) - ); - } + /** + * Definition of subtract(x, y) function. + * Returns the number x minus number y + * The supported signature of subtract function is + * (x: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE, y: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE) + * -> wider type between types of x and y + */ + private static DefaultFunctionResolver subtractBase(FunctionName functionName) { + return define( + functionName, + impl(nullMissingHandling((v1, v2) -> new ExprByteValue(v1.byteValue() - v2.byteValue())), BYTE, BYTE, BYTE), + impl(nullMissingHandling((v1, v2) -> new ExprShortValue(v1.shortValue() - v2.shortValue())), SHORT, SHORT, SHORT), + impl( + nullMissingHandling((v1, v2) -> new ExprIntegerValue(Math.subtractExact(v1.integerValue(), v2.integerValue()))), + INTEGER, + INTEGER, + INTEGER + ), + impl(nullMissingHandling((v1, v2) -> new ExprLongValue(Math.subtractExact(v1.longValue(), v2.longValue()))), LONG, LONG, LONG), + impl(nullMissingHandling((v1, v2) -> new ExprFloatValue(v1.floatValue() - v2.floatValue())), FLOAT, FLOAT, FLOAT), + impl(nullMissingHandling((v1, v2) -> new ExprDoubleValue(v1.doubleValue() - v2.doubleValue())), DOUBLE, DOUBLE, DOUBLE) + ); + } - private static DefaultFunctionResolver subtract() { - return subtractBase(BuiltinFunctionName.SUBTRACT.getName()); - } + private static DefaultFunctionResolver subtract() { + return subtractBase(BuiltinFunctionName.SUBTRACT.getName()); + } - private static DefaultFunctionResolver subtractFunction() { - return subtractBase(BuiltinFunctionName.SUBTRACTFUNCTION.getName()); - } -} \ No newline at end of file + private static DefaultFunctionResolver subtractFunction() { + return subtractBase(BuiltinFunctionName.SUBTRACTFUNCTION.getName()); + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java index 810d292ca2..a561c02513 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.operator.arthmetic; import static org.opensearch.sql.data.type.ExprCoreType.BYTE; @@ -48,629 +47,761 @@ @UtilityClass public class MathematicalFunction { - /** - * Register Mathematical Functions. - * - * @param repository {@link BuiltinFunctionRepository}. - */ - public static void register(BuiltinFunctionRepository repository) { - repository.register(abs()); - repository.register(acos()); - repository.register(asin()); - repository.register(atan()); - repository.register(atan2()); - repository.register(cbrt()); - repository.register(ceil()); - repository.register(ceiling()); - repository.register(conv()); - repository.register(cos()); - repository.register(cosh()); - repository.register(cot()); - repository.register(crc32()); - repository.register(degrees()); - repository.register(euler()); - repository.register(exp()); - repository.register(expm1()); - repository.register(floor()); - repository.register(ln()); - repository.register(log()); - repository.register(log10()); - repository.register(log2()); - repository.register(mod()); - repository.register(pi()); - repository.register(pow()); - repository.register(power()); - repository.register(radians()); - repository.register(rand()); - repository.register(rint()); - repository.register(round()); - repository.register(sign()); - repository.register(signum()); - repository.register(sin()); - repository.register(sinh()); - repository.register(sqrt()); - repository.register(tan()); - repository.register(truncate()); - } - - /** - * Base function for math functions with similar formats that return DOUBLE. - * - * @param functionName BuiltinFunctionName of math function. - * @param formula lambda function of math formula. - * @param returnType data type return type of the calling function - * @return DefaultFunctionResolver for math functions. - */ - private static DefaultFunctionResolver baseMathFunction( - FunctionName functionName, SerializableFunction formula, ExprCoreType returnType) { - return define(functionName, ExprCoreType.numberTypes().stream().map(type -> - impl(nullMissingHandling(formula), returnType, type)).collect(Collectors.toList())); - } - - /** - * Definition of abs() function. The supported signature of abs() function are INT -> INT LONG -> - * LONG FLOAT -> FLOAT DOUBLE -> DOUBLE - */ - private static DefaultFunctionResolver abs() { - return define(BuiltinFunctionName.ABS.getName(), - impl(nullMissingHandling(v -> new ExprByteValue(Math.abs(v.byteValue()))), - BYTE, BYTE), - impl(nullMissingHandling(v -> new ExprShortValue(Math.abs(v.shortValue()))), - SHORT, SHORT), - impl(nullMissingHandling(v -> new ExprIntegerValue(Math.abs(v.integerValue()))), - INTEGER, INTEGER), - impl(nullMissingHandling(v -> new ExprLongValue(Math.abs(v.longValue()))), - LONG, LONG), - impl(nullMissingHandling(v -> new ExprFloatValue(Math.abs(v.floatValue()))), - FLOAT, FLOAT), - impl(nullMissingHandling(v -> new ExprDoubleValue(Math.abs(v.doubleValue()))), - DOUBLE, DOUBLE) - ); - } - - /** - * Definition of ceil(x)/ceiling(x) function. Calculate the next highest integer that x rounds up - * to The supported signature of ceil/ceiling function is DOUBLE -> INTEGER - */ - private static DefaultFunctionResolver ceil() { - return define(BuiltinFunctionName.CEIL.getName(), - impl(nullMissingHandling(v -> new ExprLongValue(Math.ceil(v.doubleValue()))), - LONG, DOUBLE) - ); - } - - private static DefaultFunctionResolver ceiling() { - return define(BuiltinFunctionName.CEILING.getName(), - impl(nullMissingHandling(v -> new ExprLongValue(Math.ceil(v.doubleValue()))), - LONG, DOUBLE) - ); - } - - /** - * Definition of conv(x, a, b) function. - * Convert number x from base a to base b - * The supported signature of floor function is - * (STRING, INTEGER, INTEGER) -> STRING - * (INTEGER, INTEGER, INTEGER) -> STRING - */ - private static DefaultFunctionResolver conv() { - return define(BuiltinFunctionName.CONV.getName(), - impl(nullMissingHandling((x, a, b) -> new ExprStringValue( - Integer.toString(Integer.parseInt(x.stringValue(), a.integerValue()), - b.integerValue()))), - STRING, STRING, INTEGER, INTEGER), - impl(nullMissingHandling((x, a, b) -> new ExprStringValue( - Integer.toString(Integer.parseInt(x.integerValue().toString(), a.integerValue()), - b.integerValue()))), - STRING, INTEGER, INTEGER, INTEGER) - ); - } - - /** - * Definition of crc32(x) function. - * Calculate a cyclic redundancy check value and returns a 32-bit unsigned value - * The supported signature of crc32 function is - * STRING -> LONG - */ - private static DefaultFunctionResolver crc32() { - return define(BuiltinFunctionName.CRC32.getName(), - impl(nullMissingHandling(v -> { - CRC32 crc = new CRC32(); - crc.update(v.stringValue().getBytes()); - return new ExprLongValue(crc.getValue()); - }), - LONG, STRING) - ); - } - - /** - * Definition of e() function. - * Get the Euler's number. - * () -> DOUBLE - */ - private static DefaultFunctionResolver euler() { - return define(BuiltinFunctionName.E.getName(), - impl(() -> new ExprDoubleValue(Math.E), DOUBLE) - ); - } - - /** - * Definition of exp(x) function. Calculate exponent function e to the x - * The supported signature of exp function is INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE - */ - private static DefaultFunctionResolver exp() { - return baseMathFunction(BuiltinFunctionName.EXP.getName(), - v -> new ExprDoubleValue(Math.exp(v.doubleValue())), DOUBLE); - } - - /** - * Definition of expm1(x) function. Calculate exponent function e to the x, minus 1 - * The supported signature of exp function is INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE - */ - private static DefaultFunctionResolver expm1() { - return baseMathFunction(BuiltinFunctionName.EXPM1.getName(), - v -> new ExprDoubleValue(Math.expm1(v.doubleValue())), DOUBLE); - } - - /** - * Definition of floor(x) function. Calculate the next nearest whole integer that x rounds down to - * The supported signature of floor function is DOUBLE -> INTEGER - */ - private static DefaultFunctionResolver floor() { - return define(BuiltinFunctionName.FLOOR.getName(), - impl(nullMissingHandling(v -> new ExprLongValue(Math.floor(v.doubleValue()))), - LONG, DOUBLE) - ); - } - - /** - * Definition of ln(x) function. Calculate the natural logarithm of x The supported signature of - * ln function is INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE - */ - private static DefaultFunctionResolver ln() { - return baseMathFunction(BuiltinFunctionName.LN.getName(), - v -> v.doubleValue() <= 0 ? ExprNullValue.of() : - new ExprDoubleValue(Math.log(v.doubleValue())), DOUBLE); - } - - /** - * Definition of log(b, x) function. Calculate the logarithm of x using b as the base The - * supported signature of log function is (b: INTEGER/LONG/FLOAT/DOUBLE, x: - * INTEGER/LONG/FLOAT/DOUBLE]) -> DOUBLE - */ - private static DefaultFunctionResolver log() { - ImmutableList.Builder>> builder = new ImmutableList.Builder<>(); - - // build unary log(x), SHORT/INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE - for (ExprType type : ExprCoreType.numberTypes()) { - builder.add(impl(nullMissingHandling(v -> v.doubleValue() <= 0 ? ExprNullValue.of() : - new ExprDoubleValue(Math.log(v.doubleValue()))), - DOUBLE, type)); - } - - // build binary function log(b, x) - for (ExprType baseType : ExprCoreType.numberTypes()) { - for (ExprType numberType : ExprCoreType.numberTypes()) { - builder.add(impl(nullMissingHandling((b, x) -> b.doubleValue() <= 0 || x.doubleValue() <= 0 - ? ExprNullValue.of() : new ExprDoubleValue( - Math.log(x.doubleValue()) / Math.log(b.doubleValue()))), - DOUBLE, baseType, numberType)); - } - } - return define(BuiltinFunctionName.LOG.getName(), builder.build()); - } - - - /** - * Definition of log10(x) function. Calculate base-10 logarithm of x The supported signature of - * log function is SHORT/INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE - */ - private static DefaultFunctionResolver log10() { - return baseMathFunction(BuiltinFunctionName.LOG10.getName(), - v -> v.doubleValue() <= 0 ? ExprNullValue.of() : - new ExprDoubleValue(Math.log10(v.doubleValue())), DOUBLE); - } - - /** - * Definition of log2(x) function. Calculate base-2 logarithm of x The supported signature of log - * function is SHORT/INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE - */ - private static DefaultFunctionResolver log2() { - return baseMathFunction(BuiltinFunctionName.LOG2.getName(), - v -> v.doubleValue() <= 0 ? ExprNullValue.of() : - new ExprDoubleValue(Math.log(v.doubleValue()) / Math.log(2)), DOUBLE); - } - - /** - * Definition of mod(x, y) function. - * Calculate the remainder of x divided by y - * The supported signature of mod function is - * (x: INTEGER/LONG/FLOAT/DOUBLE, y: INTEGER/LONG/FLOAT/DOUBLE) - * -> wider type between types of x and y - */ - private static DefaultFunctionResolver mod() { - return define(BuiltinFunctionName.MOD.getName(), - impl(nullMissingHandling((v1, v2) -> v2.byteValue() == 0 ? ExprNullValue.of() : - new ExprByteValue(v1.byteValue() % v2.byteValue())), - BYTE, BYTE, BYTE), - impl(nullMissingHandling((v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : - new ExprShortValue(v1.shortValue() % v2.shortValue())), - SHORT, SHORT, SHORT), - impl(nullMissingHandling((v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : - new ExprIntegerValue(Math.floorMod(v1.integerValue(), v2.integerValue()))), - INTEGER, INTEGER, INTEGER), - impl(nullMissingHandling((v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : - new ExprLongValue(Math.floorMod(v1.longValue(), v2.longValue()))), - LONG, LONG, LONG), - impl(nullMissingHandling((v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : - new ExprFloatValue(v1.floatValue() % v2.floatValue())), - FLOAT, FLOAT, FLOAT), - impl(nullMissingHandling((v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : - new ExprDoubleValue(v1.doubleValue() % v2.doubleValue())), - DOUBLE, DOUBLE, DOUBLE) - ); - } - - /** - * Definition of pi() function. - * Get the value of pi. - * () -> DOUBLE - */ - private static DefaultFunctionResolver pi() { - return define(BuiltinFunctionName.PI.getName(), - impl(() -> new ExprDoubleValue(Math.PI), DOUBLE) - ); - } - - /** - * Definition of pow(x, y)/power(x, y) function. - * Calculate the value of x raised to the power of y - * The supported signature of pow/power function is - * (INTEGER, INTEGER) -> DOUBLE - * (LONG, LONG) -> DOUBLE - * (FLOAT, FLOAT) -> DOUBLE - * (DOUBLE, DOUBLE) -> DOUBLE - */ - private static DefaultFunctionResolver pow() { - return define(BuiltinFunctionName.POW.getName(), powerFunctionImpl()); - } - - private static DefaultFunctionResolver power() { - return define(BuiltinFunctionName.POWER.getName(), powerFunctionImpl()); - } - - private List>> powerFunctionImpl() { - return Arrays.asList( - impl(nullMissingHandling( - (v1, v2) -> new ExprDoubleValue(Math.pow(v1.shortValue(), v2.shortValue()))), - DOUBLE, SHORT, SHORT), - impl(nullMissingHandling( - (v1, v2) -> new ExprDoubleValue(Math.pow(v1.integerValue(), v2.integerValue()))), - DOUBLE, INTEGER, INTEGER), - impl(nullMissingHandling( - (v1, v2) -> new ExprDoubleValue(Math.pow(v1.longValue(), v2.longValue()))), - DOUBLE, LONG, LONG), - impl(nullMissingHandling( - (v1, v2) -> v1.floatValue() <= 0 && v2.floatValue() != Math.floor(v2.floatValue()) - ? ExprNullValue.of() : - new ExprDoubleValue(Math.pow(v1.floatValue(), v2.floatValue()))), - DOUBLE, FLOAT, FLOAT), - impl(nullMissingHandling( - (v1, v2) -> v1.doubleValue() <= 0 && v2.doubleValue() != Math.floor(v2.doubleValue()) - ? ExprNullValue.of() : - new ExprDoubleValue(Math.pow(v1.doubleValue(), v2.doubleValue()))), - DOUBLE, DOUBLE, DOUBLE)); - } - - /** - * Definition of rand() and rand(N) function. - * rand() returns a random floating-point value in the range 0 <= value < 1.0 - * If integer N is specified, the seed is initialized prior to execution. - * One implication of this behavior is with identical argument N,rand(N) returns the same value - * each time, and thus produces a repeatable sequence of column values. - * The supported signature of rand function is - * ([INTEGER]) -> FLOAT - */ - private static DefaultFunctionResolver rand() { - return define(BuiltinFunctionName.RAND.getName(), - impl(() -> new ExprFloatValue(new Random().nextFloat()), FLOAT), - impl(nullMissingHandling( - v -> new ExprFloatValue(new Random(v.integerValue()).nextFloat())), FLOAT, INTEGER) - ); - } - - /** - * Definition of rint(x) function. - * Returns the closest whole integer value to x - * The supported signature is - * BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE - */ - private static DefaultFunctionResolver rint() { - return baseMathFunction(BuiltinFunctionName.RINT.getName(), - v -> new ExprDoubleValue(Math.rint(v.doubleValue())), DOUBLE); - } - - /** - * Definition of round(x)/round(x, d) function. - * Rounds the argument x to d decimal places, d defaults to 0 if not specified. - * The supported signature of round function is - * (x: INTEGER [, y: INTEGER]) -> INTEGER - * (x: LONG [, y: INTEGER]) -> LONG - * (x: FLOAT [, y: INTEGER]) -> FLOAT - * (x: DOUBLE [, y: INTEGER]) -> DOUBLE - */ - private static DefaultFunctionResolver round() { - return define(BuiltinFunctionName.ROUND.getName(), - // rand(x) - impl(nullMissingHandling(v -> new ExprLongValue((long) Math.round(v.integerValue()))), - LONG, INTEGER), - impl(nullMissingHandling(v -> new ExprLongValue((long) Math.round(v.longValue()))), - LONG, LONG), - impl(nullMissingHandling(v -> new ExprDoubleValue((double) Math.round(v.floatValue()))), - DOUBLE, FLOAT), - impl(nullMissingHandling(v -> new ExprDoubleValue(new BigDecimal(v.doubleValue()) - .setScale(0, RoundingMode.HALF_UP).doubleValue())), - DOUBLE, DOUBLE), - - // rand(x, d) - impl(nullMissingHandling((x, d) -> new ExprLongValue(new BigDecimal(x.integerValue()) - .setScale(d.integerValue(), RoundingMode.HALF_UP).longValue())), - LONG, INTEGER, INTEGER), - impl(nullMissingHandling((x, d) -> new ExprLongValue(new BigDecimal(x.longValue()) - .setScale(d.integerValue(), RoundingMode.HALF_UP).longValue())), - LONG, LONG, INTEGER), - impl(nullMissingHandling((x, d) -> new ExprDoubleValue(new BigDecimal(x.floatValue()) - .setScale(d.integerValue(), RoundingMode.HALF_UP).doubleValue())), - DOUBLE, FLOAT, INTEGER), - impl(nullMissingHandling((x, d) -> new ExprDoubleValue(new BigDecimal(x.doubleValue()) - .setScale(d.integerValue(), RoundingMode.HALF_UP).doubleValue())), - DOUBLE, DOUBLE, INTEGER)); - } - - /** - * Definition of sign(x) function. - * Returns the sign of the argument as -1, 0, or 1 - * depending on whether x is negative, zero, or positive - * The supported signature is - * SHORT/INTEGER/LONG/FLOAT/DOUBLE -> INTEGER - */ - private static DefaultFunctionResolver sign() { - return baseMathFunction(BuiltinFunctionName.SIGN.getName(), - v -> new ExprIntegerValue(Math.signum(v.doubleValue())), INTEGER); - } - - /** - * Definition of signum(x) function. - * Returns the sign of the argument as -1.0, 0, or 1.0 - * depending on whether x is negative, zero, or positive - * The supported signature is - * BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE -> INTEGER - */ - private static DefaultFunctionResolver signum() { - return baseMathFunction(BuiltinFunctionName.SIGNUM.getName(), - v -> new ExprIntegerValue(Math.signum(v.doubleValue())), INTEGER); - } - - /** - * Definition of sinh(x) function. - * Returns the hyperbolix sine of x, defined as (((e^x) - (e^(-x))) / 2) - * The supported signature is - * BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE - */ - private static DefaultFunctionResolver sinh() { - return baseMathFunction(BuiltinFunctionName.SINH.getName(), - v -> new ExprDoubleValue(Math.sinh(v.doubleValue())), DOUBLE); - } - - /** - * Definition of sqrt(x) function. - * Calculate the square root of a non-negative number x - * The supported signature is - * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE - */ - private static DefaultFunctionResolver sqrt() { - return baseMathFunction(BuiltinFunctionName.SQRT.getName(), - v -> v.doubleValue() < 0 ? ExprNullValue.of() : - new ExprDoubleValue(Math.sqrt(v.doubleValue())), DOUBLE); - } - - /** - * Definition of cbrt(x) function. - * Calculate the cube root of a number x - * The supported signature is - * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE - */ - private static DefaultFunctionResolver cbrt() { - return baseMathFunction(BuiltinFunctionName.CBRT.getName(), - v -> new ExprDoubleValue(Math.cbrt(v.doubleValue())), DOUBLE); - } - - /** - * Definition of truncate(x, d) function. - * Returns the number x, truncated to d decimal places - * The supported signature of round function is - * (x: INTEGER, y: INTEGER) -> LONG - * (x: LONG, y: INTEGER) -> LONG - * (x: FLOAT, y: INTEGER) -> DOUBLE - * (x: DOUBLE, y: INTEGER) -> DOUBLE - */ - private static DefaultFunctionResolver truncate() { - return define(BuiltinFunctionName.TRUNCATE.getName(), - impl(nullMissingHandling((x, y) -> new ExprLongValue(BigDecimal.valueOf(x.integerValue()) - .setScale(y.integerValue(), RoundingMode.DOWN).longValue())), - LONG, INTEGER, INTEGER), - impl(nullMissingHandling((x, y) -> new ExprLongValue(BigDecimal.valueOf(x.longValue()) - .setScale(y.integerValue(), RoundingMode.DOWN).longValue())), - LONG, LONG, INTEGER), - impl(nullMissingHandling((x, y) -> new ExprDoubleValue(BigDecimal.valueOf(x.floatValue()) - .setScale(y.integerValue(), RoundingMode.DOWN).doubleValue())), - DOUBLE, FLOAT, INTEGER), - impl(nullMissingHandling((x, y) -> new ExprDoubleValue(BigDecimal.valueOf(x.doubleValue()) - .setScale(y.integerValue(), RoundingMode.DOWN).doubleValue())), - DOUBLE, DOUBLE, INTEGER)); - } - - /** - * Definition of acos(x) function. - * Calculates the arc cosine of x, that is, the value whose cosine is x. - * Returns NULL if x is not in the range -1 to 1. - * The supported signature of acos function is - * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE - */ - private static DefaultFunctionResolver acos() { - return define(BuiltinFunctionName.ACOS.getName(), - ExprCoreType.numberTypes().stream() - .map(type -> impl(nullMissingHandling( - v -> v.doubleValue() < -1 || v.doubleValue() > 1 ? ExprNullValue.of() : - new ExprDoubleValue(Math.acos(v.doubleValue()))), - DOUBLE, type)).collect(Collectors.toList())); - } - - /** - * Definition of asin(x) function. - * Calculates the arc sine of x, that is, the value whose sine is x. - * Returns NULL if x is not in the range -1 to 1. - * The supported signature of asin function is - * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE - */ - private static DefaultFunctionResolver asin() { - return define(BuiltinFunctionName.ASIN.getName(), - ExprCoreType.numberTypes().stream() - .map(type -> impl(nullMissingHandling( - v -> v.doubleValue() < -1 || v.doubleValue() > 1 ? ExprNullValue.of() : - new ExprDoubleValue(Math.asin(v.doubleValue()))), - DOUBLE, type)).collect(Collectors.toList())); - } - - /** - * Definition of atan(x) and atan(y, x) function. - * atan(x) calculates the arc tangent of x, that is, the value whose tangent is x. - * atan(y, x) calculates the arc tangent of y / x, except that the signs of both arguments - * are used to determine the quadrant of the result. - * The supported signature of atan function is - * (x: INTEGER/LONG/FLOAT/DOUBLE, y: INTEGER/LONG/FLOAT/DOUBLE) -> DOUBLE - */ - private static DefaultFunctionResolver atan() { - ImmutableList.Builder>> builder = new ImmutableList.Builder<>(); - - for (ExprType type : ExprCoreType.numberTypes()) { - builder.add(impl(nullMissingHandling(x -> new ExprDoubleValue(Math.atan(x.doubleValue()))), - type, DOUBLE)); - builder.add(impl(nullMissingHandling((y, x) -> new ExprDoubleValue(Math.atan2(y.doubleValue(), - x.doubleValue()))), - DOUBLE, type, type)); - } - - return define(BuiltinFunctionName.ATAN.getName(), builder.build()); - } - - /** - * Definition of atan2(y, x) function. - * Calculates the arc tangent of y / x, except that the signs of both arguments - * are used to determine the quadrant of the result. - * The supported signature of atan2 function is - * (x: INTEGER/LONG/FLOAT/DOUBLE, y: INTEGER/LONG/FLOAT/DOUBLE) -> DOUBLE - */ - private static DefaultFunctionResolver atan2() { - ImmutableList.Builder>> builder = new ImmutableList.Builder<>(); - - for (ExprType type : ExprCoreType.numberTypes()) { - builder.add(impl(nullMissingHandling((y, x) -> new ExprDoubleValue(Math.atan2(y.doubleValue(), - x.doubleValue()))), DOUBLE, type, type)); - } - - return define(BuiltinFunctionName.ATAN2.getName(), builder.build()); - } - - /** - * Definition of cos(x) function. - * Calculates the cosine of X, where X is given in radians - * The supported signature of cos function is - * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE - */ - private static DefaultFunctionResolver cos() { - return baseMathFunction(BuiltinFunctionName.COS.getName(), - v -> new ExprDoubleValue(Math.cos(v.doubleValue())), DOUBLE); - } - - /** - * Definition of cosh(x) function. - * Returns the hyperbolic cosine of x, defined as (((e^x) + (e^(-x))) / 2) - * The supported signature is - * BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE - */ - private static DefaultFunctionResolver cosh() { - return baseMathFunction(BuiltinFunctionName.COSH.getName(), - v -> new ExprDoubleValue(Math.cosh(v.doubleValue())), DOUBLE); - } - - /** - * Definition of cot(x) function. - * Calculates the cotangent of x - * The supported signature of cot function is - * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE - */ - private static DefaultFunctionResolver cot() { - return define(BuiltinFunctionName.COT.getName(), - ExprCoreType.numberTypes().stream() - .map(type -> impl(nullMissingHandling( - v -> { - Double value = v.doubleValue(); - if (value == 0) { - throw new ArithmeticException( - String.format("Out of range value for cot(%s)", value)); - } - return new ExprDoubleValue(1 / Math.tan(value)); - }), - DOUBLE, type)).collect(Collectors.toList())); - } - - /** - * Definition of degrees(x) function. - * Converts x from radians to degrees - * The supported signature of degrees function is - * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE - */ - private static DefaultFunctionResolver degrees() { - return baseMathFunction(BuiltinFunctionName.DEGREES.getName(), - v -> new ExprDoubleValue(Math.toDegrees(v.doubleValue())), DOUBLE); - } - - /** - * Definition of radians(x) function. - * Converts x from degrees to radians - * The supported signature of radians function is - * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE - */ - private static DefaultFunctionResolver radians() { - return baseMathFunction(BuiltinFunctionName.RADIANS.getName(), - v -> new ExprDoubleValue(Math.toRadians(v.doubleValue())), DOUBLE); - } - - /** - * Definition of sin(x) function. - * Calculates the sine of x, where x is given in radians - * The supported signature of sin function is - * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE - */ - private static DefaultFunctionResolver sin() { - return baseMathFunction(BuiltinFunctionName.SIN.getName(), - v -> new ExprDoubleValue(Math.sin(v.doubleValue())), DOUBLE); - } - - /** - * Definition of tan(x) function. - * Calculates the tangent of x, where x is given in radians - * The supported signature of tan function is - * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE - */ - private static DefaultFunctionResolver tan() { - return baseMathFunction(BuiltinFunctionName.TAN.getName(), - v -> new ExprDoubleValue(Math.tan(v.doubleValue())), DOUBLE); - } + /** + * Register Mathematical Functions. + * + * @param repository {@link BuiltinFunctionRepository}. + */ + public static void register(BuiltinFunctionRepository repository) { + repository.register(abs()); + repository.register(acos()); + repository.register(asin()); + repository.register(atan()); + repository.register(atan2()); + repository.register(cbrt()); + repository.register(ceil()); + repository.register(ceiling()); + repository.register(conv()); + repository.register(cos()); + repository.register(cosh()); + repository.register(cot()); + repository.register(crc32()); + repository.register(degrees()); + repository.register(euler()); + repository.register(exp()); + repository.register(expm1()); + repository.register(floor()); + repository.register(ln()); + repository.register(log()); + repository.register(log10()); + repository.register(log2()); + repository.register(mod()); + repository.register(pi()); + repository.register(pow()); + repository.register(power()); + repository.register(radians()); + repository.register(rand()); + repository.register(rint()); + repository.register(round()); + repository.register(sign()); + repository.register(signum()); + repository.register(sin()); + repository.register(sinh()); + repository.register(sqrt()); + repository.register(tan()); + repository.register(truncate()); + } + + /** + * Base function for math functions with similar formats that return DOUBLE. + * + * @param functionName BuiltinFunctionName of math function. + * @param formula lambda function of math formula. + * @param returnType data type return type of the calling function + * @return DefaultFunctionResolver for math functions. + */ + private static DefaultFunctionResolver baseMathFunction( + FunctionName functionName, + SerializableFunction formula, + ExprCoreType returnType + ) { + return define( + functionName, + ExprCoreType.numberTypes() + .stream() + .map(type -> impl(nullMissingHandling(formula), returnType, type)) + .collect(Collectors.toList()) + ); + } + + /** + * Definition of abs() function. The supported signature of abs() function are INT -> INT LONG -> + * LONG FLOAT -> FLOAT DOUBLE -> DOUBLE + */ + private static DefaultFunctionResolver abs() { + return define( + BuiltinFunctionName.ABS.getName(), + impl(nullMissingHandling(v -> new ExprByteValue(Math.abs(v.byteValue()))), BYTE, BYTE), + impl(nullMissingHandling(v -> new ExprShortValue(Math.abs(v.shortValue()))), SHORT, SHORT), + impl(nullMissingHandling(v -> new ExprIntegerValue(Math.abs(v.integerValue()))), INTEGER, INTEGER), + impl(nullMissingHandling(v -> new ExprLongValue(Math.abs(v.longValue()))), LONG, LONG), + impl(nullMissingHandling(v -> new ExprFloatValue(Math.abs(v.floatValue()))), FLOAT, FLOAT), + impl(nullMissingHandling(v -> new ExprDoubleValue(Math.abs(v.doubleValue()))), DOUBLE, DOUBLE) + ); + } + + /** + * Definition of ceil(x)/ceiling(x) function. Calculate the next highest integer that x rounds up + * to The supported signature of ceil/ceiling function is DOUBLE -> INTEGER + */ + private static DefaultFunctionResolver ceil() { + return define( + BuiltinFunctionName.CEIL.getName(), + impl(nullMissingHandling(v -> new ExprLongValue(Math.ceil(v.doubleValue()))), LONG, DOUBLE) + ); + } + + private static DefaultFunctionResolver ceiling() { + return define( + BuiltinFunctionName.CEILING.getName(), + impl(nullMissingHandling(v -> new ExprLongValue(Math.ceil(v.doubleValue()))), LONG, DOUBLE) + ); + } + + /** + * Definition of conv(x, a, b) function. + * Convert number x from base a to base b + * The supported signature of floor function is + * (STRING, INTEGER, INTEGER) -> STRING + * (INTEGER, INTEGER, INTEGER) -> STRING + */ + private static DefaultFunctionResolver conv() { + return define( + BuiltinFunctionName.CONV.getName(), + impl( + nullMissingHandling( + (x, a, b) -> new ExprStringValue( + Integer.toString(Integer.parseInt(x.stringValue(), a.integerValue()), b.integerValue()) + ) + ), + STRING, + STRING, + INTEGER, + INTEGER + ), + impl( + nullMissingHandling( + (x, a, b) -> new ExprStringValue( + Integer.toString(Integer.parseInt(x.integerValue().toString(), a.integerValue()), b.integerValue()) + ) + ), + STRING, + INTEGER, + INTEGER, + INTEGER + ) + ); + } + + /** + * Definition of crc32(x) function. + * Calculate a cyclic redundancy check value and returns a 32-bit unsigned value + * The supported signature of crc32 function is + * STRING -> LONG + */ + private static DefaultFunctionResolver crc32() { + return define(BuiltinFunctionName.CRC32.getName(), impl(nullMissingHandling(v -> { + CRC32 crc = new CRC32(); + crc.update(v.stringValue().getBytes()); + return new ExprLongValue(crc.getValue()); + }), LONG, STRING)); + } + + /** + * Definition of e() function. + * Get the Euler's number. + * () -> DOUBLE + */ + private static DefaultFunctionResolver euler() { + return define(BuiltinFunctionName.E.getName(), impl(() -> new ExprDoubleValue(Math.E), DOUBLE)); + } + + /** + * Definition of exp(x) function. Calculate exponent function e to the x + * The supported signature of exp function is INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE + */ + private static DefaultFunctionResolver exp() { + return baseMathFunction(BuiltinFunctionName.EXP.getName(), v -> new ExprDoubleValue(Math.exp(v.doubleValue())), DOUBLE); + } + + /** + * Definition of expm1(x) function. Calculate exponent function e to the x, minus 1 + * The supported signature of exp function is INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE + */ + private static DefaultFunctionResolver expm1() { + return baseMathFunction(BuiltinFunctionName.EXPM1.getName(), v -> new ExprDoubleValue(Math.expm1(v.doubleValue())), DOUBLE); + } + + /** + * Definition of floor(x) function. Calculate the next nearest whole integer that x rounds down to + * The supported signature of floor function is DOUBLE -> INTEGER + */ + private static DefaultFunctionResolver floor() { + return define( + BuiltinFunctionName.FLOOR.getName(), + impl(nullMissingHandling(v -> new ExprLongValue(Math.floor(v.doubleValue()))), LONG, DOUBLE) + ); + } + + /** + * Definition of ln(x) function. Calculate the natural logarithm of x The supported signature of + * ln function is INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE + */ + private static DefaultFunctionResolver ln() { + return baseMathFunction( + BuiltinFunctionName.LN.getName(), + v -> v.doubleValue() <= 0 ? ExprNullValue.of() : new ExprDoubleValue(Math.log(v.doubleValue())), + DOUBLE + ); + } + + /** + * Definition of log(b, x) function. Calculate the logarithm of x using b as the base The + * supported signature of log function is (b: INTEGER/LONG/FLOAT/DOUBLE, x: + * INTEGER/LONG/FLOAT/DOUBLE]) -> DOUBLE + */ + private static DefaultFunctionResolver log() { + ImmutableList.Builder>> builder = + new ImmutableList.Builder<>(); + + // build unary log(x), SHORT/INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE + for (ExprType type : ExprCoreType.numberTypes()) { + builder.add( + impl( + nullMissingHandling(v -> v.doubleValue() <= 0 ? ExprNullValue.of() : new ExprDoubleValue(Math.log(v.doubleValue()))), + DOUBLE, + type + ) + ); + } + + // build binary function log(b, x) + for (ExprType baseType : ExprCoreType.numberTypes()) { + for (ExprType numberType : ExprCoreType.numberTypes()) { + builder.add( + impl( + nullMissingHandling( + (b, x) -> b.doubleValue() <= 0 || x.doubleValue() <= 0 + ? ExprNullValue.of() + : new ExprDoubleValue(Math.log(x.doubleValue()) / Math.log(b.doubleValue())) + ), + DOUBLE, + baseType, + numberType + ) + ); + } + } + return define(BuiltinFunctionName.LOG.getName(), builder.build()); + } + + /** + * Definition of log10(x) function. Calculate base-10 logarithm of x The supported signature of + * log function is SHORT/INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE + */ + private static DefaultFunctionResolver log10() { + return baseMathFunction( + BuiltinFunctionName.LOG10.getName(), + v -> v.doubleValue() <= 0 ? ExprNullValue.of() : new ExprDoubleValue(Math.log10(v.doubleValue())), + DOUBLE + ); + } + + /** + * Definition of log2(x) function. Calculate base-2 logarithm of x The supported signature of log + * function is SHORT/INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE + */ + private static DefaultFunctionResolver log2() { + return baseMathFunction( + BuiltinFunctionName.LOG2.getName(), + v -> v.doubleValue() <= 0 ? ExprNullValue.of() : new ExprDoubleValue(Math.log(v.doubleValue()) / Math.log(2)), + DOUBLE + ); + } + + /** + * Definition of mod(x, y) function. + * Calculate the remainder of x divided by y + * The supported signature of mod function is + * (x: INTEGER/LONG/FLOAT/DOUBLE, y: INTEGER/LONG/FLOAT/DOUBLE) + * -> wider type between types of x and y + */ + private static DefaultFunctionResolver mod() { + return define( + BuiltinFunctionName.MOD.getName(), + impl( + nullMissingHandling( + (v1, v2) -> v2.byteValue() == 0 ? ExprNullValue.of() : new ExprByteValue(v1.byteValue() % v2.byteValue()) + ), + BYTE, + BYTE, + BYTE + ), + impl( + nullMissingHandling( + (v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : new ExprShortValue(v1.shortValue() % v2.shortValue()) + ), + SHORT, + SHORT, + SHORT + ), + impl( + nullMissingHandling( + (v1, v2) -> v2.shortValue() == 0 + ? ExprNullValue.of() + : new ExprIntegerValue(Math.floorMod(v1.integerValue(), v2.integerValue())) + ), + INTEGER, + INTEGER, + INTEGER + ), + impl( + nullMissingHandling( + (v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : new ExprLongValue(Math.floorMod(v1.longValue(), v2.longValue())) + ), + LONG, + LONG, + LONG + ), + impl( + nullMissingHandling( + (v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : new ExprFloatValue(v1.floatValue() % v2.floatValue()) + ), + FLOAT, + FLOAT, + FLOAT + ), + impl( + nullMissingHandling( + (v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() : new ExprDoubleValue(v1.doubleValue() % v2.doubleValue()) + ), + DOUBLE, + DOUBLE, + DOUBLE + ) + ); + } + + /** + * Definition of pi() function. + * Get the value of pi. + * () -> DOUBLE + */ + private static DefaultFunctionResolver pi() { + return define(BuiltinFunctionName.PI.getName(), impl(() -> new ExprDoubleValue(Math.PI), DOUBLE)); + } + + /** + * Definition of pow(x, y)/power(x, y) function. + * Calculate the value of x raised to the power of y + * The supported signature of pow/power function is + * (INTEGER, INTEGER) -> DOUBLE + * (LONG, LONG) -> DOUBLE + * (FLOAT, FLOAT) -> DOUBLE + * (DOUBLE, DOUBLE) -> DOUBLE + */ + private static DefaultFunctionResolver pow() { + return define(BuiltinFunctionName.POW.getName(), powerFunctionImpl()); + } + + private static DefaultFunctionResolver power() { + return define(BuiltinFunctionName.POWER.getName(), powerFunctionImpl()); + } + + private List>> powerFunctionImpl() { + return Arrays.asList( + impl(nullMissingHandling((v1, v2) -> new ExprDoubleValue(Math.pow(v1.shortValue(), v2.shortValue()))), DOUBLE, SHORT, SHORT), + impl( + nullMissingHandling((v1, v2) -> new ExprDoubleValue(Math.pow(v1.integerValue(), v2.integerValue()))), + DOUBLE, + INTEGER, + INTEGER + ), + impl(nullMissingHandling((v1, v2) -> new ExprDoubleValue(Math.pow(v1.longValue(), v2.longValue()))), DOUBLE, LONG, LONG), + impl( + nullMissingHandling( + (v1, v2) -> v1.floatValue() <= 0 && v2.floatValue() != Math.floor(v2.floatValue()) + ? ExprNullValue.of() + : new ExprDoubleValue(Math.pow(v1.floatValue(), v2.floatValue())) + ), + DOUBLE, + FLOAT, + FLOAT + ), + impl( + nullMissingHandling( + (v1, v2) -> v1.doubleValue() <= 0 && v2.doubleValue() != Math.floor(v2.doubleValue()) + ? ExprNullValue.of() + : new ExprDoubleValue(Math.pow(v1.doubleValue(), v2.doubleValue())) + ), + DOUBLE, + DOUBLE, + DOUBLE + ) + ); + } + + /** + * Definition of rand() and rand(N) function. + * rand() returns a random floating-point value in the range 0 <= value < 1.0 + * If integer N is specified, the seed is initialized prior to execution. + * One implication of this behavior is with identical argument N,rand(N) returns the same value + * each time, and thus produces a repeatable sequence of column values. + * The supported signature of rand function is + * ([INTEGER]) -> FLOAT + */ + private static DefaultFunctionResolver rand() { + return define( + BuiltinFunctionName.RAND.getName(), + impl(() -> new ExprFloatValue(new Random().nextFloat()), FLOAT), + impl(nullMissingHandling(v -> new ExprFloatValue(new Random(v.integerValue()).nextFloat())), FLOAT, INTEGER) + ); + } + + /** + * Definition of rint(x) function. + * Returns the closest whole integer value to x + * The supported signature is + * BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE + */ + private static DefaultFunctionResolver rint() { + return baseMathFunction(BuiltinFunctionName.RINT.getName(), v -> new ExprDoubleValue(Math.rint(v.doubleValue())), DOUBLE); + } + + /** + * Definition of round(x)/round(x, d) function. + * Rounds the argument x to d decimal places, d defaults to 0 if not specified. + * The supported signature of round function is + * (x: INTEGER [, y: INTEGER]) -> INTEGER + * (x: LONG [, y: INTEGER]) -> LONG + * (x: FLOAT [, y: INTEGER]) -> FLOAT + * (x: DOUBLE [, y: INTEGER]) -> DOUBLE + */ + private static DefaultFunctionResolver round() { + return define( + BuiltinFunctionName.ROUND.getName(), + // rand(x) + impl(nullMissingHandling(v -> new ExprLongValue((long) Math.round(v.integerValue()))), LONG, INTEGER), + impl(nullMissingHandling(v -> new ExprLongValue((long) Math.round(v.longValue()))), LONG, LONG), + impl(nullMissingHandling(v -> new ExprDoubleValue((double) Math.round(v.floatValue()))), DOUBLE, FLOAT), + impl( + nullMissingHandling( + v -> new ExprDoubleValue(new BigDecimal(v.doubleValue()).setScale(0, RoundingMode.HALF_UP).doubleValue()) + ), + DOUBLE, + DOUBLE + ), + + // rand(x, d) + impl( + nullMissingHandling( + (x, d) -> new ExprLongValue( + new BigDecimal(x.integerValue()).setScale(d.integerValue(), RoundingMode.HALF_UP).longValue() + ) + ), + LONG, + INTEGER, + INTEGER + ), + impl( + nullMissingHandling( + (x, d) -> new ExprLongValue(new BigDecimal(x.longValue()).setScale(d.integerValue(), RoundingMode.HALF_UP).longValue()) + ), + LONG, + LONG, + INTEGER + ), + impl( + nullMissingHandling( + (x, d) -> new ExprDoubleValue( + new BigDecimal(x.floatValue()).setScale(d.integerValue(), RoundingMode.HALF_UP).doubleValue() + ) + ), + DOUBLE, + FLOAT, + INTEGER + ), + impl( + nullMissingHandling( + (x, d) -> new ExprDoubleValue( + new BigDecimal(x.doubleValue()).setScale(d.integerValue(), RoundingMode.HALF_UP).doubleValue() + ) + ), + DOUBLE, + DOUBLE, + INTEGER + ) + ); + } + + /** + * Definition of sign(x) function. + * Returns the sign of the argument as -1, 0, or 1 + * depending on whether x is negative, zero, or positive + * The supported signature is + * SHORT/INTEGER/LONG/FLOAT/DOUBLE -> INTEGER + */ + private static DefaultFunctionResolver sign() { + return baseMathFunction(BuiltinFunctionName.SIGN.getName(), v -> new ExprIntegerValue(Math.signum(v.doubleValue())), INTEGER); + } + + /** + * Definition of signum(x) function. + * Returns the sign of the argument as -1.0, 0, or 1.0 + * depending on whether x is negative, zero, or positive + * The supported signature is + * BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE -> INTEGER + */ + private static DefaultFunctionResolver signum() { + return baseMathFunction(BuiltinFunctionName.SIGNUM.getName(), v -> new ExprIntegerValue(Math.signum(v.doubleValue())), INTEGER); + } + + /** + * Definition of sinh(x) function. + * Returns the hyperbolix sine of x, defined as (((e^x) - (e^(-x))) / 2) + * The supported signature is + * BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE + */ + private static DefaultFunctionResolver sinh() { + return baseMathFunction(BuiltinFunctionName.SINH.getName(), v -> new ExprDoubleValue(Math.sinh(v.doubleValue())), DOUBLE); + } + + /** + * Definition of sqrt(x) function. + * Calculate the square root of a non-negative number x + * The supported signature is + * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE + */ + private static DefaultFunctionResolver sqrt() { + return baseMathFunction( + BuiltinFunctionName.SQRT.getName(), + v -> v.doubleValue() < 0 ? ExprNullValue.of() : new ExprDoubleValue(Math.sqrt(v.doubleValue())), + DOUBLE + ); + } + + /** + * Definition of cbrt(x) function. + * Calculate the cube root of a number x + * The supported signature is + * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE + */ + private static DefaultFunctionResolver cbrt() { + return baseMathFunction(BuiltinFunctionName.CBRT.getName(), v -> new ExprDoubleValue(Math.cbrt(v.doubleValue())), DOUBLE); + } + + /** + * Definition of truncate(x, d) function. + * Returns the number x, truncated to d decimal places + * The supported signature of round function is + * (x: INTEGER, y: INTEGER) -> LONG + * (x: LONG, y: INTEGER) -> LONG + * (x: FLOAT, y: INTEGER) -> DOUBLE + * (x: DOUBLE, y: INTEGER) -> DOUBLE + */ + private static DefaultFunctionResolver truncate() { + return define( + BuiltinFunctionName.TRUNCATE.getName(), + impl( + nullMissingHandling( + (x, y) -> new ExprLongValue( + BigDecimal.valueOf(x.integerValue()).setScale(y.integerValue(), RoundingMode.DOWN).longValue() + ) + ), + LONG, + INTEGER, + INTEGER + ), + impl( + nullMissingHandling( + (x, y) -> new ExprLongValue(BigDecimal.valueOf(x.longValue()).setScale(y.integerValue(), RoundingMode.DOWN).longValue()) + ), + LONG, + LONG, + INTEGER + ), + impl( + nullMissingHandling( + (x, y) -> new ExprDoubleValue( + BigDecimal.valueOf(x.floatValue()).setScale(y.integerValue(), RoundingMode.DOWN).doubleValue() + ) + ), + DOUBLE, + FLOAT, + INTEGER + ), + impl( + nullMissingHandling( + (x, y) -> new ExprDoubleValue( + BigDecimal.valueOf(x.doubleValue()).setScale(y.integerValue(), RoundingMode.DOWN).doubleValue() + ) + ), + DOUBLE, + DOUBLE, + INTEGER + ) + ); + } + + /** + * Definition of acos(x) function. + * Calculates the arc cosine of x, that is, the value whose cosine is x. + * Returns NULL if x is not in the range -1 to 1. + * The supported signature of acos function is + * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE + */ + private static DefaultFunctionResolver acos() { + return define( + BuiltinFunctionName.ACOS.getName(), + ExprCoreType.numberTypes() + .stream() + .map( + type -> impl( + nullMissingHandling( + v -> v.doubleValue() < -1 || v.doubleValue() > 1 + ? ExprNullValue.of() + : new ExprDoubleValue(Math.acos(v.doubleValue())) + ), + DOUBLE, + type + ) + ) + .collect(Collectors.toList()) + ); + } + + /** + * Definition of asin(x) function. + * Calculates the arc sine of x, that is, the value whose sine is x. + * Returns NULL if x is not in the range -1 to 1. + * The supported signature of asin function is + * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE + */ + private static DefaultFunctionResolver asin() { + return define( + BuiltinFunctionName.ASIN.getName(), + ExprCoreType.numberTypes() + .stream() + .map( + type -> impl( + nullMissingHandling( + v -> v.doubleValue() < -1 || v.doubleValue() > 1 + ? ExprNullValue.of() + : new ExprDoubleValue(Math.asin(v.doubleValue())) + ), + DOUBLE, + type + ) + ) + .collect(Collectors.toList()) + ); + } + + /** + * Definition of atan(x) and atan(y, x) function. + * atan(x) calculates the arc tangent of x, that is, the value whose tangent is x. + * atan(y, x) calculates the arc tangent of y / x, except that the signs of both arguments + * are used to determine the quadrant of the result. + * The supported signature of atan function is + * (x: INTEGER/LONG/FLOAT/DOUBLE, y: INTEGER/LONG/FLOAT/DOUBLE) -> DOUBLE + */ + private static DefaultFunctionResolver atan() { + ImmutableList.Builder>> builder = + new ImmutableList.Builder<>(); + + for (ExprType type : ExprCoreType.numberTypes()) { + builder.add(impl(nullMissingHandling(x -> new ExprDoubleValue(Math.atan(x.doubleValue()))), type, DOUBLE)); + builder.add( + impl(nullMissingHandling((y, x) -> new ExprDoubleValue(Math.atan2(y.doubleValue(), x.doubleValue()))), DOUBLE, type, type) + ); + } + + return define(BuiltinFunctionName.ATAN.getName(), builder.build()); + } + + /** + * Definition of atan2(y, x) function. + * Calculates the arc tangent of y / x, except that the signs of both arguments + * are used to determine the quadrant of the result. + * The supported signature of atan2 function is + * (x: INTEGER/LONG/FLOAT/DOUBLE, y: INTEGER/LONG/FLOAT/DOUBLE) -> DOUBLE + */ + private static DefaultFunctionResolver atan2() { + ImmutableList.Builder>> builder = + new ImmutableList.Builder<>(); + + for (ExprType type : ExprCoreType.numberTypes()) { + builder.add( + impl(nullMissingHandling((y, x) -> new ExprDoubleValue(Math.atan2(y.doubleValue(), x.doubleValue()))), DOUBLE, type, type) + ); + } + + return define(BuiltinFunctionName.ATAN2.getName(), builder.build()); + } + + /** + * Definition of cos(x) function. + * Calculates the cosine of X, where X is given in radians + * The supported signature of cos function is + * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE + */ + private static DefaultFunctionResolver cos() { + return baseMathFunction(BuiltinFunctionName.COS.getName(), v -> new ExprDoubleValue(Math.cos(v.doubleValue())), DOUBLE); + } + + /** + * Definition of cosh(x) function. + * Returns the hyperbolic cosine of x, defined as (((e^x) + (e^(-x))) / 2) + * The supported signature is + * BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE + */ + private static DefaultFunctionResolver cosh() { + return baseMathFunction(BuiltinFunctionName.COSH.getName(), v -> new ExprDoubleValue(Math.cosh(v.doubleValue())), DOUBLE); + } + + /** + * Definition of cot(x) function. + * Calculates the cotangent of x + * The supported signature of cot function is + * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE + */ + private static DefaultFunctionResolver cot() { + return define(BuiltinFunctionName.COT.getName(), ExprCoreType.numberTypes().stream().map(type -> impl(nullMissingHandling(v -> { + Double value = v.doubleValue(); + if (value == 0) { + throw new ArithmeticException(String.format("Out of range value for cot(%s)", value)); + } + return new ExprDoubleValue(1 / Math.tan(value)); + }), DOUBLE, type)).collect(Collectors.toList())); + } + + /** + * Definition of degrees(x) function. + * Converts x from radians to degrees + * The supported signature of degrees function is + * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE + */ + private static DefaultFunctionResolver degrees() { + return baseMathFunction(BuiltinFunctionName.DEGREES.getName(), v -> new ExprDoubleValue(Math.toDegrees(v.doubleValue())), DOUBLE); + } + + /** + * Definition of radians(x) function. + * Converts x from degrees to radians + * The supported signature of radians function is + * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE + */ + private static DefaultFunctionResolver radians() { + return baseMathFunction(BuiltinFunctionName.RADIANS.getName(), v -> new ExprDoubleValue(Math.toRadians(v.doubleValue())), DOUBLE); + } + + /** + * Definition of sin(x) function. + * Calculates the sine of x, where x is given in radians + * The supported signature of sin function is + * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE + */ + private static DefaultFunctionResolver sin() { + return baseMathFunction(BuiltinFunctionName.SIN.getName(), v -> new ExprDoubleValue(Math.sin(v.doubleValue())), DOUBLE); + } + + /** + * Definition of tan(x) function. + * Calculates the tangent of x, where x is given in radians + * The supported signature of tan function is + * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE + */ + private static DefaultFunctionResolver tan() { + return baseMathFunction(BuiltinFunctionName.TAN.getName(), v -> new ExprDoubleValue(Math.tan(v.doubleValue())), DOUBLE); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/convert/TypeCastOperator.java b/core/src/main/java/org/opensearch/sql/expression/operator/convert/TypeCastOperator.java index d3295a53f0..dc714e7e87 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/convert/TypeCastOperator.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/convert/TypeCastOperator.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.operator.convert; import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN; @@ -46,166 +45,147 @@ @UtilityClass public class TypeCastOperator { - /** - * Register Type Cast Operator. - */ - public static void register(BuiltinFunctionRepository repository) { - repository.register(castToString()); - repository.register(castToByte()); - repository.register(castToShort()); - repository.register(castToInt()); - repository.register(castToLong()); - repository.register(castToFloat()); - repository.register(castToDouble()); - repository.register(castToBoolean()); - repository.register(castToDate()); - repository.register(castToTime()); - repository.register(castToTimestamp()); - repository.register(castToDatetime()); - } - - - private static DefaultFunctionResolver castToString() { - return FunctionDSL.define(BuiltinFunctionName.CAST_TO_STRING.getName(), - Stream.concat( - Arrays.asList(BYTE, SHORT, INTEGER, LONG, FLOAT, DOUBLE, BOOLEAN, TIME, DATE, - TIMESTAMP, DATETIME).stream() - .map(type -> impl( - nullMissingHandling((v) -> new ExprStringValue(v.value().toString())), - STRING, type)), - Stream.of(impl(nullMissingHandling((v) -> v), STRING, STRING))) - .collect(Collectors.toList()) - ); - } - - private static DefaultFunctionResolver castToByte() { - return FunctionDSL.define(BuiltinFunctionName.CAST_TO_BYTE.getName(), - impl(nullMissingHandling( - (v) -> new ExprByteValue(Byte.valueOf(v.stringValue()))), BYTE, STRING), - impl(nullMissingHandling( - (v) -> new ExprByteValue(v.byteValue())), BYTE, DOUBLE), - impl(nullMissingHandling( - (v) -> new ExprByteValue(v.booleanValue() ? 1 : 0)), BYTE, BOOLEAN) - ); - } - - private static DefaultFunctionResolver castToShort() { - return FunctionDSL.define(BuiltinFunctionName.CAST_TO_SHORT.getName(), - impl(nullMissingHandling( - (v) -> new ExprShortValue(Short.valueOf(v.stringValue()))), SHORT, STRING), - impl(nullMissingHandling( - (v) -> new ExprShortValue(v.shortValue())), SHORT, DOUBLE), - impl(nullMissingHandling( - (v) -> new ExprShortValue(v.booleanValue() ? 1 : 0)), SHORT, BOOLEAN) - ); - } - - private static DefaultFunctionResolver castToInt() { - return FunctionDSL.define(BuiltinFunctionName.CAST_TO_INT.getName(), - impl(nullMissingHandling( - (v) -> new ExprIntegerValue(Integer.valueOf(v.stringValue()))), INTEGER, STRING), - impl(nullMissingHandling( - (v) -> new ExprIntegerValue(v.integerValue())), INTEGER, DOUBLE), - impl(nullMissingHandling( - (v) -> new ExprIntegerValue(v.booleanValue() ? 1 : 0)), INTEGER, BOOLEAN) - ); - } - - private static DefaultFunctionResolver castToLong() { - return FunctionDSL.define(BuiltinFunctionName.CAST_TO_LONG.getName(), - impl(nullMissingHandling( - (v) -> new ExprLongValue(Long.valueOf(v.stringValue()))), LONG, STRING), - impl(nullMissingHandling( - (v) -> new ExprLongValue(v.longValue())), LONG, DOUBLE), - impl(nullMissingHandling( - (v) -> new ExprLongValue(v.booleanValue() ? 1L : 0L)), LONG, BOOLEAN) - ); - } - - private static DefaultFunctionResolver castToFloat() { - return FunctionDSL.define(BuiltinFunctionName.CAST_TO_FLOAT.getName(), - impl(nullMissingHandling( - (v) -> new ExprFloatValue(Float.valueOf(v.stringValue()))), FLOAT, STRING), - impl(nullMissingHandling( - (v) -> new ExprFloatValue(v.floatValue())), FLOAT, DOUBLE), - impl(nullMissingHandling( - (v) -> new ExprFloatValue(v.booleanValue() ? 1f : 0f)), FLOAT, BOOLEAN) - ); - } - - private static DefaultFunctionResolver castToDouble() { - return FunctionDSL.define(BuiltinFunctionName.CAST_TO_DOUBLE.getName(), - impl(nullMissingHandling( - (v) -> new ExprDoubleValue(Double.valueOf(v.stringValue()))), DOUBLE, STRING), - impl(nullMissingHandling( - (v) -> new ExprDoubleValue(v.doubleValue())), DOUBLE, DOUBLE), - impl(nullMissingHandling( - (v) -> new ExprDoubleValue(v.booleanValue() ? 1D : 0D)), DOUBLE, BOOLEAN) - ); - } - - private static DefaultFunctionResolver castToBoolean() { - return FunctionDSL.define(BuiltinFunctionName.CAST_TO_BOOLEAN.getName(), - impl(nullMissingHandling( - (v) -> ExprBooleanValue.of(Boolean.valueOf(v.stringValue()))), BOOLEAN, STRING), - impl(nullMissingHandling( - (v) -> ExprBooleanValue.of(v.doubleValue() != 0)), BOOLEAN, DOUBLE), - impl(nullMissingHandling((v) -> v), BOOLEAN, BOOLEAN) - ); - } - - private static DefaultFunctionResolver castToDate() { - return FunctionDSL.define(BuiltinFunctionName.CAST_TO_DATE.getName(), - impl(nullMissingHandling( - (v) -> new ExprDateValue(v.stringValue())), DATE, STRING), - impl(nullMissingHandling( - (v) -> new ExprDateValue(v.dateValue())), DATE, DATETIME), - impl(nullMissingHandling( - (v) -> new ExprDateValue(v.dateValue())), DATE, TIMESTAMP), - impl(nullMissingHandling((v) -> v), DATE, DATE) - ); - } - - private static DefaultFunctionResolver castToTime() { - return FunctionDSL.define(BuiltinFunctionName.CAST_TO_TIME.getName(), - impl(nullMissingHandling( - (v) -> new ExprTimeValue(v.stringValue())), TIME, STRING), - impl(nullMissingHandling( - (v) -> new ExprTimeValue(v.timeValue())), TIME, DATETIME), - impl(nullMissingHandling( - (v) -> new ExprTimeValue(v.timeValue())), TIME, TIMESTAMP), - impl(nullMissingHandling((v) -> v), TIME, TIME) - ); - } - - // `DATE`/`TIME`/`DATETIME` -> `DATETIME`/TIMESTAMP` cast tested in BinaryPredicateOperatorTest - private static DefaultFunctionResolver castToTimestamp() { - return FunctionDSL.define(BuiltinFunctionName.CAST_TO_TIMESTAMP.getName(), - impl(nullMissingHandling( - (v) -> new ExprTimestampValue(v.stringValue())), TIMESTAMP, STRING), - impl(nullMissingHandling( - (v) -> new ExprTimestampValue(v.timestampValue())), TIMESTAMP, DATETIME), - impl(nullMissingHandling( - (v) -> new ExprTimestampValue(v.timestampValue())), TIMESTAMP, DATE), - implWithProperties(nullMissingHandlingWithProperties( - (fp, v) -> new ExprTimestampValue(((ExprTimeValue)v).timestampValue(fp))), - TIMESTAMP, TIME), - impl(nullMissingHandling((v) -> v), TIMESTAMP, TIMESTAMP) - ); - } - - private static DefaultFunctionResolver castToDatetime() { - return FunctionDSL.define(BuiltinFunctionName.CAST_TO_DATETIME.getName(), - impl(nullMissingHandling( - (v) -> new ExprDatetimeValue(v.stringValue())), DATETIME, STRING), - impl(nullMissingHandling( - (v) -> new ExprDatetimeValue(v.datetimeValue())), DATETIME, TIMESTAMP), - impl(nullMissingHandling( - (v) -> new ExprDatetimeValue(v.datetimeValue())), DATETIME, DATE), - implWithProperties(nullMissingHandlingWithProperties( - (fp, v) -> new ExprDatetimeValue(((ExprTimeValue)v).datetimeValue(fp))), - DATETIME, TIME), - impl(nullMissingHandling((v) -> v), DATETIME, DATETIME) - ); - } + /** + * Register Type Cast Operator. + */ + public static void register(BuiltinFunctionRepository repository) { + repository.register(castToString()); + repository.register(castToByte()); + repository.register(castToShort()); + repository.register(castToInt()); + repository.register(castToLong()); + repository.register(castToFloat()); + repository.register(castToDouble()); + repository.register(castToBoolean()); + repository.register(castToDate()); + repository.register(castToTime()); + repository.register(castToTimestamp()); + repository.register(castToDatetime()); + } + + private static DefaultFunctionResolver castToString() { + return FunctionDSL.define( + BuiltinFunctionName.CAST_TO_STRING.getName(), + Stream.concat( + Arrays.asList(BYTE, SHORT, INTEGER, LONG, FLOAT, DOUBLE, BOOLEAN, TIME, DATE, TIMESTAMP, DATETIME) + .stream() + .map(type -> impl(nullMissingHandling((v) -> new ExprStringValue(v.value().toString())), STRING, type)), + Stream.of(impl(nullMissingHandling((v) -> v), STRING, STRING)) + ).collect(Collectors.toList()) + ); + } + + private static DefaultFunctionResolver castToByte() { + return FunctionDSL.define( + BuiltinFunctionName.CAST_TO_BYTE.getName(), + impl(nullMissingHandling((v) -> new ExprByteValue(Byte.valueOf(v.stringValue()))), BYTE, STRING), + impl(nullMissingHandling((v) -> new ExprByteValue(v.byteValue())), BYTE, DOUBLE), + impl(nullMissingHandling((v) -> new ExprByteValue(v.booleanValue() ? 1 : 0)), BYTE, BOOLEAN) + ); + } + + private static DefaultFunctionResolver castToShort() { + return FunctionDSL.define( + BuiltinFunctionName.CAST_TO_SHORT.getName(), + impl(nullMissingHandling((v) -> new ExprShortValue(Short.valueOf(v.stringValue()))), SHORT, STRING), + impl(nullMissingHandling((v) -> new ExprShortValue(v.shortValue())), SHORT, DOUBLE), + impl(nullMissingHandling((v) -> new ExprShortValue(v.booleanValue() ? 1 : 0)), SHORT, BOOLEAN) + ); + } + + private static DefaultFunctionResolver castToInt() { + return FunctionDSL.define( + BuiltinFunctionName.CAST_TO_INT.getName(), + impl(nullMissingHandling((v) -> new ExprIntegerValue(Integer.valueOf(v.stringValue()))), INTEGER, STRING), + impl(nullMissingHandling((v) -> new ExprIntegerValue(v.integerValue())), INTEGER, DOUBLE), + impl(nullMissingHandling((v) -> new ExprIntegerValue(v.booleanValue() ? 1 : 0)), INTEGER, BOOLEAN) + ); + } + + private static DefaultFunctionResolver castToLong() { + return FunctionDSL.define( + BuiltinFunctionName.CAST_TO_LONG.getName(), + impl(nullMissingHandling((v) -> new ExprLongValue(Long.valueOf(v.stringValue()))), LONG, STRING), + impl(nullMissingHandling((v) -> new ExprLongValue(v.longValue())), LONG, DOUBLE), + impl(nullMissingHandling((v) -> new ExprLongValue(v.booleanValue() ? 1L : 0L)), LONG, BOOLEAN) + ); + } + + private static DefaultFunctionResolver castToFloat() { + return FunctionDSL.define( + BuiltinFunctionName.CAST_TO_FLOAT.getName(), + impl(nullMissingHandling((v) -> new ExprFloatValue(Float.valueOf(v.stringValue()))), FLOAT, STRING), + impl(nullMissingHandling((v) -> new ExprFloatValue(v.floatValue())), FLOAT, DOUBLE), + impl(nullMissingHandling((v) -> new ExprFloatValue(v.booleanValue() ? 1f : 0f)), FLOAT, BOOLEAN) + ); + } + + private static DefaultFunctionResolver castToDouble() { + return FunctionDSL.define( + BuiltinFunctionName.CAST_TO_DOUBLE.getName(), + impl(nullMissingHandling((v) -> new ExprDoubleValue(Double.valueOf(v.stringValue()))), DOUBLE, STRING), + impl(nullMissingHandling((v) -> new ExprDoubleValue(v.doubleValue())), DOUBLE, DOUBLE), + impl(nullMissingHandling((v) -> new ExprDoubleValue(v.booleanValue() ? 1D : 0D)), DOUBLE, BOOLEAN) + ); + } + + private static DefaultFunctionResolver castToBoolean() { + return FunctionDSL.define( + BuiltinFunctionName.CAST_TO_BOOLEAN.getName(), + impl(nullMissingHandling((v) -> ExprBooleanValue.of(Boolean.valueOf(v.stringValue()))), BOOLEAN, STRING), + impl(nullMissingHandling((v) -> ExprBooleanValue.of(v.doubleValue() != 0)), BOOLEAN, DOUBLE), + impl(nullMissingHandling((v) -> v), BOOLEAN, BOOLEAN) + ); + } + + private static DefaultFunctionResolver castToDate() { + return FunctionDSL.define( + BuiltinFunctionName.CAST_TO_DATE.getName(), + impl(nullMissingHandling((v) -> new ExprDateValue(v.stringValue())), DATE, STRING), + impl(nullMissingHandling((v) -> new ExprDateValue(v.dateValue())), DATE, DATETIME), + impl(nullMissingHandling((v) -> new ExprDateValue(v.dateValue())), DATE, TIMESTAMP), + impl(nullMissingHandling((v) -> v), DATE, DATE) + ); + } + + private static DefaultFunctionResolver castToTime() { + return FunctionDSL.define( + BuiltinFunctionName.CAST_TO_TIME.getName(), + impl(nullMissingHandling((v) -> new ExprTimeValue(v.stringValue())), TIME, STRING), + impl(nullMissingHandling((v) -> new ExprTimeValue(v.timeValue())), TIME, DATETIME), + impl(nullMissingHandling((v) -> new ExprTimeValue(v.timeValue())), TIME, TIMESTAMP), + impl(nullMissingHandling((v) -> v), TIME, TIME) + ); + } + + // `DATE`/`TIME`/`DATETIME` -> `DATETIME`/TIMESTAMP` cast tested in BinaryPredicateOperatorTest + private static DefaultFunctionResolver castToTimestamp() { + return FunctionDSL.define( + BuiltinFunctionName.CAST_TO_TIMESTAMP.getName(), + impl(nullMissingHandling((v) -> new ExprTimestampValue(v.stringValue())), TIMESTAMP, STRING), + impl(nullMissingHandling((v) -> new ExprTimestampValue(v.timestampValue())), TIMESTAMP, DATETIME), + impl(nullMissingHandling((v) -> new ExprTimestampValue(v.timestampValue())), TIMESTAMP, DATE), + implWithProperties( + nullMissingHandlingWithProperties((fp, v) -> new ExprTimestampValue(((ExprTimeValue) v).timestampValue(fp))), + TIMESTAMP, + TIME + ), + impl(nullMissingHandling((v) -> v), TIMESTAMP, TIMESTAMP) + ); + } + + private static DefaultFunctionResolver castToDatetime() { + return FunctionDSL.define( + BuiltinFunctionName.CAST_TO_DATETIME.getName(), + impl(nullMissingHandling((v) -> new ExprDatetimeValue(v.stringValue())), DATETIME, STRING), + impl(nullMissingHandling((v) -> new ExprDatetimeValue(v.datetimeValue())), DATETIME, TIMESTAMP), + impl(nullMissingHandling((v) -> new ExprDatetimeValue(v.datetimeValue())), DATETIME, DATE), + implWithProperties( + nullMissingHandlingWithProperties((fp, v) -> new ExprDatetimeValue(((ExprTimeValue) v).datetimeValue(fp))), + DATETIME, + TIME + ), + impl(nullMissingHandling((v) -> v), DATETIME, DATETIME) + ); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/predicate/BinaryPredicateOperator.java b/core/src/main/java/org/opensearch/sql/expression/operator/predicate/BinaryPredicateOperator.java index cc5b47bde1..355d09f4c9 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/predicate/BinaryPredicateOperator.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/predicate/BinaryPredicateOperator.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.operator.predicate; import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_FALSE; @@ -38,196 +37,217 @@ */ @UtilityClass public class BinaryPredicateOperator { - /** - * Register Binary Predicate Function. - * - * @param repository {@link BuiltinFunctionRepository}. - */ - public static void register(BuiltinFunctionRepository repository) { - repository.register(and()); - repository.register(or()); - repository.register(xor()); - repository.register(equal()); - repository.register(notEqual()); - repository.register(less()); - repository.register(lte()); - repository.register(greater()); - repository.register(gte()); - repository.register(like()); - repository.register(notLike()); - repository.register(regexp()); - } - - /** - * The and logic. - * A B A AND B - * TRUE TRUE TRUE - * TRUE FALSE FALSE - * TRUE NULL NULL - * TRUE MISSING MISSING - * FALSE FALSE FALSE - * FALSE NULL FALSE - * FALSE MISSING FALSE - * NULL NULL NULL - * NULL MISSING MISSING - * MISSING MISSING MISSING - */ - private static Table andTable = - new ImmutableTable.Builder() - .put(LITERAL_TRUE, LITERAL_TRUE, LITERAL_TRUE) - .put(LITERAL_TRUE, LITERAL_FALSE, LITERAL_FALSE) - .put(LITERAL_TRUE, LITERAL_NULL, LITERAL_NULL) - .put(LITERAL_TRUE, LITERAL_MISSING, LITERAL_MISSING) - .put(LITERAL_FALSE, LITERAL_FALSE, LITERAL_FALSE) - .put(LITERAL_FALSE, LITERAL_NULL, LITERAL_FALSE) - .put(LITERAL_FALSE, LITERAL_MISSING, LITERAL_FALSE) - .put(LITERAL_NULL, LITERAL_NULL, LITERAL_NULL) - .put(LITERAL_NULL, LITERAL_MISSING, LITERAL_MISSING) - .put(LITERAL_MISSING, LITERAL_MISSING, LITERAL_MISSING) - .build(); - - /** - * The or logic. - * A B A AND B - * TRUE TRUE TRUE - * TRUE FALSE TRUE - * TRUE NULL TRUE - * TRUE MISSING TRUE - * FALSE FALSE FALSE - * FALSE NULL NULL - * FALSE MISSING MISSING - * NULL NULL NULL - * NULL MISSING NULL - * MISSING MISSING MISSING - */ - private static Table orTable = - new ImmutableTable.Builder() - .put(LITERAL_TRUE, LITERAL_TRUE, LITERAL_TRUE) - .put(LITERAL_TRUE, LITERAL_FALSE, LITERAL_TRUE) - .put(LITERAL_TRUE, LITERAL_NULL, LITERAL_TRUE) - .put(LITERAL_TRUE, LITERAL_MISSING, LITERAL_TRUE) - .put(LITERAL_FALSE, LITERAL_FALSE, LITERAL_FALSE) - .put(LITERAL_FALSE, LITERAL_NULL, LITERAL_NULL) - .put(LITERAL_FALSE, LITERAL_MISSING, LITERAL_MISSING) - .put(LITERAL_NULL, LITERAL_NULL, LITERAL_NULL) - .put(LITERAL_NULL, LITERAL_MISSING, LITERAL_NULL) - .put(LITERAL_MISSING, LITERAL_MISSING, LITERAL_MISSING) - .build(); - - /** - * The xor logic. - * A B A AND B - * TRUE TRUE FALSE - * TRUE FALSE TRUE - * TRUE NULL TRUE - * TRUE MISSING TRUE - * FALSE FALSE FALSE - * FALSE NULL NULL - * FALSE MISSING MISSING - * NULL NULL NULL - * NULL MISSING NULL - * MISSING MISSING MISSING - */ - private static Table xorTable = - new ImmutableTable.Builder() - .put(LITERAL_TRUE, LITERAL_TRUE, LITERAL_FALSE) - .put(LITERAL_TRUE, LITERAL_FALSE, LITERAL_TRUE) - .put(LITERAL_TRUE, LITERAL_NULL, LITERAL_TRUE) - .put(LITERAL_TRUE, LITERAL_MISSING, LITERAL_TRUE) - .put(LITERAL_FALSE, LITERAL_FALSE, LITERAL_FALSE) - .put(LITERAL_FALSE, LITERAL_NULL, LITERAL_NULL) - .put(LITERAL_FALSE, LITERAL_MISSING, LITERAL_MISSING) - .put(LITERAL_NULL, LITERAL_NULL, LITERAL_NULL) - .put(LITERAL_NULL, LITERAL_MISSING, LITERAL_NULL) - .put(LITERAL_MISSING, LITERAL_MISSING, LITERAL_MISSING) - .build(); - - private static DefaultFunctionResolver and() { - return define(BuiltinFunctionName.AND.getName(), - impl((v1, v2) -> lookupTableFunction(v1, v2, andTable), BOOLEAN, BOOLEAN, BOOLEAN)); - } - - private static DefaultFunctionResolver or() { - return define(BuiltinFunctionName.OR.getName(), - impl((v1, v2) -> lookupTableFunction(v1, v2, orTable), BOOLEAN, BOOLEAN, BOOLEAN)); - } - - private static DefaultFunctionResolver xor() { - return define(BuiltinFunctionName.XOR.getName(), - impl((v1, v2) -> lookupTableFunction(v1, v2, xorTable), BOOLEAN, BOOLEAN, BOOLEAN)); - } - - private static DefaultFunctionResolver equal() { - return define(BuiltinFunctionName.EQUAL.getName(), ExprCoreType.coreTypes().stream() - .map(type -> impl(nullMissingHandling( - (v1, v2) -> ExprBooleanValue.of(v1.equals(v2))), - BOOLEAN, type, type)) - .collect(Collectors.toList())); - } - - private static DefaultFunctionResolver notEqual() { - return define(BuiltinFunctionName.NOTEQUAL.getName(), ExprCoreType.coreTypes().stream() - .map(type -> impl(nullMissingHandling( - (v1, v2) -> ExprBooleanValue.of(!v1.equals(v2))), - BOOLEAN, type, type)) - .collect(Collectors.toList())); - } - - private static DefaultFunctionResolver less() { - return define(BuiltinFunctionName.LESS.getName(), ExprCoreType.coreTypes().stream() - .map(type -> impl(nullMissingHandling( - (v1, v2) -> ExprBooleanValue.of(v1.compareTo(v2) < 0)), - BOOLEAN,type, type)) - .collect(Collectors.toList())); - } - - private static DefaultFunctionResolver lte() { - return define(BuiltinFunctionName.LTE.getName(), ExprCoreType.coreTypes().stream() - .map(type -> impl(nullMissingHandling( - (v1, v2) -> ExprBooleanValue.of(v1.compareTo(v2) <= 0)), - BOOLEAN, type, type)) - .collect(Collectors.toList())); - } - - private static DefaultFunctionResolver greater() { - return define(BuiltinFunctionName.GREATER.getName(), ExprCoreType.coreTypes().stream() - .map(type -> impl(nullMissingHandling( - (v1, v2) -> ExprBooleanValue.of(v1.compareTo(v2) > 0)), - BOOLEAN, type, type)) - .collect(Collectors.toList())); - } - - private static DefaultFunctionResolver gte() { - return define(BuiltinFunctionName.GTE.getName(), ExprCoreType.coreTypes().stream() - .map(type -> impl(nullMissingHandling( - (v1, v2) -> ExprBooleanValue.of(v1.compareTo(v2) >= 0)), - BOOLEAN, type, type)) - .collect(Collectors.toList())); - } - - private static DefaultFunctionResolver like() { - return define(BuiltinFunctionName.LIKE.getName(), - impl(nullMissingHandling(OperatorUtils::matches), BOOLEAN, STRING, STRING)); - } - - private static DefaultFunctionResolver regexp() { - return define(BuiltinFunctionName.REGEXP.getName(), - impl(nullMissingHandling(OperatorUtils::matchesRegexp), INTEGER, STRING, STRING)); - } - - private static DefaultFunctionResolver notLike() { - return define(BuiltinFunctionName.NOT_LIKE.getName(), - impl(nullMissingHandling( - (v1, v2) -> UnaryPredicateOperator.not(OperatorUtils.matches(v1, v2))), - BOOLEAN, STRING, STRING)); - } - - private static ExprValue lookupTableFunction(ExprValue arg1, ExprValue arg2, - Table table) { - if (table.contains(arg1, arg2)) { - return table.get(arg1, arg2); - } else { - return table.get(arg2, arg1); + /** + * Register Binary Predicate Function. + * + * @param repository {@link BuiltinFunctionRepository}. + */ + public static void register(BuiltinFunctionRepository repository) { + repository.register(and()); + repository.register(or()); + repository.register(xor()); + repository.register(equal()); + repository.register(notEqual()); + repository.register(less()); + repository.register(lte()); + repository.register(greater()); + repository.register(gte()); + repository.register(like()); + repository.register(notLike()); + repository.register(regexp()); + } + + /** + * The and logic. + * A B A AND B + * TRUE TRUE TRUE + * TRUE FALSE FALSE + * TRUE NULL NULL + * TRUE MISSING MISSING + * FALSE FALSE FALSE + * FALSE NULL FALSE + * FALSE MISSING FALSE + * NULL NULL NULL + * NULL MISSING MISSING + * MISSING MISSING MISSING + */ + private static Table andTable = new ImmutableTable.Builder().put( + LITERAL_TRUE, + LITERAL_TRUE, + LITERAL_TRUE + ) + .put(LITERAL_TRUE, LITERAL_FALSE, LITERAL_FALSE) + .put(LITERAL_TRUE, LITERAL_NULL, LITERAL_NULL) + .put(LITERAL_TRUE, LITERAL_MISSING, LITERAL_MISSING) + .put(LITERAL_FALSE, LITERAL_FALSE, LITERAL_FALSE) + .put(LITERAL_FALSE, LITERAL_NULL, LITERAL_FALSE) + .put(LITERAL_FALSE, LITERAL_MISSING, LITERAL_FALSE) + .put(LITERAL_NULL, LITERAL_NULL, LITERAL_NULL) + .put(LITERAL_NULL, LITERAL_MISSING, LITERAL_MISSING) + .put(LITERAL_MISSING, LITERAL_MISSING, LITERAL_MISSING) + .build(); + + /** + * The or logic. + * A B A AND B + * TRUE TRUE TRUE + * TRUE FALSE TRUE + * TRUE NULL TRUE + * TRUE MISSING TRUE + * FALSE FALSE FALSE + * FALSE NULL NULL + * FALSE MISSING MISSING + * NULL NULL NULL + * NULL MISSING NULL + * MISSING MISSING MISSING + */ + private static Table orTable = new ImmutableTable.Builder().put( + LITERAL_TRUE, + LITERAL_TRUE, + LITERAL_TRUE + ) + .put(LITERAL_TRUE, LITERAL_FALSE, LITERAL_TRUE) + .put(LITERAL_TRUE, LITERAL_NULL, LITERAL_TRUE) + .put(LITERAL_TRUE, LITERAL_MISSING, LITERAL_TRUE) + .put(LITERAL_FALSE, LITERAL_FALSE, LITERAL_FALSE) + .put(LITERAL_FALSE, LITERAL_NULL, LITERAL_NULL) + .put(LITERAL_FALSE, LITERAL_MISSING, LITERAL_MISSING) + .put(LITERAL_NULL, LITERAL_NULL, LITERAL_NULL) + .put(LITERAL_NULL, LITERAL_MISSING, LITERAL_NULL) + .put(LITERAL_MISSING, LITERAL_MISSING, LITERAL_MISSING) + .build(); + + /** + * The xor logic. + * A B A AND B + * TRUE TRUE FALSE + * TRUE FALSE TRUE + * TRUE NULL TRUE + * TRUE MISSING TRUE + * FALSE FALSE FALSE + * FALSE NULL NULL + * FALSE MISSING MISSING + * NULL NULL NULL + * NULL MISSING NULL + * MISSING MISSING MISSING + */ + private static Table xorTable = new ImmutableTable.Builder().put( + LITERAL_TRUE, + LITERAL_TRUE, + LITERAL_FALSE + ) + .put(LITERAL_TRUE, LITERAL_FALSE, LITERAL_TRUE) + .put(LITERAL_TRUE, LITERAL_NULL, LITERAL_TRUE) + .put(LITERAL_TRUE, LITERAL_MISSING, LITERAL_TRUE) + .put(LITERAL_FALSE, LITERAL_FALSE, LITERAL_FALSE) + .put(LITERAL_FALSE, LITERAL_NULL, LITERAL_NULL) + .put(LITERAL_FALSE, LITERAL_MISSING, LITERAL_MISSING) + .put(LITERAL_NULL, LITERAL_NULL, LITERAL_NULL) + .put(LITERAL_NULL, LITERAL_MISSING, LITERAL_NULL) + .put(LITERAL_MISSING, LITERAL_MISSING, LITERAL_MISSING) + .build(); + + private static DefaultFunctionResolver and() { + return define( + BuiltinFunctionName.AND.getName(), + impl((v1, v2) -> lookupTableFunction(v1, v2, andTable), BOOLEAN, BOOLEAN, BOOLEAN) + ); + } + + private static DefaultFunctionResolver or() { + return define(BuiltinFunctionName.OR.getName(), impl((v1, v2) -> lookupTableFunction(v1, v2, orTable), BOOLEAN, BOOLEAN, BOOLEAN)); + } + + private static DefaultFunctionResolver xor() { + return define( + BuiltinFunctionName.XOR.getName(), + impl((v1, v2) -> lookupTableFunction(v1, v2, xorTable), BOOLEAN, BOOLEAN, BOOLEAN) + ); + } + + private static DefaultFunctionResolver equal() { + return define( + BuiltinFunctionName.EQUAL.getName(), + ExprCoreType.coreTypes() + .stream() + .map(type -> impl(nullMissingHandling((v1, v2) -> ExprBooleanValue.of(v1.equals(v2))), BOOLEAN, type, type)) + .collect(Collectors.toList()) + ); + } + + private static DefaultFunctionResolver notEqual() { + return define( + BuiltinFunctionName.NOTEQUAL.getName(), + ExprCoreType.coreTypes() + .stream() + .map(type -> impl(nullMissingHandling((v1, v2) -> ExprBooleanValue.of(!v1.equals(v2))), BOOLEAN, type, type)) + .collect(Collectors.toList()) + ); + } + + private static DefaultFunctionResolver less() { + return define( + BuiltinFunctionName.LESS.getName(), + ExprCoreType.coreTypes() + .stream() + .map(type -> impl(nullMissingHandling((v1, v2) -> ExprBooleanValue.of(v1.compareTo(v2) < 0)), BOOLEAN, type, type)) + .collect(Collectors.toList()) + ); + } + + private static DefaultFunctionResolver lte() { + return define( + BuiltinFunctionName.LTE.getName(), + ExprCoreType.coreTypes() + .stream() + .map(type -> impl(nullMissingHandling((v1, v2) -> ExprBooleanValue.of(v1.compareTo(v2) <= 0)), BOOLEAN, type, type)) + .collect(Collectors.toList()) + ); + } + + private static DefaultFunctionResolver greater() { + return define( + BuiltinFunctionName.GREATER.getName(), + ExprCoreType.coreTypes() + .stream() + .map(type -> impl(nullMissingHandling((v1, v2) -> ExprBooleanValue.of(v1.compareTo(v2) > 0)), BOOLEAN, type, type)) + .collect(Collectors.toList()) + ); + } + + private static DefaultFunctionResolver gte() { + return define( + BuiltinFunctionName.GTE.getName(), + ExprCoreType.coreTypes() + .stream() + .map(type -> impl(nullMissingHandling((v1, v2) -> ExprBooleanValue.of(v1.compareTo(v2) >= 0)), BOOLEAN, type, type)) + .collect(Collectors.toList()) + ); + } + + private static DefaultFunctionResolver like() { + return define(BuiltinFunctionName.LIKE.getName(), impl(nullMissingHandling(OperatorUtils::matches), BOOLEAN, STRING, STRING)); + } + + private static DefaultFunctionResolver regexp() { + return define( + BuiltinFunctionName.REGEXP.getName(), + impl(nullMissingHandling(OperatorUtils::matchesRegexp), INTEGER, STRING, STRING) + ); + } + + private static DefaultFunctionResolver notLike() { + return define( + BuiltinFunctionName.NOT_LIKE.getName(), + impl(nullMissingHandling((v1, v2) -> UnaryPredicateOperator.not(OperatorUtils.matches(v1, v2))), BOOLEAN, STRING, STRING) + ); + } + + private static ExprValue lookupTableFunction(ExprValue arg1, ExprValue arg2, Table table) { + if (table.contains(arg1, arg2)) { + return table.get(arg1, arg2); + } else { + return table.get(arg2, arg1); + } } - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/predicate/UnaryPredicateOperator.java b/core/src/main/java/org/opensearch/sql/expression/operator/predicate/UnaryPredicateOperator.java index 7d79d9d923..e74124fb0a 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/predicate/UnaryPredicateOperator.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/predicate/UnaryPredicateOperator.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.operator.predicate; import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_NULL; @@ -33,118 +32,112 @@ */ @UtilityClass public class UnaryPredicateOperator { - /** - * Register Unary Predicate Function. - */ - public static void register(BuiltinFunctionRepository repository) { - repository.register(not()); - repository.register(isNotNull()); - repository.register(ifNull()); - repository.register(nullIf()); - repository.register(isNull(BuiltinFunctionName.IS_NULL)); - repository.register(isNull(BuiltinFunctionName.ISNULL)); - repository.register(ifFunction()); - } - - private static DefaultFunctionResolver not() { - return FunctionDSL.define(BuiltinFunctionName.NOT.getName(), FunctionDSL - .impl(UnaryPredicateOperator::not, BOOLEAN, BOOLEAN)); - } - - /** - * The not logic. - * A NOT A - * TRUE FALSE - * FALSE TRUE - * NULL NULL - * MISSING MISSING - */ - public ExprValue not(ExprValue v) { - if (v.isMissing() || v.isNull()) { - return v; - } else { - return ExprBooleanValue.of(!v.booleanValue()); + /** + * Register Unary Predicate Function. + */ + public static void register(BuiltinFunctionRepository repository) { + repository.register(not()); + repository.register(isNotNull()); + repository.register(ifNull()); + repository.register(nullIf()); + repository.register(isNull(BuiltinFunctionName.IS_NULL)); + repository.register(isNull(BuiltinFunctionName.ISNULL)); + repository.register(ifFunction()); + } + + private static DefaultFunctionResolver not() { + return FunctionDSL.define(BuiltinFunctionName.NOT.getName(), FunctionDSL.impl(UnaryPredicateOperator::not, BOOLEAN, BOOLEAN)); + } + + /** + * The not logic. + * A NOT A + * TRUE FALSE + * FALSE TRUE + * NULL NULL + * MISSING MISSING + */ + public ExprValue not(ExprValue v) { + if (v.isMissing() || v.isNull()) { + return v; + } else { + return ExprBooleanValue.of(!v.booleanValue()); + } + } + + private static DefaultFunctionResolver isNull(BuiltinFunctionName funcName) { + return FunctionDSL.define( + funcName.getName(), + Arrays.stream(ExprCoreType.values()) + .map(type -> FunctionDSL.impl((v) -> ExprBooleanValue.of(v.isNull()), BOOLEAN, type)) + .collect(Collectors.toList()) + ); + } + + private static DefaultFunctionResolver isNotNull() { + return FunctionDSL.define( + BuiltinFunctionName.IS_NOT_NULL.getName(), + Arrays.stream(ExprCoreType.values()) + .map(type -> FunctionDSL.impl((v) -> ExprBooleanValue.of(!v.isNull()), BOOLEAN, type)) + .collect(Collectors.toList()) + ); + } + + private static DefaultFunctionResolver ifFunction() { + FunctionName functionName = BuiltinFunctionName.IF.getName(); + List typeList = ExprCoreType.coreTypes(); + + List>> functionsOne = + typeList.stream().map(v -> impl((UnaryPredicateOperator::exprIf), v, BOOLEAN, v, v)).collect(Collectors.toList()); + + DefaultFunctionResolver functionResolver = FunctionDSL.define(functionName, functionsOne); + return functionResolver; + } + + private static DefaultFunctionResolver ifNull() { + FunctionName functionName = BuiltinFunctionName.IFNULL.getName(); + List typeList = ExprCoreType.coreTypes(); + + List>> functionsOne = + typeList.stream().map(v -> impl((UnaryPredicateOperator::exprIfNull), v, v, v)).collect(Collectors.toList()); + + DefaultFunctionResolver functionResolver = FunctionDSL.define(functionName, functionsOne); + return functionResolver; + } + + private static DefaultFunctionResolver nullIf() { + FunctionName functionName = BuiltinFunctionName.NULLIF.getName(); + List typeList = ExprCoreType.coreTypes(); + + DefaultFunctionResolver functionResolver = FunctionDSL.define( + functionName, + typeList.stream().map(v -> impl((UnaryPredicateOperator::exprNullIf), v, v, v)).collect(Collectors.toList()) + ); + return functionResolver; + } + + /** v2 if v1 is null. + * + * @param v1 varable 1 + * @param v2 varable 2 + * @return v2 if v1 is null + */ + public static ExprValue exprIfNull(ExprValue v1, ExprValue v2) { + return (v1.isNull() || v1.isMissing()) ? v2 : v1; + } + + /** return null if v1 equls to v2. + * + * @param v1 varable 1 + * @param v2 varable 2 + * @return null if v1 equls to v2 + */ + public static ExprValue exprNullIf(ExprValue v1, ExprValue v2) { + return v1.equals(v2) ? LITERAL_NULL : v1; + } + + public static ExprValue exprIf(ExprValue v1, ExprValue v2, ExprValue v3) { + return !v1.isNull() && !v1.isMissing() && LITERAL_TRUE.equals(v1) ? v2 : v3; } - } - - private static DefaultFunctionResolver isNull(BuiltinFunctionName funcName) { - return FunctionDSL - .define(funcName.getName(), Arrays.stream(ExprCoreType.values()) - .map(type -> FunctionDSL - .impl((v) -> ExprBooleanValue.of(v.isNull()), BOOLEAN, type)) - .collect( - Collectors.toList())); - } - - private static DefaultFunctionResolver isNotNull() { - return FunctionDSL - .define(BuiltinFunctionName.IS_NOT_NULL.getName(), Arrays.stream(ExprCoreType.values()) - .map(type -> FunctionDSL - .impl((v) -> ExprBooleanValue.of(!v.isNull()), BOOLEAN, type)) - .collect( - Collectors.toList())); - } - - private static DefaultFunctionResolver ifFunction() { - FunctionName functionName = BuiltinFunctionName.IF.getName(); - List typeList = ExprCoreType.coreTypes(); - - List>> functionsOne = typeList.stream().map(v -> - impl((UnaryPredicateOperator::exprIf), v, BOOLEAN, v, v)) - .collect(Collectors.toList()); - - DefaultFunctionResolver functionResolver = FunctionDSL.define(functionName, functionsOne); - return functionResolver; - } - - private static DefaultFunctionResolver ifNull() { - FunctionName functionName = BuiltinFunctionName.IFNULL.getName(); - List typeList = ExprCoreType.coreTypes(); - - List>> functionsOne = typeList.stream().map(v -> - impl((UnaryPredicateOperator::exprIfNull), v, v, v)) - .collect(Collectors.toList()); - - DefaultFunctionResolver functionResolver = FunctionDSL.define(functionName, functionsOne); - return functionResolver; - } - - private static DefaultFunctionResolver nullIf() { - FunctionName functionName = BuiltinFunctionName.NULLIF.getName(); - List typeList = ExprCoreType.coreTypes(); - - DefaultFunctionResolver functionResolver = - FunctionDSL.define(functionName, - typeList.stream().map(v -> - impl((UnaryPredicateOperator::exprNullIf), v, v, v)) - .collect(Collectors.toList())); - return functionResolver; - } - - /** v2 if v1 is null. - * - * @param v1 varable 1 - * @param v2 varable 2 - * @return v2 if v1 is null - */ - public static ExprValue exprIfNull(ExprValue v1, ExprValue v2) { - return (v1.isNull() || v1.isMissing()) ? v2 : v1; - } - - /** return null if v1 equls to v2. - * - * @param v1 varable 1 - * @param v2 varable 2 - * @return null if v1 equls to v2 - */ - public static ExprValue exprNullIf(ExprValue v1, ExprValue v2) { - return v1.equals(v2) ? LITERAL_NULL : v1; - } - - public static ExprValue exprIf(ExprValue v1, ExprValue v2, ExprValue v3) { - return !v1.isNull() && !v1.isMissing() && LITERAL_TRUE.equals(v1) ? v2 : v3; - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/parse/GrokExpression.java b/core/src/main/java/org/opensearch/sql/expression/parse/GrokExpression.java index 9797832f07..3bfad6a576 100644 --- a/core/src/main/java/org/opensearch/sql/expression/parse/GrokExpression.java +++ b/core/src/main/java/org/opensearch/sql/expression/parse/GrokExpression.java @@ -26,50 +26,52 @@ @EqualsAndHashCode(callSuper = true) @ToString public class GrokExpression extends ParseExpression { - private static final Logger log = LogManager.getLogger(GrokExpression.class); - private static final GrokCompiler grokCompiler = GrokCompiler.newInstance(); + private static final Logger log = LogManager.getLogger(GrokExpression.class); + private static final GrokCompiler grokCompiler = GrokCompiler.newInstance(); - static { - grokCompiler.registerDefaultPatterns(); - } + static { + grokCompiler.registerDefaultPatterns(); + } - @EqualsAndHashCode.Exclude - private final Grok grok; + @EqualsAndHashCode.Exclude + private final Grok grok; - /** - * GrokExpression. - * - * @param sourceField source text field - * @param pattern pattern used for parsing - * @param identifier derived field - */ - public GrokExpression(Expression sourceField, Expression pattern, Expression identifier) { - super("grok", sourceField, pattern, identifier); - this.grok = grokCompiler.compile(pattern.valueOf().stringValue()); - } + /** + * GrokExpression. + * + * @param sourceField source text field + * @param pattern pattern used for parsing + * @param identifier derived field + */ + public GrokExpression(Expression sourceField, Expression pattern, Expression identifier) { + super("grok", sourceField, pattern, identifier); + this.grok = grokCompiler.compile(pattern.valueOf().stringValue()); + } - @Override - ExprValue parseValue(ExprValue value) throws ExpressionEvaluationException { - String rawString = value.stringValue(); - Match grokMatch = grok.match(rawString); - Map capture = grokMatch.capture(); - Object match = capture.get(identifierStr); - if (match != null) { - return new ExprStringValue(match.toString()); + @Override + ExprValue parseValue(ExprValue value) throws ExpressionEvaluationException { + String rawString = value.stringValue(); + Match grokMatch = grok.match(rawString); + Map capture = grokMatch.capture(); + Object match = capture.get(identifierStr); + if (match != null) { + return new ExprStringValue(match.toString()); + } + log.debug("failed to extract pattern {} from input ***", grok.getOriginalGrokPattern()); + return new ExprStringValue(""); } - log.debug("failed to extract pattern {} from input ***", grok.getOriginalGrokPattern()); - return new ExprStringValue(""); - } - /** - * Get list of derived fields based on parse pattern. - * - * @param pattern pattern used for parsing - * @return list of names of the derived fields - */ - public static List getNamedGroupCandidates(String pattern) { - Grok grok = grokCompiler.compile(pattern); - return grok.namedGroups.stream().map(grok::getNamedRegexCollectionById) - .filter(group -> !group.equals("UNWANTED")).collect(Collectors.toUnmodifiableList()); - } + /** + * Get list of derived fields based on parse pattern. + * + * @param pattern pattern used for parsing + * @return list of names of the derived fields + */ + public static List getNamedGroupCandidates(String pattern) { + Grok grok = grokCompiler.compile(pattern); + return grok.namedGroups.stream() + .map(grok::getNamedRegexCollectionById) + .filter(group -> !group.equals("UNWANTED")) + .collect(Collectors.toUnmodifiableList()); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/parse/ParseExpression.java b/core/src/main/java/org/opensearch/sql/expression/parse/ParseExpression.java index 8d1ebcce08..b1d5a743b5 100644 --- a/core/src/main/java/org/opensearch/sql/expression/parse/ParseExpression.java +++ b/core/src/main/java/org/opensearch/sql/expression/parse/ParseExpression.java @@ -27,53 +27,51 @@ @EqualsAndHashCode(callSuper = false) @ToString public abstract class ParseExpression extends FunctionExpression { - @Getter - protected final Expression sourceField; - protected final Expression pattern; - @Getter - protected final Expression identifier; - protected final String identifierStr; + @Getter + protected final Expression sourceField; + protected final Expression pattern; + @Getter + protected final Expression identifier; + protected final String identifierStr; - /** - * ParseExpression. - * - * @param functionName name of function expression - * @param sourceField source text field - * @param pattern pattern used for parsing - * @param identifier derived field - */ - public ParseExpression(String functionName, Expression sourceField, Expression pattern, - Expression identifier) { - super(FunctionName.of(functionName), ImmutableList.of(sourceField, pattern, identifier)); - this.sourceField = sourceField; - this.pattern = pattern; - this.identifier = identifier; - this.identifierStr = identifier.valueOf().stringValue(); - } - - @Override - public ExprValue valueOf(Environment valueEnv) { - ExprValue value = valueEnv.resolve(sourceField); - if (value.isNull() || value.isMissing()) { - return ExprValueUtils.nullValue(); + /** + * ParseExpression. + * + * @param functionName name of function expression + * @param sourceField source text field + * @param pattern pattern used for parsing + * @param identifier derived field + */ + public ParseExpression(String functionName, Expression sourceField, Expression pattern, Expression identifier) { + super(FunctionName.of(functionName), ImmutableList.of(sourceField, pattern, identifier)); + this.sourceField = sourceField; + this.pattern = pattern; + this.identifier = identifier; + this.identifierStr = identifier.valueOf().stringValue(); } - try { - return parseValue(value); - } catch (ExpressionEvaluationException e) { - throw new SemanticCheckException( - String.format("failed to parse field \"%s\" with type [%s]", sourceField, value.type())); + + @Override + public ExprValue valueOf(Environment valueEnv) { + ExprValue value = valueEnv.resolve(sourceField); + if (value.isNull() || value.isMissing()) { + return ExprValueUtils.nullValue(); + } + try { + return parseValue(value); + } catch (ExpressionEvaluationException e) { + throw new SemanticCheckException(String.format("failed to parse field \"%s\" with type [%s]", sourceField, value.type())); + } } - } - @Override - public ExprType type() { - return ExprCoreType.STRING; - } + @Override + public ExprType type() { + return ExprCoreType.STRING; + } - @Override - public T accept(ExpressionNodeVisitor visitor, C context) { - return visitor.visitParse(this, context); - } + @Override + public T accept(ExpressionNodeVisitor visitor, C context) { + return visitor.visitParse(this, context); + } - abstract ExprValue parseValue(ExprValue value) throws ExpressionEvaluationException; + abstract ExprValue parseValue(ExprValue value) throws ExpressionEvaluationException; } diff --git a/core/src/main/java/org/opensearch/sql/expression/parse/PatternsExpression.java b/core/src/main/java/org/opensearch/sql/expression/parse/PatternsExpression.java index 67160dad58..865da8cc00 100644 --- a/core/src/main/java/org/opensearch/sql/expression/parse/PatternsExpression.java +++ b/core/src/main/java/org/opensearch/sql/expression/parse/PatternsExpression.java @@ -23,58 +23,58 @@ @EqualsAndHashCode(callSuper = true) @ToString public class PatternsExpression extends ParseExpression { - /** - * Default name of the derived field. - */ - public static final String DEFAULT_NEW_FIELD = "patterns_field"; + /** + * Default name of the derived field. + */ + public static final String DEFAULT_NEW_FIELD = "patterns_field"; - private static final ImmutableSet DEFAULT_IGNORED_CHARS = ImmutableSet.copyOf( - "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789".chars() - .mapToObj(c -> (char) c).toArray(Character[]::new)); - private final boolean useCustomPattern; - @EqualsAndHashCode.Exclude - private Pattern pattern; + private static final ImmutableSet DEFAULT_IGNORED_CHARS = ImmutableSet.copyOf( + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789".chars().mapToObj(c -> (char) c).toArray(Character[]::new) + ); + private final boolean useCustomPattern; + @EqualsAndHashCode.Exclude + private Pattern pattern; - /** - * PatternsExpression. - * - * @param sourceField source text field - * @param pattern pattern used for parsing - * @param identifier derived field - */ - public PatternsExpression(Expression sourceField, Expression pattern, Expression identifier) { - super("patterns", sourceField, pattern, identifier); - String patternStr = pattern.valueOf().stringValue(); - useCustomPattern = !patternStr.isEmpty(); - if (useCustomPattern) { - this.pattern = Pattern.compile(patternStr); + /** + * PatternsExpression. + * + * @param sourceField source text field + * @param pattern pattern used for parsing + * @param identifier derived field + */ + public PatternsExpression(Expression sourceField, Expression pattern, Expression identifier) { + super("patterns", sourceField, pattern, identifier); + String patternStr = pattern.valueOf().stringValue(); + useCustomPattern = !patternStr.isEmpty(); + if (useCustomPattern) { + this.pattern = Pattern.compile(patternStr); + } } - } - @Override - ExprValue parseValue(ExprValue value) throws ExpressionEvaluationException { - String rawString = value.stringValue(); - if (useCustomPattern) { - return new ExprStringValue(pattern.matcher(rawString).replaceAll("")); - } + @Override + ExprValue parseValue(ExprValue value) throws ExpressionEvaluationException { + String rawString = value.stringValue(); + if (useCustomPattern) { + return new ExprStringValue(pattern.matcher(rawString).replaceAll("")); + } - char[] chars = rawString.toCharArray(); - int pos = 0; - for (int i = 0; i < chars.length; i++) { - if (!DEFAULT_IGNORED_CHARS.contains(chars[i])) { - chars[pos++] = chars[i]; - } + char[] chars = rawString.toCharArray(); + int pos = 0; + for (int i = 0; i < chars.length; i++) { + if (!DEFAULT_IGNORED_CHARS.contains(chars[i])) { + chars[pos++] = chars[i]; + } + } + return new ExprStringValue(new String(chars, 0, pos)); } - return new ExprStringValue(new String(chars, 0, pos)); - } - /** - * Get list of derived fields. - * - * @param identifier identifier used to generate the field name - * @return list of names of the derived fields - */ - public static List getNamedGroupCandidates(String identifier) { - return ImmutableList.of(Objects.requireNonNullElse(identifier, DEFAULT_NEW_FIELD)); - } + /** + * Get list of derived fields. + * + * @param identifier identifier used to generate the field name + * @return list of names of the derived fields + */ + public static List getNamedGroupCandidates(String identifier) { + return ImmutableList.of(Objects.requireNonNullElse(identifier, DEFAULT_NEW_FIELD)); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/parse/RegexExpression.java b/core/src/main/java/org/opensearch/sql/expression/parse/RegexExpression.java index f3a3ff0b66..bc3a6102e7 100644 --- a/core/src/main/java/org/opensearch/sql/expression/parse/RegexExpression.java +++ b/core/src/main/java/org/opensearch/sql/expression/parse/RegexExpression.java @@ -25,47 +25,47 @@ @EqualsAndHashCode(callSuper = true) @ToString public class RegexExpression extends ParseExpression { - private static final Logger log = LogManager.getLogger(RegexExpression.class); - private static final Pattern GROUP_PATTERN = Pattern.compile("\\(\\?<([a-zA-Z][a-zA-Z0-9]*)>"); - @Getter - @EqualsAndHashCode.Exclude - private final Pattern regexPattern; + private static final Logger log = LogManager.getLogger(RegexExpression.class); + private static final Pattern GROUP_PATTERN = Pattern.compile("\\(\\?<([a-zA-Z][a-zA-Z0-9]*)>"); + @Getter + @EqualsAndHashCode.Exclude + private final Pattern regexPattern; - /** - * RegexExpression. - * - * @param sourceField source text field - * @param pattern pattern used for parsing - * @param identifier derived field - */ - public RegexExpression(Expression sourceField, Expression pattern, Expression identifier) { - super("regex", sourceField, pattern, identifier); - this.regexPattern = Pattern.compile(pattern.valueOf().stringValue()); - } + /** + * RegexExpression. + * + * @param sourceField source text field + * @param pattern pattern used for parsing + * @param identifier derived field + */ + public RegexExpression(Expression sourceField, Expression pattern, Expression identifier) { + super("regex", sourceField, pattern, identifier); + this.regexPattern = Pattern.compile(pattern.valueOf().stringValue()); + } - @Override - ExprValue parseValue(ExprValue value) throws ExpressionEvaluationException { - String rawString = value.stringValue(); - Matcher matcher = regexPattern.matcher(rawString); - if (matcher.matches()) { - return new ExprStringValue(matcher.group(identifierStr)); + @Override + ExprValue parseValue(ExprValue value) throws ExpressionEvaluationException { + String rawString = value.stringValue(); + Matcher matcher = regexPattern.matcher(rawString); + if (matcher.matches()) { + return new ExprStringValue(matcher.group(identifierStr)); + } + log.debug("failed to extract pattern {} from input ***", regexPattern.pattern()); + return new ExprStringValue(""); } - log.debug("failed to extract pattern {} from input ***", regexPattern.pattern()); - return new ExprStringValue(""); - } - /** - * Get list of derived fields based on parse pattern. - * - * @param pattern pattern used for parsing - * @return list of names of the derived fields - */ - public static List getNamedGroupCandidates(String pattern) { - ImmutableList.Builder namedGroups = ImmutableList.builder(); - Matcher m = GROUP_PATTERN.matcher(pattern); - while (m.find()) { - namedGroups.add(m.group(1)); + /** + * Get list of derived fields based on parse pattern. + * + * @param pattern pattern used for parsing + * @return list of names of the derived fields + */ + public static List getNamedGroupCandidates(String pattern) { + ImmutableList.Builder namedGroups = ImmutableList.builder(); + Matcher m = GROUP_PATTERN.matcher(pattern); + while (m.find()) { + namedGroups.add(m.group(1)); + } + return namedGroups.build(); } - return namedGroups.build(); - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/span/SpanExpression.java b/core/src/main/java/org/opensearch/sql/expression/span/SpanExpression.java index aff114145e..0218551696 100644 --- a/core/src/main/java/org/opensearch/sql/expression/span/SpanExpression.java +++ b/core/src/main/java/org/opensearch/sql/expression/span/SpanExpression.java @@ -20,47 +20,47 @@ @ToString @EqualsAndHashCode public class SpanExpression implements Expression { - private final Expression field; - private final Expression value; - private final SpanUnit unit; + private final Expression field; + private final Expression value; + private final SpanUnit unit; - /** - * Construct a span expression by field and span interval expression. - */ - public SpanExpression(Expression field, Expression value, SpanUnit unit) { - this.field = field; - this.value = value; - this.unit = unit; - } + /** + * Construct a span expression by field and span interval expression. + */ + public SpanExpression(Expression field, Expression value, SpanUnit unit) { + this.field = field; + this.value = value; + this.unit = unit; + } - @Override - public ExprValue valueOf(Environment valueEnv) { - Rounding rounding = Rounding.createRounding(this); //TODO: will integrate with WindowAssigner - return rounding.round(field.valueOf(valueEnv)); - } + @Override + public ExprValue valueOf(Environment valueEnv) { + Rounding rounding = Rounding.createRounding(this); // TODO: will integrate with WindowAssigner + return rounding.round(field.valueOf(valueEnv)); + } - /** - * Return type follows the following table. - * FIELD VALUE RETURN_TYPE - * int/long integer int/long (field type) - * int/long double double - * float/double integer float/double (field type) - * float/double double float/double (field type) - * other any field type - */ - @Override - public ExprType type() { - if (field.type().isCompatible(value.type())) { - return field.type(); - } else if (value.type().isCompatible(field.type())) { - return value.type(); - } else { - return field.type(); + /** + * Return type follows the following table. + * FIELD VALUE RETURN_TYPE + * int/long integer int/long (field type) + * int/long double double + * float/double integer float/double (field type) + * float/double double float/double (field type) + * other any field type + */ + @Override + public ExprType type() { + if (field.type().isCompatible(value.type())) { + return field.type(); + } else if (value.type().isCompatible(field.type())) { + return value.type(); + } else { + return field.type(); + } } - } - @Override - public T accept(ExpressionNodeVisitor visitor, C context) { - return visitor.visitNode(this, context); - } + @Override + public T accept(ExpressionNodeVisitor visitor, C context) { + return visitor.visitNode(this, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/system/SystemFunctions.java b/core/src/main/java/org/opensearch/sql/expression/system/SystemFunctions.java index e12bcd0a58..860583b8bf 100644 --- a/core/src/main/java/org/opensearch/sql/expression/system/SystemFunctions.java +++ b/core/src/main/java/org/opensearch/sql/expression/system/SystemFunctions.java @@ -24,38 +24,38 @@ @UtilityClass public class SystemFunctions { - /** - * Register TypeOf Operator. - */ - public static void register(BuiltinFunctionRepository repository) { - repository.register(typeof()); - } + /** + * Register TypeOf Operator. + */ + public static void register(BuiltinFunctionRepository repository) { + repository.register(typeof()); + } - // Auxiliary function useful for debugging - private static FunctionResolver typeof() { - return new FunctionResolver() { - @Override - public Pair resolve( - FunctionSignature unresolvedSignature) { - return Pair.of(unresolvedSignature, - (functionProperties, arguments) -> - new FunctionExpression(BuiltinFunctionName.TYPEOF.getName(), arguments) { - @Override - public ExprValue valueOf(Environment valueEnv) { - return new ExprStringValue(getArguments().get(0).type().legacyTypeName()); - } + // Auxiliary function useful for debugging + private static FunctionResolver typeof() { + return new FunctionResolver() { + @Override + public Pair resolve(FunctionSignature unresolvedSignature) { + return Pair.of( + unresolvedSignature, + (functionProperties, arguments) -> new FunctionExpression(BuiltinFunctionName.TYPEOF.getName(), arguments) { + @Override + public ExprValue valueOf(Environment valueEnv) { + return new ExprStringValue(getArguments().get(0).type().legacyTypeName()); + } - @Override - public ExprType type() { - return STRING; - } - }); - } + @Override + public ExprType type() { + return STRING; + } + } + ); + } - @Override - public FunctionName getFunctionName() { - return BuiltinFunctionName.TYPEOF.getName(); - } - }; - } + @Override + public FunctionName getFunctionName() { + return BuiltinFunctionName.TYPEOF.getName(); + } + }; + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java b/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java index e56c85a0c8..04076f9279 100644 --- a/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.text; import static org.opensearch.sql.data.type.ExprCoreType.ARRAY; @@ -41,342 +40,366 @@ */ @UtilityClass public class TextFunction { - private static String EMPTY_STRING = ""; - - /** - * Register String Functions. - * - * @param repository {@link BuiltinFunctionRepository}. - */ - public void register(BuiltinFunctionRepository repository) { - repository.register(ascii()); - repository.register(concat()); - repository.register(concat_ws()); - repository.register(left()); - repository.register(length()); - repository.register(locate()); - repository.register(lower()); - repository.register(ltrim()); - repository.register(position()); - repository.register(replace()); - repository.register(reverse()); - repository.register(right()); - repository.register(rtrim()); - repository.register(strcmp()); - repository.register(substr()); - repository.register(substring()); - repository.register(trim()); - repository.register(upper()); - } - - /** - * Gets substring starting at given point, for optional given length. - * Form of this function using keywords instead of comma delimited variables is not supported. - * Supports following signatures: - * (STRING, INTEGER)/(STRING, INTEGER, INTEGER) -> STRING - */ - private DefaultFunctionResolver substringSubstr(FunctionName functionName) { - return define(functionName, - impl(nullMissingHandling(TextFunction::exprSubstrStart), - STRING, STRING, INTEGER), - impl(nullMissingHandling(TextFunction::exprSubstrStartLength), - STRING, STRING, INTEGER, INTEGER)); - } - - private DefaultFunctionResolver substring() { - return substringSubstr(BuiltinFunctionName.SUBSTRING.getName()); - } - - private DefaultFunctionResolver substr() { - return substringSubstr(BuiltinFunctionName.SUBSTR.getName()); - } - - /** - * Removes leading whitespace from string. - * Supports following signatures: - * STRING -> STRING - */ - private DefaultFunctionResolver ltrim() { - return define(BuiltinFunctionName.LTRIM.getName(), - impl(nullMissingHandling((v) -> new ExprStringValue(v.stringValue().stripLeading())), - STRING, STRING)); - } - - /** - * Removes trailing whitespace from string. - * Supports following signatures: - * STRING -> STRING - */ - private DefaultFunctionResolver rtrim() { - return define(BuiltinFunctionName.RTRIM.getName(), - impl(nullMissingHandling((v) -> new ExprStringValue(v.stringValue().stripTrailing())), - STRING, STRING)); - } - - /** - * Removes leading and trailing whitespace from string. - * Has option to specify a String to trim instead of whitespace but this is not yet supported. - * Supporting String specification requires finding keywords inside TRIM command. - * Supports following signatures: - * STRING -> STRING - */ - private DefaultFunctionResolver trim() { - return define(BuiltinFunctionName.TRIM.getName(), - impl(nullMissingHandling((v) -> new ExprStringValue(v.stringValue().trim())), - STRING, STRING)); - } - - /** - * Converts String to lowercase. - * Supports following signatures: - * STRING -> STRING - */ - private DefaultFunctionResolver lower() { - return define(BuiltinFunctionName.LOWER.getName(), - impl(nullMissingHandling((v) -> new ExprStringValue((v.stringValue().toLowerCase()))), - STRING, STRING) - ); - } - - /** - * Converts String to uppercase. - * Supports following signatures: - * STRING -> STRING - */ - private DefaultFunctionResolver upper() { - return define(BuiltinFunctionName.UPPER.getName(), - impl(nullMissingHandling((v) -> new ExprStringValue((v.stringValue().toUpperCase()))), - STRING, STRING) - ); - } - - /** - * Concatenates a list of Strings. - * Supports following signatures: - * (STRING, STRING, ...., STRING) -> STRING - */ - private DefaultFunctionResolver concat() { - FunctionName concatFuncName = BuiltinFunctionName.CONCAT.getName(); - return define(concatFuncName, funcName -> - Pair.of( - new FunctionSignature(concatFuncName, Collections.singletonList(ARRAY)), - (funcProp, args) -> new FunctionExpression(funcName, args) { - @Override - public ExprValue valueOf(Environment valueEnv) { - List exprValues = args.stream() - .map(arg -> arg.valueOf(valueEnv)).collect(Collectors.toList()); - if (exprValues.stream().anyMatch(ExprValue::isMissing)) { - return ExprValueUtils.missingValue(); - } - if (exprValues.stream().anyMatch(ExprValue::isNull)) { - return ExprValueUtils.nullValue(); - } - return new ExprStringValue(exprValues.stream() - .map(ExprValue::stringValue) - .collect(Collectors.joining())); - } - - @Override - public ExprType type() { + private static String EMPTY_STRING = ""; + + /** + * Register String Functions. + * + * @param repository {@link BuiltinFunctionRepository}. + */ + public void register(BuiltinFunctionRepository repository) { + repository.register(ascii()); + repository.register(concat()); + repository.register(concat_ws()); + repository.register(left()); + repository.register(length()); + repository.register(locate()); + repository.register(lower()); + repository.register(ltrim()); + repository.register(position()); + repository.register(replace()); + repository.register(reverse()); + repository.register(right()); + repository.register(rtrim()); + repository.register(strcmp()); + repository.register(substr()); + repository.register(substring()); + repository.register(trim()); + repository.register(upper()); + } + + /** + * Gets substring starting at given point, for optional given length. + * Form of this function using keywords instead of comma delimited variables is not supported. + * Supports following signatures: + * (STRING, INTEGER)/(STRING, INTEGER, INTEGER) -> STRING + */ + private DefaultFunctionResolver substringSubstr(FunctionName functionName) { + return define( + functionName, + impl(nullMissingHandling(TextFunction::exprSubstrStart), STRING, STRING, INTEGER), + impl(nullMissingHandling(TextFunction::exprSubstrStartLength), STRING, STRING, INTEGER, INTEGER) + ); + } + + private DefaultFunctionResolver substring() { + return substringSubstr(BuiltinFunctionName.SUBSTRING.getName()); + } + + private DefaultFunctionResolver substr() { + return substringSubstr(BuiltinFunctionName.SUBSTR.getName()); + } + + /** + * Removes leading whitespace from string. + * Supports following signatures: + * STRING -> STRING + */ + private DefaultFunctionResolver ltrim() { + return define( + BuiltinFunctionName.LTRIM.getName(), + impl(nullMissingHandling((v) -> new ExprStringValue(v.stringValue().stripLeading())), STRING, STRING) + ); + } + + /** + * Removes trailing whitespace from string. + * Supports following signatures: + * STRING -> STRING + */ + private DefaultFunctionResolver rtrim() { + return define( + BuiltinFunctionName.RTRIM.getName(), + impl(nullMissingHandling((v) -> new ExprStringValue(v.stringValue().stripTrailing())), STRING, STRING) + ); + } + + /** + * Removes leading and trailing whitespace from string. + * Has option to specify a String to trim instead of whitespace but this is not yet supported. + * Supporting String specification requires finding keywords inside TRIM command. + * Supports following signatures: + * STRING -> STRING + */ + private DefaultFunctionResolver trim() { + return define( + BuiltinFunctionName.TRIM.getName(), + impl(nullMissingHandling((v) -> new ExprStringValue(v.stringValue().trim())), STRING, STRING) + ); + } + + /** + * Converts String to lowercase. + * Supports following signatures: + * STRING -> STRING + */ + private DefaultFunctionResolver lower() { + return define( + BuiltinFunctionName.LOWER.getName(), + impl(nullMissingHandling((v) -> new ExprStringValue((v.stringValue().toLowerCase()))), STRING, STRING) + ); + } + + /** + * Converts String to uppercase. + * Supports following signatures: + * STRING -> STRING + */ + private DefaultFunctionResolver upper() { + return define( + BuiltinFunctionName.UPPER.getName(), + impl(nullMissingHandling((v) -> new ExprStringValue((v.stringValue().toUpperCase()))), STRING, STRING) + ); + } + + /** + * Concatenates a list of Strings. + * Supports following signatures: + * (STRING, STRING, ...., STRING) -> STRING + */ + private DefaultFunctionResolver concat() { + FunctionName concatFuncName = BuiltinFunctionName.CONCAT.getName(); + return define( + concatFuncName, + funcName -> Pair.of( + new FunctionSignature(concatFuncName, Collections.singletonList(ARRAY)), + (funcProp, args) -> new FunctionExpression(funcName, args) { + @Override + public ExprValue valueOf(Environment valueEnv) { + List exprValues = args.stream().map(arg -> arg.valueOf(valueEnv)).collect(Collectors.toList()); + if (exprValues.stream().anyMatch(ExprValue::isMissing)) { + return ExprValueUtils.missingValue(); + } + if (exprValues.stream().anyMatch(ExprValue::isNull)) { + return ExprValueUtils.nullValue(); + } + return new ExprStringValue(exprValues.stream().map(ExprValue::stringValue).collect(Collectors.joining())); + } + + @Override + public ExprType type() { return STRING; - } } - )); - } - - /** - * TODO: https://github.com/opendistro-for-elasticsearch/sql/issues/710 - * Extend to accept variable argument amounts. - * Concatenates a list of Strings with a separator string. - * Supports following signatures: - * (STRING, STRING, STRING) -> STRING - */ - private DefaultFunctionResolver concat_ws() { - return define(BuiltinFunctionName.CONCAT_WS.getName(), - impl(nullMissingHandling((sep, str1, str2) -> - new ExprStringValue(str1.stringValue() + sep.stringValue() + str2.stringValue())), - STRING, STRING, STRING, STRING)); - } - - /** - * Calculates length of String in bytes. - * Supports following signatures: - * STRING -> INTEGER - */ - private DefaultFunctionResolver length() { - return define(BuiltinFunctionName.LENGTH.getName(), - impl(nullMissingHandling((str) -> - new ExprIntegerValue(str.stringValue().getBytes().length)), INTEGER, STRING)); - } - - /** - * Does String comparison of two Strings and returns Integer value. - * Supports following signatures: - * (STRING, STRING) -> INTEGER - */ - private DefaultFunctionResolver strcmp() { - return define(BuiltinFunctionName.STRCMP.getName(), - impl(nullMissingHandling((str1, str2) -> - new ExprIntegerValue(Integer.compare( - str1.stringValue().compareTo(str2.stringValue()), 0))), - INTEGER, STRING, STRING)); - } - - /** - * Returns the rightmost len characters from the string str, or NULL if any argument is NULL. - * Supports following signatures: - * (STRING, INTEGER) -> STRING - */ - private DefaultFunctionResolver right() { - return define(BuiltinFunctionName.RIGHT.getName(), - impl(nullMissingHandling(TextFunction::exprRight), STRING, STRING, INTEGER)); - } - - /** - * Returns the leftmost len characters from the string str, or NULL if any argument is NULL. - * Supports following signature: - * (STRING, INTEGER) -> STRING - */ - private DefaultFunctionResolver left() { - return define(BuiltinFunctionName.LEFT.getName(), - impl(nullMissingHandling(TextFunction::exprLeft), STRING, STRING, INTEGER)); - } - - /** - * Returns the numeric value of the leftmost character of the string str. - * Returns 0 if str is the empty string. Returns NULL if str is NULL. - * ASCII() works for 8-bit characters. - * Supports following signature: - * STRING -> INTEGER - */ - private DefaultFunctionResolver ascii() { - return define(BuiltinFunctionName.ASCII.getName(), - impl(nullMissingHandling(TextFunction::exprAscii), INTEGER, STRING)); - } - - /** - * LOCATE(substr, str) returns the position of the first occurrence of substring substr - * in string str. LOCATE(substr, str, pos) returns the position of the first occurrence - * of substring substr in string str, starting at position pos. - * Returns 0 if substr is not in str. - * Returns NULL if any argument is NULL. - * Supports following signature: - * (STRING, STRING) -> INTEGER - * (STRING, STRING, INTEGER) -> INTEGER - */ - private DefaultFunctionResolver locate() { - return define(BuiltinFunctionName.LOCATE.getName(), - impl(nullMissingHandling( - (SerializableBiFunction) - TextFunction::exprLocate), INTEGER, STRING, STRING), - impl(nullMissingHandling( - (SerializableTriFunction) - TextFunction::exprLocate), INTEGER, STRING, STRING, INTEGER)); - } - - /** - * Returns the position of the first occurrence of a substring in a string starting from 1. - * Returns 0 if substring is not in string. - * Returns NULL if any argument is NULL. - * Supports following signature: - * (STRING IN STRING) -> INTEGER - */ - private DefaultFunctionResolver position() { - return define(BuiltinFunctionName.POSITION.getName(), - impl(nullMissingHandling( - (SerializableBiFunction) - TextFunction::exprLocate), INTEGER, STRING, STRING)); - } - - /** - * REPLACE(str, from_str, to_str) returns the string str with all occurrences of - * the string from_str replaced by the string to_str. - * REPLACE() performs a case-sensitive match when searching for from_str. - * Supports following signature: - * (STRING, STRING, STRING) -> STRING - */ - private DefaultFunctionResolver replace() { - return define(BuiltinFunctionName.REPLACE.getName(), - impl(nullMissingHandling(TextFunction::exprReplace), STRING, STRING, STRING, STRING)); - } - - /** - * REVERSE(str) returns reversed string of the string supplied as an argument - * Returns NULL if the argument is NULL. - * Supports the following signature: - * (STRING) -> STRING - */ - private DefaultFunctionResolver reverse() { - return define(BuiltinFunctionName.REVERSE.getName(), - impl(nullMissingHandling(TextFunction::exprReverse), STRING, STRING)); - } - - private static ExprValue exprSubstrStart(ExprValue exprValue, ExprValue start) { - int startIdx = start.integerValue(); - if (startIdx == 0) { - return new ExprStringValue(EMPTY_STRING); + } + ) + ); } - String str = exprValue.stringValue(); - return exprSubStr(str, startIdx, str.length()); - } - - private static ExprValue exprSubstrStartLength( - ExprValue exprValue, ExprValue start, ExprValue length) { - int startIdx = start.integerValue(); - int len = length.integerValue(); - if ((startIdx == 0) || (len == 0)) { - return new ExprStringValue(EMPTY_STRING); + + /** + * TODO: https://github.com/opendistro-for-elasticsearch/sql/issues/710 + * Extend to accept variable argument amounts. + * Concatenates a list of Strings with a separator string. + * Supports following signatures: + * (STRING, STRING, STRING) -> STRING + */ + private DefaultFunctionResolver concat_ws() { + return define( + BuiltinFunctionName.CONCAT_WS.getName(), + impl( + nullMissingHandling((sep, str1, str2) -> new ExprStringValue(str1.stringValue() + sep.stringValue() + str2.stringValue())), + STRING, + STRING, + STRING, + STRING + ) + ); } - String str = exprValue.stringValue(); - return exprSubStr(str, startIdx, len); - } - - private static ExprValue exprSubStr(String str, int start, int len) { - // Correct negative start - start = (start > 0) ? (start - 1) : (str.length() + start); - - if (start > str.length()) { - return new ExprStringValue(EMPTY_STRING); - } else if ((start + len) > str.length()) { - return new ExprStringValue(str.substring(start)); + + /** + * Calculates length of String in bytes. + * Supports following signatures: + * STRING -> INTEGER + */ + private DefaultFunctionResolver length() { + return define( + BuiltinFunctionName.LENGTH.getName(), + impl(nullMissingHandling((str) -> new ExprIntegerValue(str.stringValue().getBytes().length)), INTEGER, STRING) + ); } - return new ExprStringValue(str.substring(start, start + len)); - } - private static ExprValue exprRight(ExprValue str, ExprValue len) { - if (len.integerValue() <= 0) { - return new ExprStringValue(""); + /** + * Does String comparison of two Strings and returns Integer value. + * Supports following signatures: + * (STRING, STRING) -> INTEGER + */ + private DefaultFunctionResolver strcmp() { + return define( + BuiltinFunctionName.STRCMP.getName(), + impl( + nullMissingHandling( + (str1, str2) -> new ExprIntegerValue(Integer.compare(str1.stringValue().compareTo(str2.stringValue()), 0)) + ), + INTEGER, + STRING, + STRING + ) + ); + } + + /** + * Returns the rightmost len characters from the string str, or NULL if any argument is NULL. + * Supports following signatures: + * (STRING, INTEGER) -> STRING + */ + private DefaultFunctionResolver right() { + return define(BuiltinFunctionName.RIGHT.getName(), impl(nullMissingHandling(TextFunction::exprRight), STRING, STRING, INTEGER)); + } + + /** + * Returns the leftmost len characters from the string str, or NULL if any argument is NULL. + * Supports following signature: + * (STRING, INTEGER) -> STRING + */ + private DefaultFunctionResolver left() { + return define(BuiltinFunctionName.LEFT.getName(), impl(nullMissingHandling(TextFunction::exprLeft), STRING, STRING, INTEGER)); + } + + /** + * Returns the numeric value of the leftmost character of the string str. + * Returns 0 if str is the empty string. Returns NULL if str is NULL. + * ASCII() works for 8-bit characters. + * Supports following signature: + * STRING -> INTEGER + */ + private DefaultFunctionResolver ascii() { + return define(BuiltinFunctionName.ASCII.getName(), impl(nullMissingHandling(TextFunction::exprAscii), INTEGER, STRING)); } - String stringValue = str.stringValue(); - int left = Math.max(stringValue.length() - len.integerValue(), 0); - return new ExprStringValue(str.stringValue().substring(left)); - } - - private static ExprValue exprLeft(ExprValue expr, ExprValue length) { - String stringValue = expr.stringValue(); - int right = length.integerValue(); - return new ExprStringValue(stringValue.substring(0, Math.min(right, stringValue.length()))); - } - - private static ExprValue exprAscii(ExprValue expr) { - return new ExprIntegerValue((int) expr.stringValue().charAt(0)); - } - - private static ExprValue exprLocate(ExprValue subStr, ExprValue str) { - return new ExprIntegerValue(str.stringValue().indexOf(subStr.stringValue()) + 1); - } - - private static ExprValue exprLocate(ExprValue subStr, ExprValue str, ExprValue pos) { - return new ExprIntegerValue( - str.stringValue().indexOf(subStr.stringValue(), pos.integerValue() - 1) + 1); - } - - private static ExprValue exprReplace(ExprValue str, ExprValue from, ExprValue to) { - return new ExprStringValue(str.stringValue().replaceAll(from.stringValue(), to.stringValue())); - } - - private static ExprValue exprReverse(ExprValue str) { - return new ExprStringValue(new StringBuilder(str.stringValue()).reverse().toString()); - } -} + /** + * LOCATE(substr, str) returns the position of the first occurrence of substring substr + * in string str. LOCATE(substr, str, pos) returns the position of the first occurrence + * of substring substr in string str, starting at position pos. + * Returns 0 if substr is not in str. + * Returns NULL if any argument is NULL. + * Supports following signature: + * (STRING, STRING) -> INTEGER + * (STRING, STRING, INTEGER) -> INTEGER + */ + private DefaultFunctionResolver locate() { + return define( + BuiltinFunctionName.LOCATE.getName(), + impl( + nullMissingHandling((SerializableBiFunction) TextFunction::exprLocate), + INTEGER, + STRING, + STRING + ), + impl( + nullMissingHandling((SerializableTriFunction) TextFunction::exprLocate), + INTEGER, + STRING, + STRING, + INTEGER + ) + ); + } + + /** + * Returns the position of the first occurrence of a substring in a string starting from 1. + * Returns 0 if substring is not in string. + * Returns NULL if any argument is NULL. + * Supports following signature: + * (STRING IN STRING) -> INTEGER + */ + private DefaultFunctionResolver position() { + return define( + BuiltinFunctionName.POSITION.getName(), + impl( + nullMissingHandling((SerializableBiFunction) TextFunction::exprLocate), + INTEGER, + STRING, + STRING + ) + ); + } + + /** + * REPLACE(str, from_str, to_str) returns the string str with all occurrences of + * the string from_str replaced by the string to_str. + * REPLACE() performs a case-sensitive match when searching for from_str. + * Supports following signature: + * (STRING, STRING, STRING) -> STRING + */ + private DefaultFunctionResolver replace() { + return define( + BuiltinFunctionName.REPLACE.getName(), + impl(nullMissingHandling(TextFunction::exprReplace), STRING, STRING, STRING, STRING) + ); + } + + /** + * REVERSE(str) returns reversed string of the string supplied as an argument + * Returns NULL if the argument is NULL. + * Supports the following signature: + * (STRING) -> STRING + */ + private DefaultFunctionResolver reverse() { + return define(BuiltinFunctionName.REVERSE.getName(), impl(nullMissingHandling(TextFunction::exprReverse), STRING, STRING)); + } + + private static ExprValue exprSubstrStart(ExprValue exprValue, ExprValue start) { + int startIdx = start.integerValue(); + if (startIdx == 0) { + return new ExprStringValue(EMPTY_STRING); + } + String str = exprValue.stringValue(); + return exprSubStr(str, startIdx, str.length()); + } + + private static ExprValue exprSubstrStartLength(ExprValue exprValue, ExprValue start, ExprValue length) { + int startIdx = start.integerValue(); + int len = length.integerValue(); + if ((startIdx == 0) || (len == 0)) { + return new ExprStringValue(EMPTY_STRING); + } + String str = exprValue.stringValue(); + return exprSubStr(str, startIdx, len); + } + + private static ExprValue exprSubStr(String str, int start, int len) { + // Correct negative start + start = (start > 0) ? (start - 1) : (str.length() + start); + + if (start > str.length()) { + return new ExprStringValue(EMPTY_STRING); + } else if ((start + len) > str.length()) { + return new ExprStringValue(str.substring(start)); + } + return new ExprStringValue(str.substring(start, start + len)); + } + + private static ExprValue exprRight(ExprValue str, ExprValue len) { + if (len.integerValue() <= 0) { + return new ExprStringValue(""); + } + String stringValue = str.stringValue(); + int left = Math.max(stringValue.length() - len.integerValue(), 0); + return new ExprStringValue(str.stringValue().substring(left)); + } + + private static ExprValue exprLeft(ExprValue expr, ExprValue length) { + String stringValue = expr.stringValue(); + int right = length.integerValue(); + return new ExprStringValue(stringValue.substring(0, Math.min(right, stringValue.length()))); + } + + private static ExprValue exprAscii(ExprValue expr) { + return new ExprIntegerValue((int) expr.stringValue().charAt(0)); + } + + private static ExprValue exprLocate(ExprValue subStr, ExprValue str) { + return new ExprIntegerValue(str.stringValue().indexOf(subStr.stringValue()) + 1); + } + + private static ExprValue exprLocate(ExprValue subStr, ExprValue str, ExprValue pos) { + return new ExprIntegerValue(str.stringValue().indexOf(subStr.stringValue(), pos.integerValue() - 1) + 1); + } + + private static ExprValue exprReplace(ExprValue str, ExprValue from, ExprValue to) { + return new ExprStringValue(str.stringValue().replaceAll(from.stringValue(), to.stringValue())); + } + + private static ExprValue exprReverse(ExprValue str) { + return new ExprStringValue(new StringBuilder(str.stringValue()).reverse().toString()); + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/window/WindowDefinition.java b/core/src/main/java/org/opensearch/sql/expression/window/WindowDefinition.java index 24751633de..74bdba4e47 100644 --- a/core/src/main/java/org/opensearch/sql/expression/window/WindowDefinition.java +++ b/core/src/main/java/org/opensearch/sql/expression/window/WindowDefinition.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.window; import static org.opensearch.sql.ast.tree.Sort.SortOption; @@ -22,18 +21,18 @@ @Data public class WindowDefinition { - private final List partitionByList; - private final List> sortList; - - /** - * Return all items in partition by and sort list. - * @return all sort items - */ - public List> getAllSortItems() { - List> allSorts = new ArrayList<>(); - partitionByList.forEach(expr -> allSorts.add(ImmutablePair.of(DEFAULT_ASC, expr))); - allSorts.addAll(sortList); - return allSorts; - } + private final List partitionByList; + private final List> sortList; + + /** + * Return all items in partition by and sort list. + * @return all sort items + */ + public List> getAllSortItems() { + List> allSorts = new ArrayList<>(); + partitionByList.forEach(expr -> allSorts.add(ImmutablePair.of(DEFAULT_ASC, expr))); + allSorts.addAll(sortList); + return allSorts; + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/window/WindowFunctionExpression.java b/core/src/main/java/org/opensearch/sql/expression/window/WindowFunctionExpression.java index a15919bf03..f9d3f4a422 100644 --- a/core/src/main/java/org/opensearch/sql/expression/window/WindowFunctionExpression.java +++ b/core/src/main/java/org/opensearch/sql/expression/window/WindowFunctionExpression.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.window; import org.opensearch.sql.expression.Expression; @@ -14,16 +13,16 @@ */ public interface WindowFunctionExpression extends Expression { - /** - * Create specific window frame based on window definition and what's current window function. - * For now two types of cumulative window frame is returned: - * 1. Ranking window functions: ignore frame definition and always operates on - * previous and current row. - * 2. Aggregate window functions: frame partition into peers and sliding window is not supported. - * - * @param definition window definition - * @return window frame - */ - WindowFrame createWindowFrame(WindowDefinition definition); + /** + * Create specific window frame based on window definition and what's current window function. + * For now two types of cumulative window frame is returned: + * 1. Ranking window functions: ignore frame definition and always operates on + * previous and current row. + * 2. Aggregate window functions: frame partition into peers and sliding window is not supported. + * + * @param definition window definition + * @return window frame + */ + WindowFrame createWindowFrame(WindowDefinition definition); } diff --git a/core/src/main/java/org/opensearch/sql/expression/window/WindowFunctions.java b/core/src/main/java/org/opensearch/sql/expression/window/WindowFunctions.java index 9a9e0c4c86..cb6f3be12e 100644 --- a/core/src/main/java/org/opensearch/sql/expression/window/WindowFunctions.java +++ b/core/src/main/java/org/opensearch/sql/expression/window/WindowFunctions.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.window; import static java.util.Collections.emptyList; @@ -28,34 +27,32 @@ @UtilityClass public class WindowFunctions { - /** - * Register all window functions to function repository. - * - * @param repository function repository - */ - public void register(BuiltinFunctionRepository repository) { - repository.register(rowNumber()); - repository.register(rank()); - repository.register(denseRank()); - } - - private DefaultFunctionResolver rowNumber() { - return rankingFunction(BuiltinFunctionName.ROW_NUMBER.getName(), RowNumberFunction::new); - } - - private DefaultFunctionResolver rank() { - return rankingFunction(BuiltinFunctionName.RANK.getName(), RankFunction::new); - } - - private DefaultFunctionResolver denseRank() { - return rankingFunction(BuiltinFunctionName.DENSE_RANK.getName(), DenseRankFunction::new); - } - - private DefaultFunctionResolver rankingFunction(FunctionName functionName, - Supplier constructor) { - FunctionSignature functionSignature = new FunctionSignature(functionName, emptyList()); - FunctionBuilder functionBuilder = (functionProperties, arguments) -> constructor.get(); - return new DefaultFunctionResolver(functionName, - ImmutableMap.of(functionSignature, functionBuilder)); - } + /** + * Register all window functions to function repository. + * + * @param repository function repository + */ + public void register(BuiltinFunctionRepository repository) { + repository.register(rowNumber()); + repository.register(rank()); + repository.register(denseRank()); + } + + private DefaultFunctionResolver rowNumber() { + return rankingFunction(BuiltinFunctionName.ROW_NUMBER.getName(), RowNumberFunction::new); + } + + private DefaultFunctionResolver rank() { + return rankingFunction(BuiltinFunctionName.RANK.getName(), RankFunction::new); + } + + private DefaultFunctionResolver denseRank() { + return rankingFunction(BuiltinFunctionName.DENSE_RANK.getName(), DenseRankFunction::new); + } + + private DefaultFunctionResolver rankingFunction(FunctionName functionName, Supplier constructor) { + FunctionSignature functionSignature = new FunctionSignature(functionName, emptyList()); + FunctionBuilder functionBuilder = (functionProperties, arguments) -> constructor.get(); + return new DefaultFunctionResolver(functionName, ImmutableMap.of(functionSignature, functionBuilder)); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/window/aggregation/AggregateWindowFunction.java b/core/src/main/java/org/opensearch/sql/expression/window/aggregation/AggregateWindowFunction.java index 604f65e6ff..9fc2a2f539 100644 --- a/core/src/main/java/org/opensearch/sql/expression/window/aggregation/AggregateWindowFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/window/aggregation/AggregateWindowFunction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.window.aggregation; import java.util.List; @@ -28,41 +27,41 @@ @RequiredArgsConstructor public class AggregateWindowFunction implements WindowFunctionExpression { - private final Aggregator aggregator; - private AggregationState state; - - @Override - public WindowFrame createWindowFrame(WindowDefinition definition) { - return new PeerRowsWindowFrame(definition); - } + private final Aggregator aggregator; + private AggregationState state; - @Override - public ExprValue valueOf(Environment valueEnv) { - PeerRowsWindowFrame frame = (PeerRowsWindowFrame) valueEnv; - if (frame.isNewPartition()) { - state = aggregator.create(); + @Override + public WindowFrame createWindowFrame(WindowDefinition definition) { + return new PeerRowsWindowFrame(definition); } - List peers = frame.next(); - for (ExprValue peer : peers) { - state = aggregator.iterate(peer.bindingTuples(), state); + @Override + public ExprValue valueOf(Environment valueEnv) { + PeerRowsWindowFrame frame = (PeerRowsWindowFrame) valueEnv; + if (frame.isNewPartition()) { + state = aggregator.create(); + } + + List peers = frame.next(); + for (ExprValue peer : peers) { + state = aggregator.iterate(peer.bindingTuples(), state); + } + return state.result(); } - return state.result(); - } - @Override - public ExprType type() { - return aggregator.type(); - } + @Override + public ExprType type() { + return aggregator.type(); + } - @Override - public T accept(ExpressionNodeVisitor visitor, C context) { - return aggregator.accept(visitor, context); - } + @Override + public T accept(ExpressionNodeVisitor visitor, C context) { + return aggregator.accept(visitor, context); + } - @Override - public String toString() { - return aggregator.toString(); - } + @Override + public String toString() { + return aggregator.toString(); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/window/frame/CurrentRowWindowFrame.java b/core/src/main/java/org/opensearch/sql/expression/window/frame/CurrentRowWindowFrame.java index 06b19a1488..382006ddde 100644 --- a/core/src/main/java/org/opensearch/sql/expression/window/frame/CurrentRowWindowFrame.java +++ b/core/src/main/java/org/opensearch/sql/expression/window/frame/CurrentRowWindowFrame.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.window.frame; import com.google.common.collect.PeekingIterator; @@ -31,59 +30,57 @@ @ToString public class CurrentRowWindowFrame implements WindowFrame { - @Getter - private final WindowDefinition windowDefinition; + @Getter + private final WindowDefinition windowDefinition; - private ExprValue previous; - private ExprValue current; + private ExprValue previous; + private ExprValue current; - @Override - public boolean isNewPartition() { - Objects.requireNonNull(current); + @Override + public boolean isNewPartition() { + Objects.requireNonNull(current); - if (previous == null) { - return true; - } + if (previous == null) { + return true; + } - List preValues = resolve(windowDefinition.getPartitionByList(), previous); - List curValues = resolve(windowDefinition.getPartitionByList(), current); - return !preValues.equals(curValues); - } + List preValues = resolve(windowDefinition.getPartitionByList(), previous); + List curValues = resolve(windowDefinition.getPartitionByList(), current); + return !preValues.equals(curValues); + } - @Override - public void load(PeekingIterator it) { - previous = current; - current = it.next(); - } + @Override + public void load(PeekingIterator it) { + previous = current; + current = it.next(); + } - @Override - public ExprValue current() { - return current; - } + @Override + public ExprValue current() { + return current; + } - public ExprValue previous() { - return previous; - } + public ExprValue previous() { + return previous; + } - private List resolve(List expressions, ExprValue row) { - Environment valueEnv = row.bindingTuples(); - return expressions.stream() - .map(expr -> expr.valueOf(valueEnv)) - .collect(Collectors.toList()); - } + private List resolve(List expressions, ExprValue row) { + Environment valueEnv = row.bindingTuples(); + return expressions.stream().map(expr -> expr.valueOf(valueEnv)).collect(Collectors.toList()); + } - /** - * Current row window frame won't pre-fetch any row ahead. - * So always return false as nothing "cached" in frame. - */ - @Override - public boolean hasNext() { - return false; - } + /** + * Current row window frame won't pre-fetch any row ahead. + * So always return false as nothing "cached" in frame. + */ + @Override + public boolean hasNext() { + return false; + } - @Override - public List next() { - return Collections.emptyList(); - } + @Override + public List next() { + return Collections.emptyList(); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/window/frame/PeerRowsWindowFrame.java b/core/src/main/java/org/opensearch/sql/expression/window/frame/PeerRowsWindowFrame.java index a3e8de40c1..9a9f0fb3e6 100644 --- a/core/src/main/java/org/opensearch/sql/expression/window/frame/PeerRowsWindowFrame.java +++ b/core/src/main/java/org/opensearch/sql/expression/window/frame/PeerRowsWindowFrame.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.window.frame; import com.google.common.collect.PeekingIterator; @@ -26,122 +25,116 @@ @RequiredArgsConstructor public class PeerRowsWindowFrame implements WindowFrame { - private final WindowDefinition windowDefinition; - - /** - * All peer rows (peer means rows in a partition that share same sort key - * based on sort list in window definition. - */ - private final List peers = new ArrayList<>(); - - /** - * Which row in the peer is currently being enriched by window function. - */ - private int position; - - /** - * Does row at current position represents a new partition. - */ - private boolean isNewPartition = true; - - /** - * If any more pre-fetched rows not returned to window operator yet. - */ - @Override - public boolean hasNext() { - return position < peers.size(); - } - - /** - * Move position and clear new partition flag. - * Note that because all peer rows have same result from window function, - * this is only returned at first time to change window function state. - * Afterwards, empty list is returned to avoid changes until next peer loaded. - * - * @return all rows for the peer - */ - @Override - public List next() { - isNewPartition = false; - if (position++ == 0) { - return peers; + private final WindowDefinition windowDefinition; + + /** + * All peer rows (peer means rows in a partition that share same sort key + * based on sort list in window definition. + */ + private final List peers = new ArrayList<>(); + + /** + * Which row in the peer is currently being enriched by window function. + */ + private int position; + + /** + * Does row at current position represents a new partition. + */ + private boolean isNewPartition = true; + + /** + * If any more pre-fetched rows not returned to window operator yet. + */ + @Override + public boolean hasNext() { + return position < peers.size(); + } + + /** + * Move position and clear new partition flag. + * Note that because all peer rows have same result from window function, + * this is only returned at first time to change window function state. + * Afterwards, empty list is returned to avoid changes until next peer loaded. + * + * @return all rows for the peer + */ + @Override + public List next() { + isNewPartition = false; + if (position++ == 0) { + return peers; + } + return Collections.emptyList(); + } + + /** + * Current row at the position. Because rows are pre-fetched here, + * window operator needs to get them from here too. + * @return row at current position that being enriched by window function + */ + @Override + public ExprValue current() { + return peers.get(position); + } + + /** + * Preload all peer rows if last peer rows done. Note that when no more data in peeking iterator, + * there must be rows in frame (hasNext()=true), so no need to check it.hasNext() in this method. + * Load until: + * 1. Different peer found (row with different sort key) + * 2. Or new partition (row with different partition key) + * 3. Or no more rows + * @param it rows iterator + */ + @Override + public void load(PeekingIterator it) { + if (hasNext()) { + return; + } + + // Reset state: reset new partition before clearing peers + isNewPartition = !isSamePartition(it.peek()); + position = 0; + peers.clear(); + + while (it.hasNext()) { + ExprValue next = it.peek(); + if (peers.isEmpty()) { + peers.add(it.next()); + } else if (isSamePartition(next) && isPeer(next)) { + peers.add(it.next()); + } else { + break; + } + } } - return Collections.emptyList(); - } - - /** - * Current row at the position. Because rows are pre-fetched here, - * window operator needs to get them from here too. - * @return row at current position that being enriched by window function - */ - @Override - public ExprValue current() { - return peers.get(position); - } - - /** - * Preload all peer rows if last peer rows done. Note that when no more data in peeking iterator, - * there must be rows in frame (hasNext()=true), so no need to check it.hasNext() in this method. - * Load until: - * 1. Different peer found (row with different sort key) - * 2. Or new partition (row with different partition key) - * 3. Or no more rows - * @param it rows iterator - */ - @Override - public void load(PeekingIterator it) { - if (hasNext()) { - return; + + @Override + public boolean isNewPartition() { + return isNewPartition; } - // Reset state: reset new partition before clearing peers - isNewPartition = !isSamePartition(it.peek()); - position = 0; - peers.clear(); - - while (it.hasNext()) { - ExprValue next = it.peek(); - if (peers.isEmpty()) { - peers.add(it.next()); - } else if (isSamePartition(next) && isPeer(next)) { - peers.add(it.next()); - } else { - break; - } + private boolean isPeer(ExprValue next) { + List sortFields = windowDefinition.getSortList().stream().map(Pair::getRight).collect(Collectors.toList()); + + ExprValue last = peers.get(peers.size() - 1); + return resolve(sortFields, last).equals(resolve(sortFields, next)); } - } - - @Override - public boolean isNewPartition() { - return isNewPartition; - } - - private boolean isPeer(ExprValue next) { - List sortFields = - windowDefinition.getSortList() - .stream() - .map(Pair::getRight) - .collect(Collectors.toList()); - - ExprValue last = peers.get(peers.size() - 1); - return resolve(sortFields, last).equals(resolve(sortFields, next)); - } - - private boolean isSamePartition(ExprValue next) { - if (peers.isEmpty()) { - return false; + + private boolean isSamePartition(ExprValue next) { + if (peers.isEmpty()) { + return false; + } + + List partitionByList = windowDefinition.getPartitionByList(); + ExprValue last = peers.get(peers.size() - 1); + return resolve(partitionByList, last).equals(resolve(partitionByList, next)); } - List partitionByList = windowDefinition.getPartitionByList(); - ExprValue last = peers.get(peers.size() - 1); - return resolve(partitionByList, last).equals(resolve(partitionByList, next)); - } - - private List resolve(List expressions, ExprValue row) { - Environment valueEnv = row.bindingTuples(); - return expressions.stream() - .map(expr -> expr.valueOf(valueEnv)) - .collect(Collectors.toList()); - } + private List resolve(List expressions, ExprValue row) { + Environment valueEnv = row.bindingTuples(); + return expressions.stream().map(expr -> expr.valueOf(valueEnv)).collect(Collectors.toList()); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/window/frame/WindowFrame.java b/core/src/main/java/org/opensearch/sql/expression/window/frame/WindowFrame.java index 323656547f..d971544fb0 100644 --- a/core/src/main/java/org/opensearch/sql/expression/window/frame/WindowFrame.java +++ b/core/src/main/java/org/opensearch/sql/expression/window/frame/WindowFrame.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.window.frame; import com.google.common.collect.PeekingIterator; @@ -24,27 +23,27 @@ */ public interface WindowFrame extends Environment, Iterator> { - @Override - default ExprValue resolve(Expression var) { - return var.valueOf(current().bindingTuples()); - } - - /** - * Check is current row the beginning of a new partition according to window definition. - * @return true if a new partition begins here, otherwise false. - */ - boolean isNewPartition(); - - /** - * Load one or more rows as window function calculation needed. - * @param iterator peeking iterator that can peek next element without moving iterator - */ - void load(PeekingIterator iterator); - - /** - * Get current data row for giving window operator chance to get rows preloaded into frame. - * @return data row - */ - ExprValue current(); + @Override + default ExprValue resolve(Expression var) { + return var.valueOf(current().bindingTuples()); + } + + /** + * Check is current row the beginning of a new partition according to window definition. + * @return true if a new partition begins here, otherwise false. + */ + boolean isNewPartition(); + + /** + * Load one or more rows as window function calculation needed. + * @param iterator peeking iterator that can peek next element without moving iterator + */ + void load(PeekingIterator iterator); + + /** + * Get current data row for giving window operator chance to get rows preloaded into frame. + * @return data row + */ + ExprValue current(); } diff --git a/core/src/main/java/org/opensearch/sql/expression/window/ranking/DenseRankFunction.java b/core/src/main/java/org/opensearch/sql/expression/window/ranking/DenseRankFunction.java index ba6e88d98d..a60a8a24b1 100644 --- a/core/src/main/java/org/opensearch/sql/expression/window/ranking/DenseRankFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/window/ranking/DenseRankFunction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.window.ranking; import org.opensearch.sql.expression.function.BuiltinFunctionName; @@ -15,20 +14,20 @@ */ public class DenseRankFunction extends RankingWindowFunction { - public DenseRankFunction() { - super(BuiltinFunctionName.DENSE_RANK.getName()); - } + public DenseRankFunction() { + super(BuiltinFunctionName.DENSE_RANK.getName()); + } - @Override - protected int rank(CurrentRowWindowFrame frame) { - if (frame.isNewPartition()) { - rank = 1; - } else { - if (isSortFieldValueDifferent(frame)) { - rank++; - } + @Override + protected int rank(CurrentRowWindowFrame frame) { + if (frame.isNewPartition()) { + rank = 1; + } else { + if (isSortFieldValueDifferent(frame)) { + rank++; + } + } + return rank; } - return rank; - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/window/ranking/RankFunction.java b/core/src/main/java/org/opensearch/sql/expression/window/ranking/RankFunction.java index c1f33e6137..9358c386b3 100644 --- a/core/src/main/java/org/opensearch/sql/expression/window/ranking/RankFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/window/ranking/RankFunction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.window.ranking; import org.opensearch.sql.expression.function.BuiltinFunctionName; @@ -16,27 +15,27 @@ */ public class RankFunction extends RankingWindowFunction { - /** - * Total number of rows have seen in current partition. - */ - private int total; + /** + * Total number of rows have seen in current partition. + */ + private int total; - public RankFunction() { - super(BuiltinFunctionName.RANK.getName()); - } + public RankFunction() { + super(BuiltinFunctionName.RANK.getName()); + } - @Override - protected int rank(CurrentRowWindowFrame frame) { - if (frame.isNewPartition()) { - total = 1; - rank = 1; - } else { - total++; - if (isSortFieldValueDifferent(frame)) { - rank = total; - } + @Override + protected int rank(CurrentRowWindowFrame frame) { + if (frame.isNewPartition()) { + total = 1; + rank = 1; + } else { + total++; + if (isSortFieldValueDifferent(frame)) { + rank = total; + } + } + return rank; } - return rank; - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/window/ranking/RankingWindowFunction.java b/core/src/main/java/org/opensearch/sql/expression/window/ranking/RankingWindowFunction.java index 07a4b42dbd..ca0632bd11 100644 --- a/core/src/main/java/org/opensearch/sql/expression/window/ranking/RankingWindowFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/window/ranking/RankingWindowFunction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.window.ranking; import static java.util.Collections.emptyList; @@ -29,74 +28,67 @@ * Ranking window function base class that captures same info across different ranking functions, * such as same return type (integer), same argument list (no arg). */ -public abstract class RankingWindowFunction extends FunctionExpression - implements WindowFunctionExpression { - - /** - * Current rank number assigned. - */ - protected int rank; - - public RankingWindowFunction(FunctionName functionName) { - super(functionName, emptyList()); - } - - @Override - public ExprType type() { - return ExprCoreType.INTEGER; - } - - @Override - public WindowFrame createWindowFrame(WindowDefinition definition) { - return new CurrentRowWindowFrame(definition); - } - - @Override - public ExprValue valueOf(Environment valueEnv) { - return new ExprIntegerValue(rank((CurrentRowWindowFrame) valueEnv)); - } - - /** - * Rank logic that sub-class needs to implement. - * @param frame window frame - * @return rank number - */ - protected abstract int rank(CurrentRowWindowFrame frame); - - /** - * Check sort field to see if current value is different from previous. - * @param frame window frame - * @return true if different, false if same or no sort list defined - */ - protected boolean isSortFieldValueDifferent(CurrentRowWindowFrame frame) { - if (isSortItemsNotDefined(frame)) { - return false; +public abstract class RankingWindowFunction extends FunctionExpression implements WindowFunctionExpression { + + /** + * Current rank number assigned. + */ + protected int rank; + + public RankingWindowFunction(FunctionName functionName) { + super(functionName, emptyList()); + } + + @Override + public ExprType type() { + return ExprCoreType.INTEGER; + } + + @Override + public WindowFrame createWindowFrame(WindowDefinition definition) { + return new CurrentRowWindowFrame(definition); + } + + @Override + public ExprValue valueOf(Environment valueEnv) { + return new ExprIntegerValue(rank((CurrentRowWindowFrame) valueEnv)); + } + + /** + * Rank logic that sub-class needs to implement. + * @param frame window frame + * @return rank number + */ + protected abstract int rank(CurrentRowWindowFrame frame); + + /** + * Check sort field to see if current value is different from previous. + * @param frame window frame + * @return true if different, false if same or no sort list defined + */ + protected boolean isSortFieldValueDifferent(CurrentRowWindowFrame frame) { + if (isSortItemsNotDefined(frame)) { + return false; + } + + List sortItems = frame.getWindowDefinition().getSortList().stream().map(Pair::getRight).collect(Collectors.toList()); + + List previous = resolve(frame, sortItems, frame.previous()); + List current = resolve(frame, sortItems, frame.current()); + return !current.equals(previous); + } + + private boolean isSortItemsNotDefined(CurrentRowWindowFrame frame) { + return frame.getWindowDefinition().getSortList().isEmpty(); + } + + private List resolve(WindowFrame frame, List expressions, ExprValue row) { + BindingTuple valueEnv = row.bindingTuples(); + return expressions.stream().map(expr -> expr.valueOf(valueEnv)).collect(Collectors.toList()); } - List sortItems = frame.getWindowDefinition() - .getSortList() - .stream() - .map(Pair::getRight) - .collect(Collectors.toList()); - - List previous = resolve(frame, sortItems, frame.previous()); - List current = resolve(frame, sortItems, frame.current()); - return !current.equals(previous); - } - - private boolean isSortItemsNotDefined(CurrentRowWindowFrame frame) { - return frame.getWindowDefinition().getSortList().isEmpty(); - } - - private List resolve(WindowFrame frame, List expressions, ExprValue row) { - BindingTuple valueEnv = row.bindingTuples(); - return expressions.stream() - .map(expr -> expr.valueOf(valueEnv)) - .collect(Collectors.toList()); - } - - @Override - public String toString() { - return getFunctionName() + "()"; - } + @Override + public String toString() { + return getFunctionName() + "()"; + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/window/ranking/RowNumberFunction.java b/core/src/main/java/org/opensearch/sql/expression/window/ranking/RowNumberFunction.java index 067dfa569d..6677bd7e87 100644 --- a/core/src/main/java/org/opensearch/sql/expression/window/ranking/RowNumberFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/window/ranking/RowNumberFunction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.expression.window.ranking; import org.opensearch.sql.expression.function.BuiltinFunctionName; @@ -14,16 +13,16 @@ */ public class RowNumberFunction extends RankingWindowFunction { - public RowNumberFunction() { - super(BuiltinFunctionName.ROW_NUMBER.getName()); - } + public RowNumberFunction() { + super(BuiltinFunctionName.ROW_NUMBER.getName()); + } - @Override - protected int rank(CurrentRowWindowFrame frame) { - if (frame.isNewPartition()) { - rank = 1; + @Override + protected int rank(CurrentRowWindowFrame frame) { + if (frame.isNewPartition()) { + rank = 1; + } + return rank++; } - return rank++; - } } diff --git a/formatterConfig.xml b/formatterConfig.xml new file mode 100644 index 0000000000..b0e1ecccb9 --- /dev/null +++ b/formatterConfig.xml @@ -0,0 +1,362 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file