diff --git a/.github/workflows/draft-release-notes-workflow.yml b/.github/workflows/draft-release-notes-workflow.yml index b0b92441b1..660a8a1a51 100644 --- a/.github/workflows/draft-release-notes-workflow.yml +++ b/.github/workflows/draft-release-notes-workflow.yml @@ -16,6 +16,6 @@ jobs: with: config-name: draft-release-notes-config.yml tag: (None) - version: 2.2.0.0 + version: 2.3.0.0 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/sql-odbc-release-workflow.yml b/.github/workflows/sql-odbc-release-workflow.yml index 00920fffd7..0d08865378 100644 --- a/.github/workflows/sql-odbc-release-workflow.yml +++ b/.github/workflows/sql-odbc-release-workflow.yml @@ -12,7 +12,7 @@ env: ODBC_BUILD_PATH: "./build/odbc/build" AWS_SDK_INSTALL_PATH: "./build/aws-sdk/install" PLUGIN_NAME: opensearch-sql-odbc - OD_VERSION: 2.2.0.0 + OD_VERSION: 2.3.0.0 jobs: build-mac: diff --git a/.github/workflows/sql-test-and-build-workflow.yml b/.github/workflows/sql-test-and-build-workflow.yml index 70d1c3a3e5..fcc63433a8 100644 --- a/.github/workflows/sql-test-and-build-workflow.yml +++ b/.github/workflows/sql-test-and-build-workflow.yml @@ -22,6 +22,9 @@ jobs: - name: Build with Gradle run: ./gradlew build assemble + - name: Run backward compatibility tests + run: ./scripts/bwctest.sh + - name: Create Artifact Path run: | mkdir -p opensearch-sql-builds diff --git a/.github/workflows/sql-workbench-release-workflow.yml b/.github/workflows/sql-workbench-release-workflow.yml index ef23bff98a..840428e538 100644 --- a/.github/workflows/sql-workbench-release-workflow.yml +++ b/.github/workflows/sql-workbench-release-workflow.yml @@ -8,7 +8,7 @@ on: env: PLUGIN_NAME: query-workbench-dashboards OPENSEARCH_VERSION: 'main' - OPENSEARCH_PLUGIN_VERSION: 2.2.0.0 + OPENSEARCH_PLUGIN_VERSION: 2.3.0.0 jobs: diff --git a/.github/workflows/sql-workbench-test-and-build-workflow.yml b/.github/workflows/sql-workbench-test-and-build-workflow.yml index c0ae593c1d..d4da17bf7f 100644 --- a/.github/workflows/sql-workbench-test-and-build-workflow.yml +++ b/.github/workflows/sql-workbench-test-and-build-workflow.yml @@ -5,7 +5,7 @@ on: [pull_request, push] env: PLUGIN_NAME: query-workbench-dashboards OPENSEARCH_VERSION: 'main' - OPENSEARCH_PLUGIN_VERSION: 2.2.0.0 + OPENSEARCH_PLUGIN_VERSION: 2.3.0.0 jobs: diff --git a/MAINTAINERS.md b/MAINTAINERS.md index 734a390acb..ba4ce45209 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -9,4 +9,6 @@ | Chen Dai | [dai-chen](https://github.com/dai-chen) | Amazon | | Chloe Zhang | [chloe-zh](https://github.com/chloe-zh) | Amazon | | Nick Knize | [nknize](https://github.com/nknize) | Amazon | -| Charlotte Henkle | [CEHENKLE](https://github.com/CEHENKLE) | Amazon | \ No newline at end of file +| Charlotte Henkle | [CEHENKLE](https://github.com/CEHENKLE) | Amazon | +| Max Ksyunz | [MaxKsyunz](https://github.com/MaxKsyunz) | BitQuill | +| Yury Fridlyand | [Yury-Fridlyand](https://github.com/Yury-Fridlyand) | BitQuill | \ No newline at end of file diff --git a/README.md b/README.md index d3c79cc97f..0c220838b5 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ - [Code Summary](#code-summary) - [Highlights](#highlights) - [Documentation](#documentation) +- [OpenSearch Forum](#forum) - [Contributing](#contributing) - [Attribution](#attribution) - [Code of Conduct](#code-of-conduct) @@ -127,6 +128,10 @@ Recently we have been actively improving our query engine primarily for better c Please refer to the [SQL Language Reference Manual](./docs/user/index.rst), [Piped Processing Language (PPL) Reference Manual](./docs/user/ppl/index.rst) and [Technical Documentation](https://opensearch.org/docs/latest/search-plugins/sql/index/) for detailed information on installing and configuring plugin. +## Forum + +For additional help with the plugin, including questions about opening an issue, visit the OpenSearch [Forum](https://forum.opensearch.org/c/plugins/sql/8). + ## Contributing See [developer guide](DEVELOPER_GUIDE.rst) and [how to contribute to this project](CONTRIBUTING.md). diff --git a/bi-connectors/PowerBIConnector/src/OpenSearchProject.query.pq b/bi-connectors/PowerBIConnector/src/OpenSearchProject.query.pq index 19c84006bf..45026f96a3 100644 --- a/bi-connectors/PowerBIConnector/src/OpenSearchProject.query.pq +++ b/bi-connectors/PowerBIConnector/src/OpenSearchProject.query.pq @@ -7,13 +7,14 @@ shared MyExtension.UnitTest = Host = "localhost", Port = 9200, UseSSL = false, + HostnameVerification = false, facts = { Fact("Connection Test", 7, let - Source = OpenSearch.Contents(Host, Port, UseSSL), + Source = OpenSearchProject.Contents(Host, Port, UseSSL, HostnameVerification), no_of_columns = Table.ColumnCount(Source) in no_of_columns @@ -22,7 +23,7 @@ shared MyExtension.UnitTest = #table(type table [bool0 = logical], { {null}, {false}, {true} }), let - Source = OpenSearch.Contents(Host, Port, UseSSL), + Source = OpenSearchProject.Contents(Host, Port, UseSSL, HostnameVerification), calcs_null_null = Source{[Item="calcs",Schema=null,Catalog=null]}[Data], grouped = Table.Group(calcs_null_null, {"bool0"}, {}) in diff --git a/build.gradle b/build.gradle index 855ec748bc..c96655a5c1 100644 --- a/build.gradle +++ b/build.gradle @@ -6,7 +6,7 @@ buildscript { ext { - opensearch_version = System.getProperty("opensearch.version", "2.2.0-SNAPSHOT") + opensearch_version = System.getProperty("opensearch.version", "2.3.0-SNAPSHOT") spring_version = "5.3.22" jackson_version = "2.13.3" isSnapshot = "true" == System.getProperty("build.snapshot", "true") diff --git a/core/build.gradle b/core/build.gradle index 1fa3e19e26..2926eb0614 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -45,6 +45,9 @@ dependencies { api group: 'org.apache.commons', name: 'commons-lang3', version: '3.10' api group: 'com.facebook.presto', name: 'presto-matching', version: '0.240' api group: 'org.apache.commons', name: 'commons-math3', version: '3.6.1' + api "com.fasterxml.jackson.core:jackson-core:${jackson_version}" + api "com.fasterxml.jackson.core:jackson-databind:${jackson_version}" + api "com.fasterxml.jackson.core:jackson-annotations:${jackson_version}" api project(':common') testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') @@ -70,7 +73,7 @@ jacocoTestReport { afterEvaluate { classDirectories.setFrom(files(classDirectories.files.collect { fileTree(dir: it, - exclude: ['**/ast/**']) + exclude: ['**/ast/**', '**/catalog/model/**']) })) } } @@ -80,7 +83,9 @@ jacocoTestCoverageVerification { rule { element = 'CLASS' excludes = [ - 'org.opensearch.sql.utils.MLCommonsConstants' + 'org.opensearch.sql.utils.MLCommonsConstants', + 'org.opensearch.sql.utils.Constants', + 'org.opensearch.sql.catalog.model.*' ] limit { counter = 'LINE' diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index dc12bdab73..eea1c0786b 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -55,6 +55,7 @@ import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.ast.tree.Values; +import org.opensearch.sql.catalog.CatalogService; import org.opensearch.sql.data.model.ExprMissingValue; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.exception.SemanticCheckException; @@ -81,7 +82,6 @@ import org.opensearch.sql.planner.logical.LogicalRename; import org.opensearch.sql.planner.logical.LogicalSort; import org.opensearch.sql.planner.logical.LogicalValues; -import org.opensearch.sql.storage.StorageEngine; import org.opensearch.sql.storage.Table; import org.opensearch.sql.utils.ParseUtils; @@ -97,16 +97,16 @@ public class Analyzer extends AbstractNodeVisitor private final NamedExpressionAnalyzer namedExpressionAnalyzer; - private final StorageEngine storageEngine; + private final CatalogService catalogService; /** * Constructor. */ public Analyzer( ExpressionAnalyzer expressionAnalyzer, - StorageEngine storageEngine) { + CatalogService catalogService) { this.expressionAnalyzer = expressionAnalyzer; - this.storageEngine = storageEngine; + this.catalogService = catalogService; this.selectExpressionAnalyzer = new SelectExpressionAnalyzer(expressionAnalyzer); this.namedExpressionAnalyzer = new NamedExpressionAnalyzer(expressionAnalyzer); } @@ -119,16 +119,33 @@ public LogicalPlan analyze(UnresolvedPlan unresolved, AnalysisContext context) { public LogicalPlan visitRelation(Relation node, AnalysisContext context) { context.push(); TypeEnvironment curEnv = context.peek(); - Table table = storageEngine.getTable(node.getTableName()); + String catalogName = getCatalogName(node); + String tableName = getTableName(node); + if (catalogName != null && !catalogService.getCatalogs().contains(catalogName)) { + tableName = catalogName + "." + tableName; + catalogName = null; + } + Table table = catalogService + .getStorageEngine(catalogName) + .getTable(tableName); table.getFieldTypes().forEach((k, v) -> curEnv.define(new Symbol(Namespace.FIELD_NAME, k), v)); // Put index name or its alias in index namespace on type environment so qualifier // can be removed when analyzing qualified name. The value (expr type) here doesn't matter. curEnv.define(new Symbol(Namespace.INDEX_NAME, node.getTableNameOrAlias()), STRUCT); - return new LogicalRelation(node.getTableName()); + return new LogicalRelation(tableName, table); + } + + private String getTableName(Relation node) { + return node.getTableName(); } + private String getCatalogName(Relation node) { + return node.getCatalogName(); + } + + @Override public LogicalPlan visitRelationSubquery(RelationSubquery node, AnalysisContext context) { LogicalPlan subquery = analyze(node.getChild().get(0), context); diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index 510482c645..99d8aaa882 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -7,6 +7,7 @@ package org.opensearch.sql.ast.dsl; import java.util.Arrays; +import java.util.Collections; import java.util.List; import lombok.experimental.UtilityClass; import org.apache.commons.lang3.tuple.Pair; @@ -71,6 +72,10 @@ public UnresolvedPlan relation(String tableName) { return new Relation(qualifiedName(tableName)); } + public UnresolvedPlan relation(QualifiedName tableName) { + return new Relation(tableName); + } + public UnresolvedPlan relation(String tableName, String alias) { return new Relation(qualifiedName(tableName), alias); } @@ -114,7 +119,7 @@ public static UnresolvedPlan rename(UnresolvedPlan input, Map... maps) { /** * Initialize Values node by rows of literals. * @param values rows in which each row is a list of literal values - * @return Values node + * @return Values node */ @SafeVarargs public UnresolvedPlan values(List... values) { @@ -413,7 +418,8 @@ public static List defaultTopArgs() { } public static RareTopN rareTopN(UnresolvedPlan input, CommandType commandType, - List noOfResults, List groupList, Field... fields) { + List noOfResults, List groupList, + Field... fields) { return new RareTopN(input, commandType, noOfResults, Arrays.asList(fields), groupList) .attach(input); } diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Relation.java b/core/src/main/java/org/opensearch/sql/ast/tree/Relation.java index 462639ddad..c85c928089 100644 --- a/core/src/main/java/org/opensearch/sql/ast/tree/Relation.java +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Relation.java @@ -15,6 +15,7 @@ import lombok.RequiredArgsConstructor; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.UnresolvedExpression; /** @@ -46,9 +47,40 @@ public Relation(UnresolvedExpression tableName, String alias) { /** * Get original table name. Unwrap and get name if table name expression * is actually an Alias. - * @return table name + * In case of federated queries we are assuming single table. + * + * @return table name */ public String getTableName() { + if (tableName.size() == 1 && ((QualifiedName) tableName.get(0)).first().isPresent()) { + return ((QualifiedName) tableName.get(0)).rest().toString(); + } + return tableName.stream() + .map(UnresolvedExpression::toString) + .collect(Collectors.joining(COMMA)); + } + + /** + * Get Catalog Name if present. Since in the initial phase we would be supporting single table + * federation queries, we are making an assumption of one table. + * + * @return catalog name + */ + public String getCatalogName() { + if (tableName.size() == 1) { + if (tableName.get(0) instanceof QualifiedName) { + return ((QualifiedName) tableName.get(0)).first().orElse(null); + } + } + return null; + } + + /** + * Return full qualified table name with catalog. + * + * @return fully qualified table name with catalog. + */ + public String getFullyQualifiedTableNameWithCatalog() { return tableName.stream() .map(UnresolvedExpression::toString) .collect(Collectors.joining(COMMA)); @@ -56,7 +88,8 @@ public String getTableName() { /** * Get original table name or its alias if present in Alias. - * @return table name or its alias + * + * @return table name or its alias */ public String getTableNameOrAlias() { return (alias == null) ? getTableName() : alias; diff --git a/core/src/main/java/org/opensearch/sql/catalog/CatalogService.java b/core/src/main/java/org/opensearch/sql/catalog/CatalogService.java new file mode 100644 index 0000000000..67512f98d7 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/catalog/CatalogService.java @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.catalog; + +import java.util.Set; +import org.opensearch.sql.storage.StorageEngine; + +/** + * Catalog Service defines api for + * providing and managing storage engines and execution engines + * for all the catalogs. + * The storage and execution indirectly make connections to the underlying datastore catalog. + */ +public interface CatalogService { + + StorageEngine getStorageEngine(String catalog); + + Set getCatalogs(); + + void registerOpenSearchStorageEngine(StorageEngine storageEngine); + +} diff --git a/core/src/main/java/org/opensearch/sql/catalog/model/AbstractAuthenticationData.java b/core/src/main/java/org/opensearch/sql/catalog/model/AbstractAuthenticationData.java new file mode 100644 index 0000000000..e6a0dfa538 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/catalog/model/AbstractAuthenticationData.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.catalog.model; + +import com.fasterxml.jackson.annotation.JsonFormat; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import lombok.Getter; +import lombok.Setter; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonTypeInfo( + use = JsonTypeInfo.Id.NAME, + include = JsonTypeInfo.As.EXISTING_PROPERTY, + property = "type", + defaultImpl = AbstractAuthenticationData.class, + visible = true) +@JsonSubTypes({ + @JsonSubTypes.Type(value = BasicAuthenticationData.class, name = "basicauth"), +}) +@Getter +@Setter +public abstract class AbstractAuthenticationData { + + @JsonFormat(with = JsonFormat.Feature.ACCEPT_CASE_INSENSITIVE_PROPERTIES) + private AuthenticationType type; + +} diff --git a/core/src/main/java/org/opensearch/sql/catalog/model/AuthenticationType.java b/core/src/main/java/org/opensearch/sql/catalog/model/AuthenticationType.java new file mode 100644 index 0000000000..3e602c7f62 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/catalog/model/AuthenticationType.java @@ -0,0 +1,10 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.catalog.model; + +public enum AuthenticationType { + BASICAUTH,NO +} diff --git a/core/src/main/java/org/opensearch/sql/catalog/model/BasicAuthenticationData.java b/core/src/main/java/org/opensearch/sql/catalog/model/BasicAuthenticationData.java new file mode 100644 index 0000000000..5ac8a72085 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/catalog/model/BasicAuthenticationData.java @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.catalog.model; + + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +@JsonIgnoreProperties(ignoreUnknown = true) +public class BasicAuthenticationData extends AbstractAuthenticationData { + + @JsonProperty(required = true) + private String username; + + @JsonProperty(required = true) + private String password; + +} diff --git a/core/src/main/java/org/opensearch/sql/catalog/model/CatalogMetadata.java b/core/src/main/java/org/opensearch/sql/catalog/model/CatalogMetadata.java new file mode 100644 index 0000000000..46c1894f6c --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/catalog/model/CatalogMetadata.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.catalog.model; + +import com.fasterxml.jackson.annotation.JsonFormat; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Getter; +import lombok.Setter; + +@JsonIgnoreProperties(ignoreUnknown = true) +@Getter +@Setter +public class CatalogMetadata { + + @JsonProperty(required = true) + private String name; + + @JsonProperty(required = true) + private String uri; + + @JsonProperty(required = true) + @JsonFormat(with = JsonFormat.Feature.ACCEPT_CASE_INSENSITIVE_PROPERTIES) + private ConnectorType connector; + + private AbstractAuthenticationData authentication; + +} diff --git a/core/src/main/java/org/opensearch/sql/catalog/model/ConnectorType.java b/core/src/main/java/org/opensearch/sql/catalog/model/ConnectorType.java new file mode 100644 index 0000000000..b84c68adbf --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/catalog/model/ConnectorType.java @@ -0,0 +1,10 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.catalog.model; + +public enum ConnectorType { + PROMETHEUS,OPENSEARCH +} diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java index 20e91aa6cd..172e1ee778 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java @@ -27,9 +27,9 @@ import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.expression.function.DefaultFunctionResolver; import org.opensearch.sql.expression.function.FunctionBuilder; import org.opensearch.sql.expression.function.FunctionName; -import org.opensearch.sql.expression.function.FunctionResolver; import org.opensearch.sql.expression.function.FunctionSignature; /** @@ -44,6 +44,7 @@ public class AggregatorFunction { /** * Register Aggregation Function. + * * @param repository {@link BuiltinFunctionRepository}. */ public static void register(BuiltinFunctionRepository repository) { @@ -58,9 +59,9 @@ public static void register(BuiltinFunctionRepository repository) { repository.register(stddevPop()); } - private static FunctionResolver avg() { + private static DefaultFunctionResolver avg() { FunctionName functionName = BuiltinFunctionName.AVG.getName(); - return new FunctionResolver( + return new DefaultFunctionResolver( functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), @@ -69,18 +70,18 @@ private static FunctionResolver avg() { ); } - private static FunctionResolver count() { + private static DefaultFunctionResolver count() { FunctionName functionName = BuiltinFunctionName.COUNT.getName(); - FunctionResolver functionResolver = new FunctionResolver(functionName, + DefaultFunctionResolver functionResolver = new DefaultFunctionResolver(functionName, ExprCoreType.coreTypes().stream().collect(Collectors.toMap( type -> new FunctionSignature(functionName, Collections.singletonList(type)), type -> arguments -> new CountAggregator(arguments, INTEGER)))); return functionResolver; } - private static FunctionResolver sum() { + private static DefaultFunctionResolver sum() { FunctionName functionName = BuiltinFunctionName.SUM.getName(); - return new FunctionResolver( + return new DefaultFunctionResolver( functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(INTEGER)), @@ -95,9 +96,9 @@ private static FunctionResolver sum() { ); } - private static FunctionResolver min() { + private static DefaultFunctionResolver min() { FunctionName functionName = BuiltinFunctionName.MIN.getName(); - return new FunctionResolver( + return new DefaultFunctionResolver( functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(INTEGER)), @@ -121,9 +122,9 @@ private static FunctionResolver min() { .build()); } - private static FunctionResolver max() { + private static DefaultFunctionResolver max() { FunctionName functionName = BuiltinFunctionName.MAX.getName(); - return new FunctionResolver( + return new DefaultFunctionResolver( functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(INTEGER)), @@ -148,9 +149,9 @@ private static FunctionResolver max() { ); } - private static FunctionResolver varSamp() { + private static DefaultFunctionResolver varSamp() { FunctionName functionName = BuiltinFunctionName.VARSAMP.getName(); - return new FunctionResolver( + return new DefaultFunctionResolver( functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), @@ -159,9 +160,9 @@ private static FunctionResolver varSamp() { ); } - private static FunctionResolver varPop() { + private static DefaultFunctionResolver varPop() { FunctionName functionName = BuiltinFunctionName.VARPOP.getName(); - return new FunctionResolver( + return new DefaultFunctionResolver( functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), @@ -170,9 +171,9 @@ private static FunctionResolver varPop() { ); } - private static FunctionResolver stddevSamp() { + private static DefaultFunctionResolver stddevSamp() { FunctionName functionName = BuiltinFunctionName.STDDEV_SAMP.getName(); - return new FunctionResolver( + return new DefaultFunctionResolver( functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), @@ -181,9 +182,9 @@ private static FunctionResolver stddevSamp() { ); } - private static FunctionResolver stddevPop() { + private static DefaultFunctionResolver stddevPop() { FunctionName functionName = BuiltinFunctionName.STDDEV_POP.getName(); - return new FunctionResolver( + return new DefaultFunctionResolver( functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), diff --git a/core/src/main/java/org/opensearch/sql/expression/datetime/DateTimeFunction.java b/core/src/main/java/org/opensearch/sql/expression/datetime/DateTimeFunction.java index 0fccacd136..469f7e2011 100644 --- a/core/src/main/java/org/opensearch/sql/expression/datetime/DateTimeFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/datetime/DateTimeFunction.java @@ -37,6 +37,7 @@ import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.expression.function.DefaultFunctionResolver; import org.opensearch.sql.expression.function.FunctionName; import org.opensearch.sql.expression.function.FunctionResolver; @@ -94,7 +95,7 @@ public void register(BuiltinFunctionRepository repository) { * (STRING/DATETIME/TIMESTAMP, LONG) -> DATETIME */ - private FunctionResolver add_date(FunctionName functionName) { + private DefaultFunctionResolver add_date(FunctionName functionName) { return define(functionName, impl(nullMissingHandling(DateTimeFunction::exprAddDateInterval), DATETIME, STRING, INTERVAL), @@ -110,7 +111,7 @@ private FunctionResolver add_date(FunctionName functionName) { ); } - private FunctionResolver adddate() { + private DefaultFunctionResolver adddate() { return add_date(BuiltinFunctionName.ADDDATE.getName()); } @@ -119,7 +120,7 @@ private FunctionResolver adddate() { * Also to construct a date type. The supported signatures: * STRING/DATE/DATETIME/TIMESTAMP -> DATE */ - private FunctionResolver date() { + private DefaultFunctionResolver date() { return define(BuiltinFunctionName.DATE.getName(), impl(nullMissingHandling(DateTimeFunction::exprDate), DATE, STRING), impl(nullMissingHandling(DateTimeFunction::exprDate), DATE, DATE), @@ -127,7 +128,7 @@ private FunctionResolver date() { impl(nullMissingHandling(DateTimeFunction::exprDate), DATE, TIMESTAMP)); } - private FunctionResolver date_add() { + private DefaultFunctionResolver date_add() { return add_date(BuiltinFunctionName.DATE_ADD.getName()); } @@ -138,7 +139,7 @@ private FunctionResolver date_add() { * (DATE, LONG) -> DATE * (STRING/DATETIME/TIMESTAMP, LONG) -> DATETIME */ - private FunctionResolver sub_date(FunctionName functionName) { + private DefaultFunctionResolver sub_date(FunctionName functionName) { return define(functionName, impl(nullMissingHandling(DateTimeFunction::exprSubDateInterval), DATETIME, STRING, INTERVAL), @@ -154,14 +155,14 @@ private FunctionResolver sub_date(FunctionName functionName) { ); } - private FunctionResolver date_sub() { + private DefaultFunctionResolver date_sub() { return sub_date(BuiltinFunctionName.DATE_SUB.getName()); } /** * DAY(STRING/DATE/DATETIME/TIMESTAMP). return the day of the month (1-31). */ - private FunctionResolver day() { + private DefaultFunctionResolver day() { return define(BuiltinFunctionName.DAY.getName(), impl(nullMissingHandling(DateTimeFunction::exprDayOfMonth), INTEGER, DATE), impl(nullMissingHandling(DateTimeFunction::exprDayOfMonth), INTEGER, DATETIME), @@ -175,7 +176,7 @@ private FunctionResolver day() { * return the name of the weekday for date, including Monday, Tuesday, Wednesday, * Thursday, Friday, Saturday and Sunday. */ - private FunctionResolver dayName() { + private DefaultFunctionResolver dayName() { return define(BuiltinFunctionName.DAYNAME.getName(), impl(nullMissingHandling(DateTimeFunction::exprDayName), STRING, DATE), impl(nullMissingHandling(DateTimeFunction::exprDayName), STRING, DATETIME), @@ -187,7 +188,7 @@ private FunctionResolver dayName() { /** * DAYOFMONTH(STRING/DATE/DATETIME/TIMESTAMP). return the day of the month (1-31). */ - private FunctionResolver dayOfMonth() { + private DefaultFunctionResolver dayOfMonth() { return define(BuiltinFunctionName.DAYOFMONTH.getName(), impl(nullMissingHandling(DateTimeFunction::exprDayOfMonth), INTEGER, DATE), impl(nullMissingHandling(DateTimeFunction::exprDayOfMonth), INTEGER, DATETIME), @@ -200,7 +201,7 @@ private FunctionResolver dayOfMonth() { * DAYOFWEEK(STRING/DATE/DATETIME/TIMESTAMP). * return the weekday index for date (1 = Sunday, 2 = Monday, …, 7 = Saturday). */ - private FunctionResolver dayOfWeek() { + private DefaultFunctionResolver dayOfWeek() { return define(BuiltinFunctionName.DAYOFWEEK.getName(), impl(nullMissingHandling(DateTimeFunction::exprDayOfWeek), INTEGER, DATE), impl(nullMissingHandling(DateTimeFunction::exprDayOfWeek), INTEGER, DATETIME), @@ -213,7 +214,7 @@ private FunctionResolver dayOfWeek() { * DAYOFYEAR(STRING/DATE/DATETIME/TIMESTAMP). * return the day of the year for date (1-366). */ - private FunctionResolver dayOfYear() { + private DefaultFunctionResolver dayOfYear() { return define(BuiltinFunctionName.DAYOFYEAR.getName(), impl(nullMissingHandling(DateTimeFunction::exprDayOfYear), INTEGER, DATE), impl(nullMissingHandling(DateTimeFunction::exprDayOfYear), INTEGER, DATETIME), @@ -225,7 +226,7 @@ private FunctionResolver dayOfYear() { /** * FROM_DAYS(LONG). return the date value given the day number N. */ - private FunctionResolver from_days() { + private DefaultFunctionResolver from_days() { return define(BuiltinFunctionName.FROM_DAYS.getName(), impl(nullMissingHandling(DateTimeFunction::exprFromDays), DATE, LONG)); } @@ -233,7 +234,7 @@ private FunctionResolver from_days() { /** * HOUR(STRING/TIME/DATETIME/TIMESTAMP). return the hour value for time. */ - private FunctionResolver hour() { + private DefaultFunctionResolver hour() { return define(BuiltinFunctionName.HOUR.getName(), impl(nullMissingHandling(DateTimeFunction::exprHour), INTEGER, STRING), impl(nullMissingHandling(DateTimeFunction::exprHour), INTEGER, TIME), @@ -255,7 +256,7 @@ private FunctionResolver maketime() { /** * MICROSECOND(STRING/TIME/DATETIME/TIMESTAMP). return the microsecond value for time. */ - private FunctionResolver microsecond() { + private DefaultFunctionResolver microsecond() { return define(BuiltinFunctionName.MICROSECOND.getName(), impl(nullMissingHandling(DateTimeFunction::exprMicrosecond), INTEGER, STRING), impl(nullMissingHandling(DateTimeFunction::exprMicrosecond), INTEGER, TIME), @@ -267,7 +268,7 @@ private FunctionResolver microsecond() { /** * MINUTE(STRING/TIME/DATETIME/TIMESTAMP). return the minute value for time. */ - private FunctionResolver minute() { + private DefaultFunctionResolver minute() { return define(BuiltinFunctionName.MINUTE.getName(), impl(nullMissingHandling(DateTimeFunction::exprMinute), INTEGER, STRING), impl(nullMissingHandling(DateTimeFunction::exprMinute), INTEGER, TIME), @@ -279,7 +280,7 @@ private FunctionResolver minute() { /** * MONTH(STRING/DATE/DATETIME/TIMESTAMP). return the month for date (1-12). */ - private FunctionResolver month() { + private DefaultFunctionResolver month() { return define(BuiltinFunctionName.MONTH.getName(), impl(nullMissingHandling(DateTimeFunction::exprMonth), INTEGER, DATE), impl(nullMissingHandling(DateTimeFunction::exprMonth), INTEGER, DATETIME), @@ -291,7 +292,7 @@ private FunctionResolver month() { /** * MONTHNAME(STRING/DATE/DATETIME/TIMESTAMP). return the full name of the month for date. */ - private FunctionResolver monthName() { + private DefaultFunctionResolver monthName() { return define(BuiltinFunctionName.MONTHNAME.getName(), impl(nullMissingHandling(DateTimeFunction::exprMonthName), STRING, DATE), impl(nullMissingHandling(DateTimeFunction::exprMonthName), STRING, DATETIME), @@ -303,7 +304,7 @@ private FunctionResolver monthName() { /** * QUARTER(STRING/DATE/DATETIME/TIMESTAMP). return the month for date (1-4). */ - private FunctionResolver quarter() { + private DefaultFunctionResolver quarter() { return define(BuiltinFunctionName.QUARTER.getName(), impl(nullMissingHandling(DateTimeFunction::exprQuarter), INTEGER, DATE), impl(nullMissingHandling(DateTimeFunction::exprQuarter), INTEGER, DATETIME), @@ -315,7 +316,7 @@ private FunctionResolver quarter() { /** * SECOND(STRING/TIME/DATETIME/TIMESTAMP). return the second value for time. */ - private FunctionResolver second() { + private DefaultFunctionResolver second() { return define(BuiltinFunctionName.SECOND.getName(), impl(nullMissingHandling(DateTimeFunction::exprSecond), INTEGER, STRING), impl(nullMissingHandling(DateTimeFunction::exprSecond), INTEGER, TIME), @@ -324,7 +325,7 @@ private FunctionResolver second() { ); } - private FunctionResolver subdate() { + private DefaultFunctionResolver subdate() { return sub_date(BuiltinFunctionName.SUBDATE.getName()); } @@ -333,7 +334,7 @@ private FunctionResolver subdate() { * Also to construct a time type. The supported signatures: * STRING/DATE/DATETIME/TIME/TIMESTAMP -> TIME */ - private FunctionResolver time() { + private DefaultFunctionResolver time() { return define(BuiltinFunctionName.TIME.getName(), impl(nullMissingHandling(DateTimeFunction::exprTime), TIME, STRING), impl(nullMissingHandling(DateTimeFunction::exprTime), TIME, DATE), @@ -345,7 +346,7 @@ private FunctionResolver time() { /** * TIME_TO_SEC(STRING/TIME/DATETIME/TIMESTAMP). return the time argument, converted to seconds. */ - private FunctionResolver time_to_sec() { + private DefaultFunctionResolver time_to_sec() { return define(BuiltinFunctionName.TIME_TO_SEC.getName(), impl(nullMissingHandling(DateTimeFunction::exprTimeToSec), LONG, STRING), impl(nullMissingHandling(DateTimeFunction::exprTimeToSec), LONG, TIME), @@ -359,7 +360,7 @@ private FunctionResolver time_to_sec() { * Also to construct a date type. The supported signatures: * STRING/DATE/DATETIME/TIMESTAMP -> DATE */ - private FunctionResolver timestamp() { + private DefaultFunctionResolver timestamp() { return define(BuiltinFunctionName.TIMESTAMP.getName(), impl(nullMissingHandling(DateTimeFunction::exprTimestamp), TIMESTAMP, STRING), impl(nullMissingHandling(DateTimeFunction::exprTimestamp), TIMESTAMP, DATE), @@ -370,7 +371,7 @@ private FunctionResolver timestamp() { /** * TO_DAYS(STRING/DATE/DATETIME/TIMESTAMP). return the day number of the given date. */ - private FunctionResolver to_days() { + private DefaultFunctionResolver to_days() { return define(BuiltinFunctionName.TO_DAYS.getName(), impl(nullMissingHandling(DateTimeFunction::exprToDays), LONG, STRING), impl(nullMissingHandling(DateTimeFunction::exprToDays), LONG, TIMESTAMP), @@ -381,7 +382,7 @@ private FunctionResolver to_days() { /** * WEEK(DATE[,mode]). return the week number for date. */ - private FunctionResolver week() { + private DefaultFunctionResolver week() { return define(BuiltinFunctionName.WEEK.getName(), impl(nullMissingHandling(DateTimeFunction::exprWeekWithoutMode), INTEGER, DATE), impl(nullMissingHandling(DateTimeFunction::exprWeekWithoutMode), INTEGER, DATETIME), @@ -397,7 +398,7 @@ private FunctionResolver week() { /** * YEAR(STRING/DATE/DATETIME/TIMESTAMP). return the year for date (1000-9999). */ - private FunctionResolver year() { + private DefaultFunctionResolver year() { return define(BuiltinFunctionName.YEAR.getName(), impl(nullMissingHandling(DateTimeFunction::exprYear), INTEGER, DATE), impl(nullMissingHandling(DateTimeFunction::exprYear), INTEGER, DATETIME), @@ -414,7 +415,7 @@ private FunctionResolver year() { * (DATETIME, STRING) -> STRING * (TIMESTAMP, STRING) -> STRING */ - private FunctionResolver date_format() { + private DefaultFunctionResolver date_format() { return define(BuiltinFunctionName.DATE_FORMAT.getName(), impl(nullMissingHandling(DateTimeFormatterUtil::getFormattedDate), STRING, STRING, STRING), @@ -711,6 +712,7 @@ private ExprValue exprToDays(ExprValue date) { /** * Week for date implementation for ExprValue. + * * @param date ExprValue of Date/Datetime/Timestamp/String type. * @param mode ExprValue of Integer type. */ @@ -722,6 +724,7 @@ private ExprValue exprWeek(ExprValue date, ExprValue mode) { /** * Week for date implementation for ExprValue. * When mode is not specified default value mode 0 is used for default_week_format. + * * @param date ExprValue of Date/Datetime/Timestamp/String type. * @return ExprValue. */ diff --git a/core/src/main/java/org/opensearch/sql/expression/datetime/IntervalClause.java b/core/src/main/java/org/opensearch/sql/expression/datetime/IntervalClause.java index f4746ebe7a..c5076431cc 100644 --- a/core/src/main/java/org/opensearch/sql/expression/datetime/IntervalClause.java +++ b/core/src/main/java/org/opensearch/sql/expression/datetime/IntervalClause.java @@ -25,7 +25,7 @@ import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; -import org.opensearch.sql.expression.function.FunctionResolver; +import org.opensearch.sql.expression.function.DefaultFunctionResolver; @UtilityClass public class IntervalClause { @@ -44,7 +44,7 @@ public void register(BuiltinFunctionRepository repository) { repository.register(interval()); } - private FunctionResolver interval() { + private DefaultFunctionResolver interval() { return define(BuiltinFunctionName.INTERVAL.getName(), impl(nullMissingHandling(IntervalClause::interval), INTERVAL, INTEGER, STRING), impl(nullMissingHandling(IntervalClause::interval), INTERVAL, LONG, STRING)); diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java index 1f4c885723..545e710f65 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java @@ -29,9 +29,9 @@ public class BuiltinFunctionRepository { private final Map functionResolverMap; /** - * Register {@link FunctionResolver} to the Builtin Function Repository. + * Register {@link DefaultFunctionResolver} to the Builtin Function Repository. * - * @param resolver {@link FunctionResolver} to be registered + * @param resolver {@link DefaultFunctionResolver} to be registered */ public void register(FunctionResolver resolver) { functionResolverMap.put(resolver.getFunctionName(), resolver); diff --git a/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java b/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java new file mode 100644 index 0000000000..7081179162 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import java.util.AbstractMap; +import java.util.Map; +import java.util.PriorityQueue; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.Builder; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.Singular; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.exception.ExpressionEvaluationException; + +/** + * The Function Resolver hold the overload {@link FunctionBuilder} implementation. + * is composed by {@link FunctionName} which identified the function name + * and a map of {@link FunctionSignature} and {@link FunctionBuilder} + * to represent the overloaded implementation + */ +@Builder +@RequiredArgsConstructor +public class DefaultFunctionResolver implements FunctionResolver { + @Getter + private final FunctionName functionName; + @Singular("functionBundle") + private final Map functionBundle; + + /** + * Resolve the {@link FunctionBuilder} by using input {@link FunctionSignature}. + * If the {@link FunctionBuilder} exactly match the input {@link FunctionSignature}, return it. + * If applying the widening rule, found the most match one, return it. + * If nothing found, throw {@link ExpressionEvaluationException} + * + * @return function signature and its builder + */ + @Override + public Pair resolve(FunctionSignature unresolvedSignature) { + PriorityQueue> functionMatchQueue = new PriorityQueue<>( + Map.Entry.comparingByKey()); + + for (FunctionSignature functionSignature : functionBundle.keySet()) { + functionMatchQueue.add( + new AbstractMap.SimpleEntry<>(unresolvedSignature.match(functionSignature), + functionSignature)); + } + Map.Entry bestMatchEntry = functionMatchQueue.peek(); + if (FunctionSignature.NOT_MATCH.equals(bestMatchEntry.getKey())) { + throw new ExpressionEvaluationException( + String.format("%s function expected %s, but get %s", functionName, + formatFunctions(functionBundle.keySet()), + unresolvedSignature.formatTypes() + )); + } else { + FunctionSignature resolvedSignature = bestMatchEntry.getValue(); + return Pair.of(resolvedSignature, functionBundle.get(resolvedSignature)); + } + } + + private String formatFunctions(Set functionSignatures) { + return functionSignatures.stream().map(FunctionSignature::formatTypes) + .collect(Collectors.joining(",", "{", "}")); + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java index dcd65d6b87..1fad333ead 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java @@ -32,9 +32,9 @@ public class FunctionDSL { * @param functions a list of function implementation. * @return FunctionResolver. */ - public static FunctionResolver define(FunctionName functionName, - SerializableFunction>... functions) { + public static DefaultFunctionResolver define(FunctionName functionName, + SerializableFunction>... functions) { return define(functionName, Arrays.asList(functions)); } @@ -45,11 +45,11 @@ public static FunctionResolver define(FunctionName functionName, * @param functions a list of function implementation. * @return FunctionResolver. */ - public static FunctionResolver define(FunctionName functionName, - List>> functions) { + public static DefaultFunctionResolver define(FunctionName functionName, List< + SerializableFunction>> functions) { - FunctionResolver.FunctionResolverBuilder builder = FunctionResolver.builder(); + DefaultFunctionResolver.DefaultFunctionResolverBuilder builder + = DefaultFunctionResolver.builder(); builder.functionName(functionName); for (Function> func : functions) { Pair functionBuilder = func.apply(functionName); diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionResolver.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionResolver.java index 06d0fb673c..1635b6f846 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/FunctionResolver.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionResolver.java @@ -5,64 +5,14 @@ package org.opensearch.sql.expression.function; -import java.util.AbstractMap; -import java.util.Map; -import java.util.PriorityQueue; -import java.util.Set; -import java.util.stream.Collectors; -import lombok.Builder; -import lombok.Getter; -import lombok.RequiredArgsConstructor; -import lombok.Singular; import org.apache.commons.lang3.tuple.Pair; -import org.opensearch.sql.exception.ExpressionEvaluationException; /** - * The Function Resolver hold the overload {@link FunctionBuilder} implementation. - * is composed by {@link FunctionName} which identified the function name - * and a map of {@link FunctionSignature} and {@link FunctionBuilder} - * to represent the overloaded implementation + * An interface for any class that can provide a {@ref FunctionBuilder} + * given a {@ref FunctionSignature}. */ -@Builder -@RequiredArgsConstructor -public class FunctionResolver { - @Getter - private final FunctionName functionName; - @Singular("functionBundle") - private final Map functionBundle; +public interface FunctionResolver { + Pair resolve(FunctionSignature unresolvedSignature); - /** - * Resolve the {@link FunctionBuilder} by using input {@link FunctionSignature}. - * If the {@link FunctionBuilder} exactly match the input {@link FunctionSignature}, return it. - * If applying the widening rule, found the most match one, return it. - * If nothing found, throw {@link ExpressionEvaluationException} - * - * @return function signature and its builder - */ - public Pair resolve(FunctionSignature unresolvedSignature) { - PriorityQueue> functionMatchQueue = new PriorityQueue<>( - Map.Entry.comparingByKey()); - - for (FunctionSignature functionSignature : functionBundle.keySet()) { - functionMatchQueue.add( - new AbstractMap.SimpleEntry<>(unresolvedSignature.match(functionSignature), - functionSignature)); - } - Map.Entry bestMatchEntry = functionMatchQueue.peek(); - if (FunctionSignature.NOT_MATCH.equals(bestMatchEntry.getKey())) { - throw new ExpressionEvaluationException( - String.format("%s function expected %s, but get %s", functionName, - formatFunctions(functionBundle.keySet()), - unresolvedSignature.formatTypes() - )); - } else { - FunctionSignature resolvedSignature = bestMatchEntry.getValue(); - return Pair.of(resolvedSignature, functionBundle.get(resolvedSignature)); - } - } - - private String formatFunctions(Set functionSignatures) { - return functionSignatures.stream().map(FunctionSignature::formatTypes) - .collect(Collectors.joining(",", "{", "}")); - } + FunctionName getFunctionName(); } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java index c3e5cc5594..bb3eb7008b 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java @@ -9,13 +9,9 @@ import static org.opensearch.sql.data.type.ExprCoreType.STRUCT; import com.google.common.collect.ImmutableMap; -import java.util.ArrayList; -import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.stream.Collectors; import lombok.experimental.UtilityClass; -import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; @@ -27,16 +23,6 @@ @UtilityClass public class OpenSearchFunctions { - - public static final int MATCH_MAX_NUM_PARAMETERS = 14; - public static final int MATCH_BOOL_PREFIX_MAX_NUM_PARAMETERS = 9; - public static final int MATCH_PHRASE_MAX_NUM_PARAMETERS = 5; - public static final int MIN_NUM_PARAMETERS = 2; - public static final int MULTI_MATCH_MAX_NUM_PARAMETERS = 17; - public static final int SIMPLE_QUERY_STRING_MAX_NUM_PARAMETERS = 14; - public static final int QUERY_STRING_MAX_NUM_PARAMETERS = 25; - public static final int MATCH_PHRASE_PREFIX_MAX_NUM_PARAMETERS = 7; - /** * Add functions specific to OpenSearch to repository. */ @@ -58,67 +44,54 @@ private static FunctionResolver highlight() { FunctionName functionName = BuiltinFunctionName.HIGHLIGHT.getName(); FunctionSignature functionSignature = new FunctionSignature(functionName, List.of(STRING)); FunctionBuilder functionBuilder = arguments -> new HighlightExpression(arguments.get(0)); - return new FunctionResolver(functionName, ImmutableMap.of(functionSignature, functionBuilder)); + return new DefaultFunctionResolver(functionName, + ImmutableMap.of(functionSignature, functionBuilder)); } private static FunctionResolver match_bool_prefix() { FunctionName name = BuiltinFunctionName.MATCH_BOOL_PREFIX.getName(); - return getRelevanceFunctionResolver(name, MATCH_BOOL_PREFIX_MAX_NUM_PARAMETERS, STRING); + return new RelevanceFunctionResolver(name, STRING); } private static FunctionResolver match() { FunctionName funcName = BuiltinFunctionName.MATCH.getName(); - return getRelevanceFunctionResolver(funcName, MATCH_MAX_NUM_PARAMETERS, STRING); + return new RelevanceFunctionResolver(funcName, STRING); } private static FunctionResolver match_phrase_prefix() { FunctionName funcName = BuiltinFunctionName.MATCH_PHRASE_PREFIX.getName(); - return getRelevanceFunctionResolver(funcName, MATCH_PHRASE_PREFIX_MAX_NUM_PARAMETERS, STRING); + return new RelevanceFunctionResolver(funcName, STRING); } private static FunctionResolver match_phrase(BuiltinFunctionName matchPhrase) { FunctionName funcName = matchPhrase.getName(); - return getRelevanceFunctionResolver(funcName, MATCH_PHRASE_MAX_NUM_PARAMETERS, STRING); + return new RelevanceFunctionResolver(funcName, STRING); } private static FunctionResolver multi_match() { FunctionName funcName = BuiltinFunctionName.MULTI_MATCH.getName(); - return getRelevanceFunctionResolver(funcName, MULTI_MATCH_MAX_NUM_PARAMETERS, STRUCT); + return new RelevanceFunctionResolver(funcName, STRUCT); } private static FunctionResolver simple_query_string() { FunctionName funcName = BuiltinFunctionName.SIMPLE_QUERY_STRING.getName(); - return getRelevanceFunctionResolver(funcName, SIMPLE_QUERY_STRING_MAX_NUM_PARAMETERS, STRUCT); + return new RelevanceFunctionResolver(funcName, STRUCT); } private static FunctionResolver query_string() { FunctionName funcName = BuiltinFunctionName.QUERY_STRING.getName(); - return getRelevanceFunctionResolver(funcName, QUERY_STRING_MAX_NUM_PARAMETERS, STRUCT); - } - - private static FunctionResolver getRelevanceFunctionResolver( - FunctionName funcName, int maxNumParameters, ExprCoreType firstArgType) { - return new FunctionResolver(funcName, - getRelevanceFunctionSignatureMap(funcName, maxNumParameters, firstArgType)); - } - - private static Map getRelevanceFunctionSignatureMap( - FunctionName funcName, int maxNumParameters, ExprCoreType firstArgType) { - FunctionBuilder buildFunction = args -> new OpenSearchFunction(funcName, args); - var signatureMapBuilder = ImmutableMap.builder(); - for (int numParameters = MIN_NUM_PARAMETERS; - numParameters <= maxNumParameters; numParameters++) { - List args = new ArrayList<>(Collections.nCopies(numParameters - 1, STRING)); - args.add(0, firstArgType); - signatureMapBuilder.put(new FunctionSignature(funcName, args), buildFunction); - } - return signatureMapBuilder.build(); + return new RelevanceFunctionResolver(funcName, STRUCT); } - private static class OpenSearchFunction extends FunctionExpression { + public static class OpenSearchFunction extends FunctionExpression { private final FunctionName functionName; private final List arguments; + /** + * Required argument constructor. + * @param functionName name of the function + * @param arguments a list of expressions + */ public OpenSearchFunction(FunctionName functionName, List arguments) { super(functionName, arguments); this.functionName = functionName; diff --git a/core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java b/core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java new file mode 100644 index 0000000000..e781db8c84 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import java.util.List; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.exception.SemanticCheckException; + +@RequiredArgsConstructor +public class RelevanceFunctionResolver + implements FunctionResolver { + + @Getter + private final FunctionName functionName; + + @Getter + private final ExprType declaredFirstParamType; + + @Override + public Pair resolve(FunctionSignature unresolvedSignature) { + if (!unresolvedSignature.getFunctionName().equals(functionName)) { + throw new SemanticCheckException(String.format("Expected '%s' but got '%s'", + functionName.getFunctionName(), unresolvedSignature.getFunctionName().getFunctionName())); + } + List paramTypes = unresolvedSignature.getParamTypeList(); + ExprType providedFirstParamType = paramTypes.get(0); + + // Check if the first parameter is of the specified type. + if (!declaredFirstParamType.equals(providedFirstParamType)) { + throw new SemanticCheckException( + getWrongParameterErrorMessage(0, providedFirstParamType, declaredFirstParamType)); + } + + // Check if all but the first parameter are of type STRING. + for (int i = 1; i < paramTypes.size(); i++) { + ExprType paramType = paramTypes.get(i); + if (!ExprCoreType.STRING.equals(paramType)) { + throw new SemanticCheckException( + getWrongParameterErrorMessage(i, paramType, ExprCoreType.STRING)); + } + } + + FunctionBuilder buildFunction = + args -> new OpenSearchFunctions.OpenSearchFunction(functionName, args); + return Pair.of(unresolvedSignature, buildFunction); + } + + /** Returns a helpful error message when expected parameter type does not match the + * specified parameter type. + * + * @param i 0-based index of the parameter in a function signature. + * @param paramType the type of the ith parameter at run-time. + * @param expectedType the expected type of the ith parameter + * @return A user-friendly error message that informs of the type difference. + */ + private String getWrongParameterErrorMessage(int i, ExprType paramType, ExprType expectedType) { + return String.format("Expected type %s instead of %s for parameter #%d", + expectedType.typeName(), paramType.typeName(), i + 1); + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java index 81356e789b..c4b106bbf4 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java @@ -23,8 +23,8 @@ import org.opensearch.sql.data.model.ExprShortValue; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.expression.function.DefaultFunctionResolver; import org.opensearch.sql.expression.function.FunctionDSL; -import org.opensearch.sql.expression.function.FunctionResolver; /** * The definition of arithmetic function @@ -49,7 +49,7 @@ public static void register(BuiltinFunctionRepository repository) { repository.register(modules()); } - private static FunctionResolver add() { + private static DefaultFunctionResolver add() { return FunctionDSL.define(BuiltinFunctionName.ADD.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling( @@ -79,7 +79,7 @@ private static FunctionResolver add() { ); } - private static FunctionResolver subtract() { + private static DefaultFunctionResolver subtract() { return FunctionDSL.define(BuiltinFunctionName.SUBTRACT.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling( @@ -109,7 +109,7 @@ private static FunctionResolver subtract() { ); } - private static FunctionResolver multiply() { + private static DefaultFunctionResolver multiply() { return FunctionDSL.define(BuiltinFunctionName.MULTIPLY.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling( @@ -139,7 +139,7 @@ private static FunctionResolver multiply() { ); } - private static FunctionResolver divide() { + private static DefaultFunctionResolver divide() { return FunctionDSL.define(BuiltinFunctionName.DIVIDE.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling( @@ -174,7 +174,7 @@ private static FunctionResolver divide() { } - private static FunctionResolver modules() { + private static DefaultFunctionResolver modules() { return FunctionDSL.define(BuiltinFunctionName.MODULES.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling( diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java index d310b42904..0ce48af48c 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java @@ -36,10 +36,10 @@ import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.expression.function.DefaultFunctionResolver; import org.opensearch.sql.expression.function.FunctionBuilder; import org.opensearch.sql.expression.function.FunctionDSL; import org.opensearch.sql.expression.function.FunctionName; -import org.opensearch.sql.expression.function.FunctionResolver; import org.opensearch.sql.expression.function.FunctionSignature; import org.opensearch.sql.expression.function.SerializableFunction; @@ -88,7 +88,7 @@ public static void register(BuiltinFunctionRepository repository) { * Definition of abs() function. The supported signature of abs() function are INT -> INT LONG -> * LONG FLOAT -> FLOAT DOUBLE -> DOUBLE */ - private static FunctionResolver abs() { + private static DefaultFunctionResolver abs() { return FunctionDSL.define(BuiltinFunctionName.ABS.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling(v -> new ExprByteValue(Math.abs(v.byteValue()))), @@ -115,7 +115,7 @@ private static FunctionResolver abs() { * Definition of ceil(x)/ceiling(x) function. Calculate the next highest integer that x rounds up * to The supported signature of ceil/ceiling function is DOUBLE -> INTEGER */ - private static FunctionResolver ceil() { + private static DefaultFunctionResolver ceil() { return FunctionDSL.define(BuiltinFunctionName.CEIL.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling(v -> new ExprIntegerValue(Math.ceil(v.doubleValue()))), @@ -123,7 +123,7 @@ private static FunctionResolver ceil() { ); } - private static FunctionResolver ceiling() { + private static DefaultFunctionResolver ceiling() { return FunctionDSL.define(BuiltinFunctionName.CEILING.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling(v -> new ExprIntegerValue(Math.ceil(v.doubleValue()))), @@ -138,7 +138,7 @@ private static FunctionResolver ceiling() { * (STRING, INTEGER, INTEGER) -> STRING * (INTEGER, INTEGER, INTEGER) -> STRING */ - private static FunctionResolver conv() { + private static DefaultFunctionResolver conv() { return FunctionDSL.define(BuiltinFunctionName.CONV.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling((x, a, b) -> new ExprStringValue( @@ -161,7 +161,7 @@ private static FunctionResolver conv() { * The supported signature of crc32 function is * STRING -> LONG */ - private static FunctionResolver crc32() { + private static DefaultFunctionResolver crc32() { return FunctionDSL.define(BuiltinFunctionName.CRC32.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling(v -> { @@ -178,7 +178,7 @@ private static FunctionResolver crc32() { * Get the Euler's number. * () -> DOUBLE */ - private static FunctionResolver euler() { + private static DefaultFunctionResolver euler() { return FunctionDSL.define(BuiltinFunctionName.E.getName(), FunctionDSL.impl(() -> new ExprDoubleValue(Math.E), DOUBLE) ); @@ -188,7 +188,7 @@ private static FunctionResolver euler() { * Definition of exp(x) function. Calculate exponent function e to the x The supported signature * of exp function is INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver exp() { + private static DefaultFunctionResolver exp() { return FunctionDSL.define(BuiltinFunctionName.EXP.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -200,7 +200,7 @@ private static FunctionResolver exp() { * Definition of floor(x) function. Calculate the next nearest whole integer that x rounds down to * The supported signature of floor function is DOUBLE -> INTEGER */ - private static FunctionResolver floor() { + private static DefaultFunctionResolver floor() { return FunctionDSL.define(BuiltinFunctionName.FLOOR.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling(v -> new ExprIntegerValue(Math.floor(v.doubleValue()))), @@ -212,7 +212,7 @@ private static FunctionResolver floor() { * Definition of ln(x) function. Calculate the natural logarithm of x The supported signature of * ln function is INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver ln() { + private static DefaultFunctionResolver ln() { return FunctionDSL.define(BuiltinFunctionName.LN.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -225,7 +225,7 @@ private static FunctionResolver ln() { * supported signature of log function is (b: INTEGER/LONG/FLOAT/DOUBLE, x: * INTEGER/LONG/FLOAT/DOUBLE]) -> DOUBLE */ - private static FunctionResolver log() { + private static DefaultFunctionResolver log() { ImmutableList.Builder>> builder = new ImmutableList.Builder<>(); @@ -253,7 +253,7 @@ private static FunctionResolver log() { * Definition of log10(x) function. Calculate base-10 logarithm of x The supported signature of * log function is SHORT/INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver log10() { + private static DefaultFunctionResolver log10() { return FunctionDSL.define(BuiltinFunctionName.LOG10.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -265,7 +265,7 @@ private static FunctionResolver log10() { * Definition of log2(x) function. Calculate base-2 logarithm of x The supported signature of log * function is SHORT/INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver log2() { + private static DefaultFunctionResolver log2() { return FunctionDSL.define(BuiltinFunctionName.LOG2.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -280,7 +280,7 @@ private static FunctionResolver log2() { * (x: INTEGER/LONG/FLOAT/DOUBLE, y: INTEGER/LONG/FLOAT/DOUBLE) * -> wider type between types of x and y */ - private static FunctionResolver mod() { + private static DefaultFunctionResolver mod() { return FunctionDSL.define(BuiltinFunctionName.MOD.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling( @@ -321,7 +321,7 @@ private static FunctionResolver mod() { * Get the value of pi. * () -> DOUBLE */ - private static FunctionResolver pi() { + private static DefaultFunctionResolver pi() { return FunctionDSL.define(BuiltinFunctionName.PI.getName(), FunctionDSL.impl(() -> new ExprDoubleValue(Math.PI), DOUBLE) ); @@ -336,11 +336,11 @@ private static FunctionResolver pi() { * (FLOAT, FLOAT) -> DOUBLE * (DOUBLE, DOUBLE) -> DOUBLE */ - private static FunctionResolver pow() { + private static DefaultFunctionResolver pow() { return FunctionDSL.define(BuiltinFunctionName.POW.getName(), powerFunctionImpl()); } - private static FunctionResolver power() { + private static DefaultFunctionResolver power() { return FunctionDSL.define(BuiltinFunctionName.POWER.getName(), powerFunctionImpl()); } @@ -378,7 +378,7 @@ FunctionBuilder>>> powerFunctionImpl() { * The supported signature of rand function is * ([INTEGER]) -> FLOAT */ - private static FunctionResolver rand() { + private static DefaultFunctionResolver rand() { return FunctionDSL.define(BuiltinFunctionName.RAND.getName(), FunctionDSL.impl(() -> new ExprFloatValue(new Random().nextFloat()), FLOAT), FunctionDSL.impl( @@ -396,7 +396,7 @@ private static FunctionResolver rand() { * (x: FLOAT [, y: INTEGER]) -> FLOAT * (x: DOUBLE [, y: INTEGER]) -> DOUBLE */ - private static FunctionResolver round() { + private static DefaultFunctionResolver round() { return FunctionDSL.define(BuiltinFunctionName.ROUND.getName(), // rand(x) FunctionDSL.impl( @@ -448,7 +448,7 @@ private static FunctionResolver round() { * The supported signature is * SHORT/INTEGER/LONG/FLOAT/DOUBLE -> INTEGER */ - private static FunctionResolver sign() { + private static DefaultFunctionResolver sign() { return FunctionDSL.define(BuiltinFunctionName.SIGN.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -462,7 +462,7 @@ private static FunctionResolver sign() { * The supported signature is * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver sqrt() { + private static DefaultFunctionResolver sqrt() { return FunctionDSL.define(BuiltinFunctionName.SQRT.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -480,7 +480,7 @@ private static FunctionResolver sqrt() { * (x: FLOAT, y: INTEGER) -> DOUBLE * (x: DOUBLE, y: INTEGER) -> DOUBLE */ - private static FunctionResolver truncate() { + private static DefaultFunctionResolver truncate() { return FunctionDSL.define(BuiltinFunctionName.TRUNCATE.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling( @@ -515,7 +515,7 @@ private static FunctionResolver truncate() { * The supported signature of acos function is * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver acos() { + private static DefaultFunctionResolver acos() { return FunctionDSL.define(BuiltinFunctionName.ACOS.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -531,7 +531,7 @@ private static FunctionResolver acos() { * The supported signature of asin function is * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver asin() { + private static DefaultFunctionResolver asin() { return FunctionDSL.define(BuiltinFunctionName.ASIN.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -548,7 +548,7 @@ private static FunctionResolver asin() { * The supported signature of atan function is * (x: INTEGER/LONG/FLOAT/DOUBLE, y: INTEGER/LONG/FLOAT/DOUBLE) -> DOUBLE */ - private static FunctionResolver atan() { + private static DefaultFunctionResolver atan() { ImmutableList.Builder>> builder = new ImmutableList.Builder<>(); @@ -571,7 +571,7 @@ private static FunctionResolver atan() { * The supported signature of atan2 function is * (x: INTEGER/LONG/FLOAT/DOUBLE, y: INTEGER/LONG/FLOAT/DOUBLE) -> DOUBLE */ - private static FunctionResolver atan2() { + private static DefaultFunctionResolver atan2() { ImmutableList.Builder>> builder = new ImmutableList.Builder<>(); @@ -590,7 +590,7 @@ private static FunctionResolver atan2() { * The supported signature of cos function is * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver cos() { + private static DefaultFunctionResolver cos() { return FunctionDSL.define(BuiltinFunctionName.COS.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -604,7 +604,7 @@ private static FunctionResolver cos() { * The supported signature of cot function is * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver cot() { + private static DefaultFunctionResolver cot() { return FunctionDSL.define(BuiltinFunctionName.COT.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -625,7 +625,7 @@ private static FunctionResolver cot() { * The supported signature of degrees function is * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver degrees() { + private static DefaultFunctionResolver degrees() { return FunctionDSL.define(BuiltinFunctionName.DEGREES.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -639,7 +639,7 @@ private static FunctionResolver degrees() { * The supported signature of radians function is * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver radians() { + private static DefaultFunctionResolver radians() { return FunctionDSL.define(BuiltinFunctionName.RADIANS.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -653,7 +653,7 @@ private static FunctionResolver radians() { * The supported signature of sin function is * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver sin() { + private static DefaultFunctionResolver sin() { return FunctionDSL.define(BuiltinFunctionName.SIN.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -667,7 +667,7 @@ private static FunctionResolver sin() { * The supported signature of tan function is * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver tan() { + private static DefaultFunctionResolver tan() { return FunctionDSL.define(BuiltinFunctionName.TAN.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/convert/TypeCastOperator.java b/core/src/main/java/org/opensearch/sql/expression/operator/convert/TypeCastOperator.java index 171563e0a3..23508406ac 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/convert/TypeCastOperator.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/convert/TypeCastOperator.java @@ -39,8 +39,8 @@ import org.opensearch.sql.data.model.ExprTimestampValue; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.expression.function.DefaultFunctionResolver; import org.opensearch.sql.expression.function.FunctionDSL; -import org.opensearch.sql.expression.function.FunctionResolver; @UtilityClass public class TypeCastOperator { @@ -63,7 +63,7 @@ public static void register(BuiltinFunctionRepository repository) { } - private static FunctionResolver castToString() { + private static DefaultFunctionResolver castToString() { return FunctionDSL.define(BuiltinFunctionName.CAST_TO_STRING.getName(), Stream.concat( Arrays.asList(BYTE, SHORT, INTEGER, LONG, FLOAT, DOUBLE, BOOLEAN, TIME, DATE, @@ -76,7 +76,7 @@ private static FunctionResolver castToString() { ); } - private static FunctionResolver castToByte() { + private static DefaultFunctionResolver castToByte() { return FunctionDSL.define(BuiltinFunctionName.CAST_TO_BYTE.getName(), impl(nullMissingHandling( (v) -> new ExprByteValue(Byte.valueOf(v.stringValue()))), BYTE, STRING), @@ -87,7 +87,7 @@ private static FunctionResolver castToByte() { ); } - private static FunctionResolver castToShort() { + private static DefaultFunctionResolver castToShort() { return FunctionDSL.define(BuiltinFunctionName.CAST_TO_SHORT.getName(), impl(nullMissingHandling( (v) -> new ExprShortValue(Short.valueOf(v.stringValue()))), SHORT, STRING), @@ -98,7 +98,7 @@ private static FunctionResolver castToShort() { ); } - private static FunctionResolver castToInt() { + private static DefaultFunctionResolver castToInt() { return FunctionDSL.define(BuiltinFunctionName.CAST_TO_INT.getName(), impl(nullMissingHandling( (v) -> new ExprIntegerValue(Integer.valueOf(v.stringValue()))), INTEGER, STRING), @@ -109,7 +109,7 @@ private static FunctionResolver castToInt() { ); } - private static FunctionResolver castToLong() { + private static DefaultFunctionResolver castToLong() { return FunctionDSL.define(BuiltinFunctionName.CAST_TO_LONG.getName(), impl(nullMissingHandling( (v) -> new ExprLongValue(Long.valueOf(v.stringValue()))), LONG, STRING), @@ -120,7 +120,7 @@ private static FunctionResolver castToLong() { ); } - private static FunctionResolver castToFloat() { + private static DefaultFunctionResolver castToFloat() { return FunctionDSL.define(BuiltinFunctionName.CAST_TO_FLOAT.getName(), impl(nullMissingHandling( (v) -> new ExprFloatValue(Float.valueOf(v.stringValue()))), FLOAT, STRING), @@ -131,7 +131,7 @@ private static FunctionResolver castToFloat() { ); } - private static FunctionResolver castToDouble() { + private static DefaultFunctionResolver castToDouble() { return FunctionDSL.define(BuiltinFunctionName.CAST_TO_DOUBLE.getName(), impl(nullMissingHandling( (v) -> new ExprDoubleValue(Double.valueOf(v.stringValue()))), DOUBLE, STRING), @@ -142,7 +142,7 @@ private static FunctionResolver castToDouble() { ); } - private static FunctionResolver castToBoolean() { + private static DefaultFunctionResolver castToBoolean() { return FunctionDSL.define(BuiltinFunctionName.CAST_TO_BOOLEAN.getName(), impl(nullMissingHandling( (v) -> ExprBooleanValue.of(Boolean.valueOf(v.stringValue()))), BOOLEAN, STRING), @@ -152,7 +152,7 @@ private static FunctionResolver castToBoolean() { ); } - private static FunctionResolver castToDate() { + private static DefaultFunctionResolver castToDate() { return FunctionDSL.define(BuiltinFunctionName.CAST_TO_DATE.getName(), impl(nullMissingHandling( (v) -> new ExprDateValue(v.stringValue())), DATE, STRING), @@ -164,7 +164,7 @@ private static FunctionResolver castToDate() { ); } - private static FunctionResolver castToTime() { + private static DefaultFunctionResolver castToTime() { return FunctionDSL.define(BuiltinFunctionName.CAST_TO_TIME.getName(), impl(nullMissingHandling( (v) -> new ExprTimeValue(v.stringValue())), TIME, STRING), @@ -176,7 +176,7 @@ private static FunctionResolver castToTime() { ); } - private static FunctionResolver castToTimestamp() { + private static DefaultFunctionResolver castToTimestamp() { return FunctionDSL.define(BuiltinFunctionName.CAST_TO_TIMESTAMP.getName(), impl(nullMissingHandling( (v) -> new ExprTimestampValue(v.stringValue())), TIMESTAMP, STRING), @@ -186,7 +186,7 @@ private static FunctionResolver castToTimestamp() { ); } - private static FunctionResolver castToDatetime() { + private static DefaultFunctionResolver castToDatetime() { return FunctionDSL.define(BuiltinFunctionName.CAST_TO_DATETIME.getName(), impl(nullMissingHandling( (v) -> new ExprDatetimeValue(v.stringValue())), DATETIME, STRING), diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/predicate/BinaryPredicateOperator.java b/core/src/main/java/org/opensearch/sql/expression/operator/predicate/BinaryPredicateOperator.java index 4caed12cae..99399249c2 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/predicate/BinaryPredicateOperator.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/predicate/BinaryPredicateOperator.java @@ -23,8 +23,8 @@ import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.expression.function.DefaultFunctionResolver; import org.opensearch.sql.expression.function.FunctionDSL; -import org.opensearch.sql.expression.function.FunctionResolver; import org.opensearch.sql.utils.OperatorUtils; /** @@ -140,25 +140,25 @@ public static void register(BuiltinFunctionRepository repository) { .put(LITERAL_MISSING, LITERAL_MISSING, LITERAL_MISSING) .build(); - private static FunctionResolver and() { + private static DefaultFunctionResolver and() { return FunctionDSL.define(BuiltinFunctionName.AND.getName(), FunctionDSL .impl((v1, v2) -> lookupTableFunction(v1, v2, andTable), BOOLEAN, BOOLEAN, BOOLEAN)); } - private static FunctionResolver or() { + private static DefaultFunctionResolver or() { return FunctionDSL.define(BuiltinFunctionName.OR.getName(), FunctionDSL .impl((v1, v2) -> lookupTableFunction(v1, v2, orTable), BOOLEAN, BOOLEAN, BOOLEAN)); } - private static FunctionResolver xor() { + private static DefaultFunctionResolver xor() { return FunctionDSL.define(BuiltinFunctionName.XOR.getName(), FunctionDSL .impl((v1, v2) -> lookupTableFunction(v1, v2, xorTable), BOOLEAN, BOOLEAN, BOOLEAN)); } - private static FunctionResolver equal() { + private static DefaultFunctionResolver equal() { return FunctionDSL.define(BuiltinFunctionName.EQUAL.getName(), ExprCoreType.coreTypes().stream() .map(type -> FunctionDSL.impl( @@ -168,7 +168,7 @@ private static FunctionResolver equal() { Collectors.toList())); } - private static FunctionResolver notEqual() { + private static DefaultFunctionResolver notEqual() { return FunctionDSL .define(BuiltinFunctionName.NOTEQUAL.getName(), ExprCoreType.coreTypes().stream() .map(type -> FunctionDSL @@ -182,7 +182,7 @@ private static FunctionResolver notEqual() { Collectors.toList())); } - private static FunctionResolver less() { + private static DefaultFunctionResolver less() { return FunctionDSL .define(BuiltinFunctionName.LESS.getName(), ExprCoreType.coreTypes().stream() .map(type -> FunctionDSL @@ -194,7 +194,7 @@ private static FunctionResolver less() { Collectors.toList())); } - private static FunctionResolver lte() { + private static DefaultFunctionResolver lte() { return FunctionDSL .define(BuiltinFunctionName.LTE.getName(), ExprCoreType.coreTypes().stream() .map(type -> FunctionDSL @@ -208,7 +208,7 @@ private static FunctionResolver lte() { Collectors.toList())); } - private static FunctionResolver greater() { + private static DefaultFunctionResolver greater() { return FunctionDSL .define(BuiltinFunctionName.GREATER.getName(), ExprCoreType.coreTypes().stream() .map(type -> FunctionDSL @@ -219,7 +219,7 @@ private static FunctionResolver greater() { Collectors.toList())); } - private static FunctionResolver gte() { + private static DefaultFunctionResolver gte() { return FunctionDSL .define(BuiltinFunctionName.GTE.getName(), ExprCoreType.coreTypes().stream() .map(type -> FunctionDSL @@ -232,19 +232,19 @@ private static FunctionResolver gte() { Collectors.toList())); } - private static FunctionResolver like() { + private static DefaultFunctionResolver like() { return FunctionDSL.define(BuiltinFunctionName.LIKE.getName(), FunctionDSL .impl(FunctionDSL.nullMissingHandling(OperatorUtils::matches), BOOLEAN, STRING, STRING)); } - private static FunctionResolver regexp() { + private static DefaultFunctionResolver regexp() { return FunctionDSL.define(BuiltinFunctionName.REGEXP.getName(), FunctionDSL .impl(FunctionDSL.nullMissingHandling(OperatorUtils::matchesRegexp), INTEGER, STRING, STRING)); } - private static FunctionResolver notLike() { + private static DefaultFunctionResolver notLike() { return FunctionDSL.define(BuiltinFunctionName.NOT_LIKE.getName(), FunctionDSL .impl(FunctionDSL.nullMissingHandling( (v1, v2) -> UnaryPredicateOperator.not(OperatorUtils.matches(v1, v2))), diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/predicate/UnaryPredicateOperator.java b/core/src/main/java/org/opensearch/sql/expression/operator/predicate/UnaryPredicateOperator.java index ca228a6a7e..7d79d9d923 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/predicate/UnaryPredicateOperator.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/predicate/UnaryPredicateOperator.java @@ -20,10 +20,10 @@ import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.expression.function.DefaultFunctionResolver; import org.opensearch.sql.expression.function.FunctionBuilder; import org.opensearch.sql.expression.function.FunctionDSL; import org.opensearch.sql.expression.function.FunctionName; -import org.opensearch.sql.expression.function.FunctionResolver; import org.opensearch.sql.expression.function.FunctionSignature; import org.opensearch.sql.expression.function.SerializableFunction; @@ -46,7 +46,7 @@ public static void register(BuiltinFunctionRepository repository) { repository.register(ifFunction()); } - private static FunctionResolver not() { + private static DefaultFunctionResolver not() { return FunctionDSL.define(BuiltinFunctionName.NOT.getName(), FunctionDSL .impl(UnaryPredicateOperator::not, BOOLEAN, BOOLEAN)); } @@ -67,7 +67,7 @@ public ExprValue not(ExprValue v) { } } - private static FunctionResolver isNull(BuiltinFunctionName funcName) { + private static DefaultFunctionResolver isNull(BuiltinFunctionName funcName) { return FunctionDSL .define(funcName.getName(), Arrays.stream(ExprCoreType.values()) .map(type -> FunctionDSL @@ -76,7 +76,7 @@ private static FunctionResolver isNull(BuiltinFunctionName funcName) { Collectors.toList())); } - private static FunctionResolver isNotNull() { + private static DefaultFunctionResolver isNotNull() { return FunctionDSL .define(BuiltinFunctionName.IS_NOT_NULL.getName(), Arrays.stream(ExprCoreType.values()) .map(type -> FunctionDSL @@ -85,7 +85,7 @@ private static FunctionResolver isNotNull() { Collectors.toList())); } - private static FunctionResolver ifFunction() { + private static DefaultFunctionResolver ifFunction() { FunctionName functionName = BuiltinFunctionName.IF.getName(); List typeList = ExprCoreType.coreTypes(); @@ -94,11 +94,11 @@ private static FunctionResolver ifFunction() { impl((UnaryPredicateOperator::exprIf), v, BOOLEAN, v, v)) .collect(Collectors.toList()); - FunctionResolver functionResolver = FunctionDSL.define(functionName, functionsOne); + DefaultFunctionResolver functionResolver = FunctionDSL.define(functionName, functionsOne); return functionResolver; } - private static FunctionResolver ifNull() { + private static DefaultFunctionResolver ifNull() { FunctionName functionName = BuiltinFunctionName.IFNULL.getName(); List typeList = ExprCoreType.coreTypes(); @@ -107,15 +107,15 @@ private static FunctionResolver ifNull() { impl((UnaryPredicateOperator::exprIfNull), v, v, v)) .collect(Collectors.toList()); - FunctionResolver functionResolver = FunctionDSL.define(functionName, functionsOne); + DefaultFunctionResolver functionResolver = FunctionDSL.define(functionName, functionsOne); return functionResolver; } - private static FunctionResolver nullIf() { + private static DefaultFunctionResolver nullIf() { FunctionName functionName = BuiltinFunctionName.NULLIF.getName(); List typeList = ExprCoreType.coreTypes(); - FunctionResolver functionResolver = + DefaultFunctionResolver functionResolver = FunctionDSL.define(functionName, typeList.stream().map(v -> impl((UnaryPredicateOperator::exprNullIf), v, v, v)) @@ -124,6 +124,7 @@ private static FunctionResolver nullIf() { } /** v2 if v1 is null. + * * @param v1 varable 1 * @param v2 varable 2 * @return v2 if v1 is null @@ -133,6 +134,7 @@ public static ExprValue exprIfNull(ExprValue v1, ExprValue v2) { } /** return null if v1 equls to v2. + * * @param v1 varable 1 * @param v2 varable 2 * @return null if v1 equls to v2 diff --git a/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java b/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java index 372540b4e9..8035728d19 100644 --- a/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java @@ -18,8 +18,8 @@ import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.expression.function.DefaultFunctionResolver; import org.opensearch.sql.expression.function.FunctionName; -import org.opensearch.sql.expression.function.FunctionResolver; import org.opensearch.sql.expression.function.SerializableBiFunction; import org.opensearch.sql.expression.function.SerializableTriFunction; @@ -63,7 +63,7 @@ public void register(BuiltinFunctionRepository repository) { * Supports following signatures: * (STRING, INTEGER)/(STRING, INTEGER, INTEGER) -> STRING */ - private FunctionResolver substringSubstr(FunctionName functionName) { + private DefaultFunctionResolver substringSubstr(FunctionName functionName) { return define(functionName, impl(nullMissingHandling(TextFunction::exprSubstrStart), STRING, STRING, INTEGER), @@ -71,11 +71,11 @@ private FunctionResolver substringSubstr(FunctionName functionName) { STRING, STRING, INTEGER, INTEGER)); } - private FunctionResolver substring() { + private DefaultFunctionResolver substring() { return substringSubstr(BuiltinFunctionName.SUBSTRING.getName()); } - private FunctionResolver substr() { + private DefaultFunctionResolver substr() { return substringSubstr(BuiltinFunctionName.SUBSTR.getName()); } @@ -84,7 +84,7 @@ private FunctionResolver substr() { * Supports following signatures: * STRING -> STRING */ - private FunctionResolver ltrim() { + private DefaultFunctionResolver ltrim() { return define(BuiltinFunctionName.LTRIM.getName(), impl(nullMissingHandling((v) -> new ExprStringValue(v.stringValue().stripLeading())), STRING, STRING)); @@ -95,7 +95,7 @@ private FunctionResolver ltrim() { * Supports following signatures: * STRING -> STRING */ - private FunctionResolver rtrim() { + private DefaultFunctionResolver rtrim() { return define(BuiltinFunctionName.RTRIM.getName(), impl(nullMissingHandling((v) -> new ExprStringValue(v.stringValue().stripTrailing())), STRING, STRING)); @@ -108,7 +108,7 @@ private FunctionResolver rtrim() { * Supports following signatures: * STRING -> STRING */ - private FunctionResolver trim() { + private DefaultFunctionResolver trim() { return define(BuiltinFunctionName.TRIM.getName(), impl(nullMissingHandling((v) -> new ExprStringValue(v.stringValue().trim())), STRING, STRING)); @@ -119,7 +119,7 @@ private FunctionResolver trim() { * Supports following signatures: * STRING -> STRING */ - private FunctionResolver lower() { + private DefaultFunctionResolver lower() { return define(BuiltinFunctionName.LOWER.getName(), impl(nullMissingHandling((v) -> new ExprStringValue((v.stringValue().toLowerCase()))), STRING, STRING) @@ -131,7 +131,7 @@ private FunctionResolver lower() { * Supports following signatures: * STRING -> STRING */ - private FunctionResolver upper() { + private DefaultFunctionResolver upper() { return define(BuiltinFunctionName.UPPER.getName(), impl(nullMissingHandling((v) -> new ExprStringValue((v.stringValue().toUpperCase()))), STRING, STRING) @@ -145,7 +145,7 @@ private FunctionResolver upper() { * Supports following signatures: * (STRING, STRING) -> STRING */ - private FunctionResolver concat() { + private DefaultFunctionResolver concat() { return define(BuiltinFunctionName.CONCAT.getName(), impl(nullMissingHandling((str1, str2) -> new ExprStringValue(str1.stringValue() + str2.stringValue())), STRING, STRING, STRING)); @@ -158,7 +158,7 @@ private FunctionResolver concat() { * Supports following signatures: * (STRING, STRING, STRING) -> STRING */ - private FunctionResolver concat_ws() { + private DefaultFunctionResolver concat_ws() { return define(BuiltinFunctionName.CONCAT_WS.getName(), impl(nullMissingHandling((sep, str1, str2) -> new ExprStringValue(str1.stringValue() + sep.stringValue() + str2.stringValue())), @@ -170,7 +170,7 @@ private FunctionResolver concat_ws() { * Supports following signatures: * STRING -> INTEGER */ - private FunctionResolver length() { + private DefaultFunctionResolver length() { return define(BuiltinFunctionName.LENGTH.getName(), impl(nullMissingHandling((str) -> new ExprIntegerValue(str.stringValue().getBytes().length)), INTEGER, STRING)); @@ -181,7 +181,7 @@ private FunctionResolver length() { * Supports following signatures: * (STRING, STRING) -> INTEGER */ - private FunctionResolver strcmp() { + private DefaultFunctionResolver strcmp() { return define(BuiltinFunctionName.STRCMP.getName(), impl(nullMissingHandling((str1, str2) -> new ExprIntegerValue(Integer.compare( @@ -194,7 +194,7 @@ private FunctionResolver strcmp() { * Supports following signatures: * (STRING, INTEGER) -> STRING */ - private FunctionResolver right() { + private DefaultFunctionResolver right() { return define(BuiltinFunctionName.RIGHT.getName(), impl(nullMissingHandling(TextFunction::exprRight), STRING, STRING, INTEGER)); } @@ -204,7 +204,7 @@ private FunctionResolver right() { * Supports following signature: * (STRING, INTEGER) -> STRING */ - private FunctionResolver left() { + private DefaultFunctionResolver left() { return define(BuiltinFunctionName.LEFT.getName(), impl(nullMissingHandling(TextFunction::exprLeft), STRING, STRING, INTEGER)); } @@ -216,7 +216,7 @@ private FunctionResolver left() { * Supports following signature: * STRING -> INTEGER */ - private FunctionResolver ascii() { + private DefaultFunctionResolver ascii() { return define(BuiltinFunctionName.ASCII.getName(), impl(nullMissingHandling(TextFunction::exprAscii), INTEGER, STRING)); } @@ -231,7 +231,7 @@ private FunctionResolver ascii() { * (STRING, STRING) -> INTEGER * (STRING, STRING, INTEGER) -> INTEGER */ - private FunctionResolver locate() { + private DefaultFunctionResolver locate() { return define(BuiltinFunctionName.LOCATE.getName(), impl(nullMissingHandling( (SerializableBiFunction) @@ -248,7 +248,7 @@ private FunctionResolver locate() { * Supports following signature: * (STRING, STRING, STRING) -> STRING */ - private FunctionResolver replace() { + private DefaultFunctionResolver replace() { return define(BuiltinFunctionName.REPLACE.getName(), impl(nullMissingHandling(TextFunction::exprReplace), STRING, STRING, STRING, STRING)); } diff --git a/core/src/main/java/org/opensearch/sql/expression/window/WindowFunctions.java b/core/src/main/java/org/opensearch/sql/expression/window/WindowFunctions.java index 2851dd9f6b..a3baf08ff3 100644 --- a/core/src/main/java/org/opensearch/sql/expression/window/WindowFunctions.java +++ b/core/src/main/java/org/opensearch/sql/expression/window/WindowFunctions.java @@ -13,9 +13,9 @@ import lombok.experimental.UtilityClass; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.expression.function.DefaultFunctionResolver; import org.opensearch.sql.expression.function.FunctionBuilder; import org.opensearch.sql.expression.function.FunctionName; -import org.opensearch.sql.expression.function.FunctionResolver; import org.opensearch.sql.expression.function.FunctionSignature; import org.opensearch.sql.expression.window.ranking.DenseRankFunction; import org.opensearch.sql.expression.window.ranking.RankFunction; @@ -30,6 +30,7 @@ public class WindowFunctions { /** * Register all window functions to function repository. + * * @param repository function repository */ public void register(BuiltinFunctionRepository repository) { @@ -38,23 +39,24 @@ public void register(BuiltinFunctionRepository repository) { repository.register(denseRank()); } - private FunctionResolver rowNumber() { + private DefaultFunctionResolver rowNumber() { return rankingFunction(BuiltinFunctionName.ROW_NUMBER.getName(), RowNumberFunction::new); } - private FunctionResolver rank() { + private DefaultFunctionResolver rank() { return rankingFunction(BuiltinFunctionName.RANK.getName(), RankFunction::new); } - private FunctionResolver denseRank() { + private DefaultFunctionResolver denseRank() { return rankingFunction(BuiltinFunctionName.DENSE_RANK.getName(), DenseRankFunction::new); } - private FunctionResolver rankingFunction(FunctionName functionName, - Supplier constructor) { + private DefaultFunctionResolver rankingFunction(FunctionName functionName, + Supplier constructor) { FunctionSignature functionSignature = new FunctionSignature(functionName, emptyList()); FunctionBuilder functionBuilder = arguments -> constructor.get(); - return new FunctionResolver(functionName, ImmutableMap.of(functionSignature, functionBuilder)); + return new DefaultFunctionResolver(functionName, + ImmutableMap.of(functionSignature, functionBuilder)); } } diff --git a/core/src/main/java/org/opensearch/sql/planner/Planner.java b/core/src/main/java/org/opensearch/sql/planner/Planner.java index 803b2d1931..8333425091 100644 --- a/core/src/main/java/org/opensearch/sql/planner/Planner.java +++ b/core/src/main/java/org/opensearch/sql/planner/Planner.java @@ -6,7 +6,6 @@ package org.opensearch.sql.planner; -import static com.google.common.base.Strings.isNullOrEmpty; import java.util.List; import lombok.RequiredArgsConstructor; @@ -15,7 +14,6 @@ import org.opensearch.sql.planner.logical.LogicalRelation; import org.opensearch.sql.planner.optimizer.LogicalPlanOptimizer; import org.opensearch.sql.planner.physical.PhysicalPlan; -import org.opensearch.sql.storage.StorageEngine; import org.opensearch.sql.storage.Table; /** @@ -24,11 +22,6 @@ @RequiredArgsConstructor public class Planner { - /** - * Storage engine. - */ - private final StorageEngine storageEngine; - private final LogicalPlanOptimizer logicalOptimizer; /** @@ -40,32 +33,31 @@ public class Planner { * @return optimal physical plan */ public PhysicalPlan plan(LogicalPlan plan) { - String tableName = findTableName(plan); - if (isNullOrEmpty(tableName)) { + Table table = findTable(plan); + if (table == null) { return plan.accept(new DefaultImplementor<>(), null); } - - Table table = storageEngine.getTable(tableName); return table.implement( table.optimize(optimize(plan))); } - private String findTableName(LogicalPlan plan) { - return plan.accept(new LogicalPlanNodeVisitor() { + private Table findTable(LogicalPlan plan) { + return plan.accept(new LogicalPlanNodeVisitor() { @Override - public String visitNode(LogicalPlan node, Object context) { + public Table visitNode(LogicalPlan node, Object context) { List children = node.getChild(); if (children.isEmpty()) { - return ""; + return null; } return children.get(0).accept(this, context); } @Override - public String visitRelation(LogicalRelation node, Object context) { - return node.getRelationName(); + public Table visitRelation(LogicalRelation node, Object context) { + return node.getTable(); } + }, null); } diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java index cdd3d3a103..005a5d84fd 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java @@ -13,6 +13,7 @@ import java.util.Map; import lombok.experimental.UtilityClass; import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.expression.Expression; @@ -21,6 +22,7 @@ import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.aggregation.NamedAggregator; import org.opensearch.sql.expression.window.WindowDefinition; +import org.opensearch.sql.storage.Table; /** * Logical Plan DSL. @@ -37,8 +39,8 @@ public static LogicalPlan filter(LogicalPlan input, Expression expression) { return new LogicalFilter(input, expression); } - public static LogicalPlan relation(String tableName) { - return new LogicalRelation(tableName); + public static LogicalPlan relation(String tableName, Table table) { + return new LogicalRelation(tableName, table); } public static LogicalPlan rename( diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalRelation.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalRelation.java index cc1925b123..a49c3d5cbe 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalRelation.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalRelation.java @@ -10,6 +10,7 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; +import org.opensearch.sql.storage.Table; /** * Logical Relation represent the data source. @@ -17,15 +18,20 @@ @ToString @EqualsAndHashCode(callSuper = true) public class LogicalRelation extends LogicalPlan { + @Getter private final String relationName; + @Getter + private final Table table; + /** * Constructor of LogicalRelation. */ - public LogicalRelation(String relationName) { + public LogicalRelation(String relationName, Table table) { super(ImmutableList.of()); this.relationName = relationName; + this.table = table; } @Override diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/AggregationOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/AggregationOperator.java index 5e05286bbc..d71089d990 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/AggregationOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/AggregationOperator.java @@ -55,14 +55,19 @@ public AggregationOperator(PhysicalPlan input, List aggregatorL List groupByExprList) { this.input = input; this.aggregatorList = aggregatorList; + this.groupByExprList = groupByExprList; if (hasSpan(groupByExprList)) { + // span expression is always the first expression in group list if exist. this.span = groupByExprList.get(0); - this.groupByExprList = groupByExprList.subList(1, groupByExprList.size()); + this.collector = + Collector.Builder.build( + this.span, groupByExprList.subList(1, groupByExprList.size()), this.aggregatorList); + } else { this.span = null; - this.groupByExprList = groupByExprList; + this.collector = + Collector.Builder.build(this.span, this.groupByExprList, this.aggregatorList); } - this.collector = Collector.Builder.build(this.span, this.groupByExprList, this.aggregatorList); } @Override diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index d4d72dd1d7..ea3bd6f3db 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -22,7 +22,6 @@ import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; import static org.opensearch.sql.ast.dsl.AstDSL.relation; import static org.opensearch.sql.ast.dsl.AstDSL.span; -import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral; import static org.opensearch.sql.ast.tree.Sort.NullOrder; import static org.opensearch.sql.ast.tree.Sort.SortOption; import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC; @@ -48,6 +47,7 @@ import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.HighlightFunction; import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.SpanUnit; import org.opensearch.sql.ast.tree.AD; import org.opensearch.sql.ast.tree.Kmeans; @@ -75,17 +75,51 @@ class AnalyzerTest extends AnalyzerTestBase { public void filter_relation() { assertAnalyzeEqual( LogicalPlanDSL.filter( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), dsl.equal(DSL.ref("integer_value", INTEGER), DSL.literal(integerValue(1)))), AstDSL.filter( AstDSL.relation("schema"), AstDSL.equalTo(AstDSL.field("integer_value"), AstDSL.intLiteral(1)))); } + @Test + public void filter_relation_with_catalog() { + assertAnalyzeEqual( + LogicalPlanDSL.filter( + LogicalPlanDSL.relation("http_total_requests", table), + dsl.equal(DSL.ref("integer_value", INTEGER), DSL.literal(integerValue(1)))), + AstDSL.filter( + AstDSL.relation(AstDSL.qualifiedName("prometheus", "http_total_requests")), + AstDSL.equalTo(AstDSL.field("integer_value"), AstDSL.intLiteral(1)))); + } + + @Test + public void filter_relation_with_escaped_catalog() { + assertAnalyzeEqual( + LogicalPlanDSL.filter( + LogicalPlanDSL.relation("prometheus.http_total_requests", table), + dsl.equal(DSL.ref("integer_value", INTEGER), DSL.literal(integerValue(1)))), + AstDSL.filter( + AstDSL.relation(AstDSL.qualifiedName("prometheus.http_total_requests")), + AstDSL.equalTo(AstDSL.field("integer_value"), AstDSL.intLiteral(1)))); + } + + @Test + public void filter_relation_with_non_existing_catalog() { + assertAnalyzeEqual( + LogicalPlanDSL.filter( + LogicalPlanDSL.relation("test.http_total_requests", table), + dsl.equal(DSL.ref("integer_value", INTEGER), DSL.literal(integerValue(1)))), + AstDSL.filter( + AstDSL.relation(AstDSL.qualifiedName("test", "http_total_requests")), + AstDSL.equalTo(AstDSL.field("integer_value"), AstDSL.intLiteral(1)))); + } + @Test public void head_relation() { assertAnalyzeEqual( - LogicalPlanDSL.limit(LogicalPlanDSL.relation("schema"),10, 0), + LogicalPlanDSL.limit(LogicalPlanDSL.relation("schema", table), + 10, 0), AstDSL.head(AstDSL.relation("schema"), 10, 0)); } @@ -93,7 +127,7 @@ public void head_relation() { public void analyze_filter_relation() { assertAnalyzeEqual( LogicalPlanDSL.filter( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), dsl.equal(DSL.ref("integer_value", INTEGER), DSL.literal(integerValue(1)))), filter(relation("schema"), compare("=", field("integer_value"), intLiteral(1)))); } @@ -103,11 +137,11 @@ public void analyze_filter_aggregation_relation() { assertAnalyzeEqual( LogicalPlanDSL.filter( LogicalPlanDSL.aggregation( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), ImmutableList.of( DSL.named("AVG(integer_value)", dsl.avg(DSL.ref("integer_value", INTEGER))), DSL.named("MIN(integer_value)", dsl.min(DSL.ref("integer_value", INTEGER)))), - ImmutableList.of(DSL.named("string_value", DSL.ref("string_value", STRING)))), + ImmutableList.of(DSL.named("string_value", DSL.ref("string_value", STRING)))), dsl.greater(// Expect to be replaced with reference by expression optimizer DSL.ref("MIN(integer_value)", INTEGER), DSL.literal(integerValue(10)))), AstDSL.filter( @@ -116,7 +150,7 @@ public void analyze_filter_aggregation_relation() { ImmutableList.of( alias("AVG(integer_value)", aggregate("AVG", qualifiedName("integer_value"))), alias("MIN(integer_value)", aggregate("MIN", qualifiedName("integer_value")))), - emptyList(), + emptyList(), ImmutableList.of(alias("string_value", qualifiedName("string_value"))), emptyList()), compare(">", @@ -127,7 +161,7 @@ public void analyze_filter_aggregation_relation() { public void rename_relation() { assertAnalyzeEqual( LogicalPlanDSL.rename( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), ImmutableMap.of(DSL.ref("integer_value", INTEGER), DSL.ref("ivalue", INTEGER))), AstDSL.rename( AstDSL.relation("schema"), @@ -138,7 +172,7 @@ public void rename_relation() { public void stats_source() { assertAnalyzeEqual( LogicalPlanDSL.aggregation( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), ImmutableList .of(DSL.named("avg(integer_value)", dsl.avg(DSL.ref("integer_value", INTEGER)))), ImmutableList.of(DSL.named("string_value", DSL.ref("string_value", STRING)))), @@ -159,7 +193,7 @@ public void stats_source() { public void rare_source() { assertAnalyzeEqual( LogicalPlanDSL.rareTopN( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), CommandType.RARE, 10, ImmutableList.of(DSL.ref("string_value", STRING)), @@ -179,7 +213,7 @@ public void rare_source() { public void top_source() { assertAnalyzeEqual( LogicalPlanDSL.rareTopN( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), CommandType.TOP, 5, ImmutableList.of(DSL.ref("string_value", STRING)), @@ -223,7 +257,7 @@ public void rename_to_invalid_expression() { public void project_source() { assertAnalyzeEqual( LogicalPlanDSL.project( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), DSL.named("integer_value", DSL.ref("integer_value", INTEGER)), DSL.named("double_value", DSL.ref("double_value", DOUBLE)) ), @@ -238,7 +272,7 @@ public void project_source() { public void project_highlight() { assertAnalyzeEqual( LogicalPlanDSL.project( - LogicalPlanDSL.highlight(LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.highlight(LogicalPlanDSL.relation("schema", table), DSL.literal("fieldA")), DSL.named("highlight(fieldA)", new HighlightExpression(DSL.literal("fieldA"))) ), @@ -254,7 +288,8 @@ public void project_highlight() { public void remove_source() { assertAnalyzeEqual( LogicalPlanDSL.remove( - LogicalPlanDSL.relation("schema"), DSL.ref("integer_value", INTEGER), DSL.ref( + LogicalPlanDSL.relation("schema", table), + DSL.ref("integer_value", INTEGER), DSL.ref( "double_value", DOUBLE)), AstDSL.projectWithArg( AstDSL.relation("schema"), @@ -306,7 +341,7 @@ public void sort_with_aggregator() { LogicalPlanDSL.project( LogicalPlanDSL.sort( LogicalPlanDSL.aggregation( - LogicalPlanDSL.relation("test"), + LogicalPlanDSL.relation("test", table), ImmutableList.of( DSL.named( "avg(integer_value)", @@ -338,25 +373,25 @@ public void sort_with_aggregator() { public void sort_with_options() { ImmutableMap argOptions = ImmutableMap.builder() - .put(new Argument[]{argument("asc", booleanLiteral(true))}, + .put(new Argument[] {argument("asc", booleanLiteral(true))}, new SortOption(SortOrder.ASC, NullOrder.NULL_FIRST)) - .put(new Argument[]{argument("asc", booleanLiteral(false))}, + .put(new Argument[] {argument("asc", booleanLiteral(false))}, new SortOption(SortOrder.DESC, NullOrder.NULL_LAST)) - .put(new Argument[]{ - argument("asc", booleanLiteral(true)), - argument("nullFirst", booleanLiteral(true))}, + .put(new Argument[] { + argument("asc", booleanLiteral(true)), + argument("nullFirst", booleanLiteral(true))}, new SortOption(SortOrder.ASC, NullOrder.NULL_FIRST)) - .put(new Argument[]{ - argument("asc", booleanLiteral(true)), - argument("nullFirst", booleanLiteral(false))}, + .put(new Argument[] { + argument("asc", booleanLiteral(true)), + argument("nullFirst", booleanLiteral(false))}, new SortOption(SortOrder.ASC, NullOrder.NULL_LAST)) - .put(new Argument[]{ - argument("asc", booleanLiteral(false)), - argument("nullFirst", booleanLiteral(true))}, + .put(new Argument[] { + argument("asc", booleanLiteral(false)), + argument("nullFirst", booleanLiteral(true))}, new SortOption(SortOrder.DESC, NullOrder.NULL_FIRST)) - .put(new Argument[]{ - argument("asc", booleanLiteral(false)), - argument("nullFirst", booleanLiteral(false))}, + .put(new Argument[] { + argument("asc", booleanLiteral(false)), + argument("nullFirst", booleanLiteral(false))}, new SortOption(SortOrder.DESC, NullOrder.NULL_LAST)) .build(); @@ -364,7 +399,7 @@ public void sort_with_options() { assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.sort( - LogicalPlanDSL.relation("test"), + LogicalPlanDSL.relation("test", table), Pair.of(expectOption, DSL.ref("integer_value", INTEGER))), DSL.named("string_value", DSL.ref("string_value", STRING))), AstDSL.project( @@ -381,7 +416,7 @@ public void window_function() { LogicalPlanDSL.project( LogicalPlanDSL.window( LogicalPlanDSL.sort( - LogicalPlanDSL.relation("test"), + LogicalPlanDSL.relation("test", table), ImmutablePair.of(DEFAULT_ASC, DSL.ref("string_value", STRING)), ImmutablePair.of(DEFAULT_ASC, DSL.ref("integer_value", INTEGER))), DSL.named("window_function", dsl.rowNumber()), @@ -406,7 +441,7 @@ public void window_function() { /** * SELECT name FROM ( - * SELECT name, age FROM test + * SELECT name, age FROM test * ) AS schema. */ @Test @@ -414,7 +449,7 @@ public void from_subquery() { assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.project( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), DSL.named("string_value", DSL.ref("string_value", STRING)), DSL.named("integer_value", DSL.ref("integer_value", INTEGER)) ), @@ -436,7 +471,7 @@ public void from_subquery() { /** * SELECT * FROM ( - * SELECT name FROM test + * SELECT name FROM test * ) AS schema. */ @Test @@ -444,7 +479,7 @@ public void select_all_from_subquery() { assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.project( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), DSL.named("string_value", DSL.ref("string_value", STRING))), DSL.named("string_value", DSL.ref("string_value", STRING)) ), @@ -469,7 +504,7 @@ public void sql_group_by_field() { assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.aggregation( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), ImmutableList .of(DSL .named("AVG(integer_value)", dsl.avg(DSL.ref("integer_value", INTEGER)))), @@ -497,7 +532,7 @@ public void sql_group_by_function() { assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.aggregation( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), ImmutableList .of(DSL .named("AVG(integer_value)", dsl.avg(DSL.ref("integer_value", INTEGER)))), @@ -527,7 +562,7 @@ public void sql_group_by_function_in_uppercase() { assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.aggregation( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), ImmutableList .of(DSL .named("AVG(integer_value)", dsl.avg(DSL.ref("integer_value", INTEGER)))), @@ -557,7 +592,7 @@ public void sql_expression_over_one_aggregation() { assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.aggregation( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), ImmutableList .of(DSL.named("avg(integer_value)", dsl.avg(DSL.ref("integer_value", INTEGER)))), @@ -588,10 +623,10 @@ public void sql_expression_over_two_aggregation() { assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.aggregation( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), ImmutableList .of(DSL.named("sum(integer_value)", - dsl.sum(DSL.ref("integer_value", INTEGER))), + dsl.sum(DSL.ref("integer_value", INTEGER))), DSL.named("avg(integer_value)", dsl.avg(DSL.ref("integer_value", INTEGER)))), ImmutableList.of(DSL.named("abs(long_value)", @@ -622,7 +657,7 @@ public void limit_offset() { assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.limit( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), 1, 1 ), DSL.named("integer_value", DSL.ref("integer_value", INTEGER)) @@ -647,7 +682,7 @@ public void named_aggregator_with_condition() { assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.aggregation( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), ImmutableList.of( DSL.named("count(string_value) filter(where integer_value > 1)", dsl.count(DSL.ref("string_value", STRING)).condition(dsl.greater(DSL.ref( @@ -683,7 +718,7 @@ public void named_aggregator_with_condition() { public void ppl_stats_by_fieldAndSpan() { assertAnalyzeEqual( LogicalPlanDSL.aggregation( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), ImmutableList.of( DSL.named("AVG(integer_value)", dsl.avg(DSL.ref("integer_value", INTEGER)))), ImmutableList.of( @@ -703,7 +738,7 @@ public void ppl_stats_by_fieldAndSpan() { public void parse_relation() { assertAnalyzeEqual( LogicalPlanDSL.project( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), ImmutableList.of(DSL.named("string_value", DSL.ref("string_value", STRING))), ImmutableList.of(DSL.named("group", DSL.parsed(DSL.ref("string_value", STRING), DSL.literal("(?.*)"), @@ -717,7 +752,7 @@ public void parse_relation() { AstDSL.alias("string_value", qualifiedName("string_value")) )); } - + @Test public void kmeanns_relation() { Map argumentMap = new HashMap() {{ @@ -726,9 +761,9 @@ public void kmeanns_relation() { put("distance_type", new Literal("COSINE", DataType.STRING)); }}; assertAnalyzeEqual( - new LogicalMLCommons(LogicalPlanDSL.relation("schema"), - "kmeans", argumentMap), - new Kmeans(AstDSL.relation("schema"), argumentMap) + new LogicalMLCommons(LogicalPlanDSL.relation("schema", table), + "kmeans", argumentMap), + new Kmeans(AstDSL.relation("schema"), argumentMap) ); } @@ -739,7 +774,7 @@ public void ad_batchRCF_relation() { put("shingle_size", new Literal(8, DataType.INTEGER)); }}; assertAnalyzeEqual( - new LogicalAD(LogicalPlanDSL.relation("schema"), argumentMap), + new LogicalAD(LogicalPlanDSL.relation("schema", table), argumentMap), new AD(AstDSL.relation("schema"), argumentMap) ); } @@ -752,8 +787,9 @@ public void ad_fitRCF_relation() { put("time_field", new Literal("timestamp", DataType.STRING)); }}; assertAnalyzeEqual( - new LogicalAD(LogicalPlanDSL.relation("schema"), argumentMap), - new AD(AstDSL.relation("schema"), argumentMap) + new LogicalAD(LogicalPlanDSL.relation("schema", table), + argumentMap), + new AD(AstDSL.relation("schema"), argumentMap) ); } } diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java index 09ddca1645..3f912b8fde 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java @@ -8,11 +8,14 @@ import static org.junit.jupiter.api.Assertions.assertEquals; +import com.google.common.collect.ImmutableSet; import java.util.Map; +import java.util.Set; import org.opensearch.sql.analysis.symbol.Namespace; import org.opensearch.sql.analysis.symbol.Symbol; import org.opensearch.sql.analysis.symbol.SymbolTable; import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.catalog.CatalogService; import org.opensearch.sql.config.TestConfig; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.exception.ExpressionEvaluationException; @@ -40,21 +43,31 @@ protected StorageEngine storageEngine() { return new StorageEngine() { @Override public Table getTable(String name) { - return new Table() { - @Override - public Map getFieldTypes() { - return typeMapping(); - } - - @Override - public PhysicalPlan implement(LogicalPlan plan) { - throw new UnsupportedOperationException(); - } - }; + return table; } }; } + @Bean + protected Table table() { + return new Table() { + @Override + public Map getFieldTypes() { + return typeMapping(); + } + + @Override + public PhysicalPlan implement(LogicalPlan plan) { + throw new UnsupportedOperationException(); + } + }; + } + + @Bean + protected CatalogService catalogService() { + return new DefaultCatalogService(); + } + @Bean protected SymbolTable symbolTable() { @@ -94,12 +107,17 @@ protected Environment typeEnv() { @Autowired protected Analyzer analyzer; + @Autowired + protected Table table; + @Autowired protected Environment typeEnv; @Bean - protected Analyzer analyzer(ExpressionAnalyzer expressionAnalyzer, StorageEngine engine) { - return new Analyzer(expressionAnalyzer, engine); + protected Analyzer analyzer(ExpressionAnalyzer expressionAnalyzer, CatalogService catalogService, + StorageEngine storageEngine) { + catalogService.registerOpenSearchStorageEngine(storageEngine); + return new Analyzer(expressionAnalyzer, catalogService); } @Bean @@ -124,4 +142,24 @@ protected void assertAnalyzeEqual(LogicalPlan expected, UnresolvedPlan unresolve protected LogicalPlan analyze(UnresolvedPlan unresolvedPlan) { return analyzer.analyze(unresolvedPlan, analysisContext); } + + private class DefaultCatalogService implements CatalogService { + + private StorageEngine storageEngine; + + @Override + public StorageEngine getStorageEngine(String catalog) { + return storageEngine; + } + + @Override + public Set getCatalogs() { + return ImmutableSet.of("prometheus"); + } + + @Override + public void registerOpenSearchStorageEngine(StorageEngine storageEngine) { + this.storageEngine = storageEngine; + } + } } diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java index 72db402552..c8ce70c418 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java @@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.opensearch.sql.ast.dsl.AstDSL.field; +import static org.opensearch.sql.ast.dsl.AstDSL.floatLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.function; import static org.opensearch.sql.ast.dsl.AstDSL.intLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; @@ -355,6 +356,14 @@ void match_bool_prefix_expression() { AstDSL.unresolvedArg("query", stringLiteral("sample query")))); } + @Test + void match_bool_prefix_wrong_expression() { + assertThrows(SemanticCheckException.class, + () -> analyze(AstDSL.function("match_bool_prefix", + AstDSL.unresolvedArg("field", stringLiteral("fieldA")), + AstDSL.unresolvedArg("query", floatLiteral(1.2f))))); + } + @Test void visit_span() { assertAnalyzeEqual( diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizerTest.java index 1c914990f1..105d8f965d 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizerTest.java @@ -72,7 +72,7 @@ void case_clause_should_be_replaced() { LogicalPlan logicalPlan = LogicalPlanDSL.aggregation( - LogicalPlanDSL.relation("test"), + LogicalPlanDSL.relation("test", table), emptyList(), ImmutableList.of(DSL.named( "CaseClause(whenClauses=[WhenClause(condition==(age, 30), result=\"true\")]," @@ -96,7 +96,7 @@ void aggregation_in_case_when_clause_should_be_replaced() { LogicalPlan logicalPlan = LogicalPlanDSL.aggregation( - LogicalPlanDSL.relation("test"), + LogicalPlanDSL.relation("test", table), ImmutableList.of(DSL.named("AVG(age)", dsl.avg(DSL.ref("age", INTEGER)))), ImmutableList.of(DSL.named("name", DSL.ref("name", STRING)))); @@ -119,7 +119,7 @@ void aggregation_in_case_else_clause_should_be_replaced() { LogicalPlan logicalPlan = LogicalPlanDSL.aggregation( - LogicalPlanDSL.relation("test"), + LogicalPlanDSL.relation("test", table), ImmutableList.of(DSL.named("AVG(age)", dsl.avg(DSL.ref("age", INTEGER)))), ImmutableList.of(DSL.named("name", DSL.ref("name", STRING)))); @@ -137,7 +137,7 @@ void window_expression_should_be_replaced() { LogicalPlan logicalPlan = LogicalPlanDSL.window( LogicalPlanDSL.window( - LogicalPlanDSL.relation("test"), + LogicalPlanDSL.relation("test", table), DSL.named(dsl.rank()), new WindowDefinition(emptyList(), emptyList())), DSL.named(dsl.denseRank()), @@ -163,7 +163,7 @@ Expression optimize(Expression expression, LogicalPlan logicalPlan) { LogicalPlan logicalPlan() { return LogicalPlanDSL.aggregation( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), ImmutableList .of(DSL.named("AVG(age)", dsl.avg(DSL.ref("age", INTEGER))), DSL.named("SUM(age)", dsl.sum(DSL.ref("age", INTEGER)))), diff --git a/core/src/test/java/org/opensearch/sql/analysis/SelectAnalyzeTest.java b/core/src/test/java/org/opensearch/sql/analysis/SelectAnalyzeTest.java index 14aff853aa..7ffc97db3b 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/SelectAnalyzeTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/SelectAnalyzeTest.java @@ -47,7 +47,7 @@ protected Map typeMapping() { public void project_all_from_source() { assertAnalyzeEqual( LogicalPlanDSL.project( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), DSL.named("integer_value", DSL.ref("integer_value", INTEGER)), DSL.named("double_value", DSL.ref("double_value", DOUBLE)), DSL.named("integer_value", DSL.ref("integer_value", INTEGER)), @@ -67,7 +67,7 @@ public void select_and_project_all() { assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.project( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), DSL.named("integer_value", DSL.ref("integer_value", INTEGER)), DSL.named("double_value", DSL.ref("double_value", DOUBLE)) ), @@ -90,7 +90,7 @@ public void remove_and_project_all() { assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.remove( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), DSL.ref("integer_value", INTEGER), DSL.ref("double_value", DOUBLE) ), @@ -112,7 +112,7 @@ public void stats_and_project_all() { assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.aggregation( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), ImmutableList.of(DSL .named("avg(integer_value)", dsl.avg(DSL.ref("integer_value", INTEGER)))), ImmutableList.of(DSL.named("string_value", DSL.ref("string_value", STRING)))), @@ -135,7 +135,7 @@ public void rename_and_project_all() { assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.rename( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), ImmutableMap.of(DSL.ref("integer_value", INTEGER), DSL.ref("ivalue", INTEGER))), DSL.named("double_value", DSL.ref("double_value", DOUBLE)), DSL.named("string_value", DSL.ref("string_value", STRING)), diff --git a/core/src/test/java/org/opensearch/sql/analysis/WindowExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/WindowExpressionAnalyzerTest.java index afc7f33370..3ef279156b 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/WindowExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/WindowExpressionAnalyzerTest.java @@ -20,6 +20,7 @@ import com.google.common.collect.ImmutableMap; import java.util.Collections; import org.apache.commons.lang3.tuple.ImmutablePair; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; @@ -45,12 +46,13 @@ @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class WindowExpressionAnalyzerTest extends AnalyzerTestBase { - private final LogicalPlan child = new LogicalRelation("test"); + private LogicalPlan child; private WindowExpressionAnalyzer analyzer; @BeforeEach void setUp() { + child = new LogicalRelation("test", table); analyzer = new WindowExpressionAnalyzer(expressionAnalyzer, child); } @@ -60,7 +62,7 @@ void should_wrap_child_with_window_and_sort_operator_if_project_item_windowed() assertEquals( LogicalPlanDSL.window( LogicalPlanDSL.sort( - LogicalPlanDSL.relation("test"), + LogicalPlanDSL.relation("test", table), ImmutablePair.of(DEFAULT_ASC, DSL.ref("string_value", STRING)), ImmutablePair.of(DEFAULT_DESC, DSL.ref("integer_value", INTEGER))), DSL.named("row_number", dsl.rowNumber()), @@ -83,7 +85,7 @@ void should_wrap_child_with_window_and_sort_operator_if_project_item_windowed() void should_not_generate_sort_operator_if_no_partition_by_and_order_by_list() { assertEquals( LogicalPlanDSL.window( - LogicalPlanDSL.relation("test"), + LogicalPlanDSL.relation("test", table), DSL.named("row_number", dsl.rowNumber()), new WindowDefinition( ImmutableList.of(), diff --git a/core/src/test/java/org/opensearch/sql/expression/function/BuiltinFunctionRepositoryTest.java b/core/src/test/java/org/opensearch/sql/expression/function/BuiltinFunctionRepositoryTest.java index eca6408d17..61cc560670 100644 --- a/core/src/test/java/org/opensearch/sql/expression/function/BuiltinFunctionRepositoryTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/function/BuiltinFunctionRepositoryTest.java @@ -47,7 +47,7 @@ @ExtendWith(MockitoExtension.class) class BuiltinFunctionRepositoryTest { @Mock - private FunctionResolver mockfunctionResolver; + private DefaultFunctionResolver mockfunctionResolver; @Mock private Map mockMap; @Mock @@ -182,7 +182,7 @@ private FunctionSignature registerFunctionResolver(FunctionName funcName, FunctionSignature resolvedSignature = new FunctionSignature( funcName, ImmutableList.of(targetType)); - FunctionResolver funcResolver = mock(FunctionResolver.class); + DefaultFunctionResolver funcResolver = mock(DefaultFunctionResolver.class); FunctionBuilder funcBuilder = mock(FunctionBuilder.class); when(mockMap.containsKey(eq(funcName))).thenReturn(true); diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionResolverTest.java b/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java similarity index 90% rename from core/src/test/java/org/opensearch/sql/expression/function/FunctionResolverTest.java rename to core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java index 141c1fbd54..baa299b60b 100644 --- a/core/src/test/java/org/opensearch/sql/expression/function/FunctionResolverTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java @@ -22,7 +22,7 @@ @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @ExtendWith(MockitoExtension.class) -class FunctionResolverTest { +class DefaultFunctionResolverTest { @Mock private FunctionSignature exactlyMatchFS; @Mock @@ -47,7 +47,7 @@ class FunctionResolverTest { @Test void resolve_function_signature_exactly_match() { when(functionSignature.match(exactlyMatchFS)).thenReturn(WideningTypeRule.TYPE_EQUAL); - FunctionResolver resolver = new FunctionResolver(functionName, + DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName, ImmutableMap.of(exactlyMatchFS, exactlyMatchBuilder)); assertEquals(exactlyMatchBuilder, resolver.resolve(functionSignature).getValue()); @@ -57,7 +57,7 @@ void resolve_function_signature_exactly_match() { void resolve_function_signature_best_match() { when(functionSignature.match(bestMatchFS)).thenReturn(1); when(functionSignature.match(leastMatchFS)).thenReturn(2); - FunctionResolver resolver = new FunctionResolver(functionName, + DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName, ImmutableMap.of(bestMatchFS, bestMatchBuilder, leastMatchFS, leastMatchBuilder)); assertEquals(bestMatchBuilder, resolver.resolve(functionSignature).getValue()); @@ -68,7 +68,7 @@ void resolve_function_not_match() { when(functionSignature.match(notMatchFS)).thenReturn(WideningTypeRule.IMPOSSIBLE_WIDENING); when(notMatchFS.formatTypes()).thenReturn("[INTEGER,INTEGER]"); when(functionSignature.formatTypes()).thenReturn("[BOOLEAN,BOOLEAN]"); - FunctionResolver resolver = new FunctionResolver(functionName, + DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName, ImmutableMap.of(notMatchFS, notMatchBuilder)); ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class, diff --git a/core/src/test/java/org/opensearch/sql/expression/function/RelevanceFunctionResolverTest.java b/core/src/test/java/org/opensearch/sql/expression/function/RelevanceFunctionResolverTest.java new file mode 100644 index 0000000000..d8547057c4 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/RelevanceFunctionResolverTest.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +import java.util.List; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.exception.SemanticCheckException; + +class RelevanceFunctionResolverTest { + private final FunctionName sampleFuncName = FunctionName.of("sample_function"); + private RelevanceFunctionResolver resolver; + + @BeforeEach + void setUp() { + resolver = new RelevanceFunctionResolver(sampleFuncName, STRING); + } + + @Test + void resolve_correct_name_test() { + var sig = new FunctionSignature(sampleFuncName, List.of(STRING)); + Pair builderPair = resolver.resolve(sig); + assertEquals(sampleFuncName, builderPair.getKey().getFunctionName()); + } + + @Test + void resolve_invalid_name_test() { + var wrongFuncName = FunctionName.of("wrong_func"); + var sig = new FunctionSignature(wrongFuncName, List.of(STRING)); + Exception exception = assertThrows(SemanticCheckException.class, + () -> resolver.resolve(sig)); + assertEquals("Expected 'sample_function' but got 'wrong_func'", + exception.getMessage()); + } + + @Test + void resolve_invalid_first_param_type_test() { + var sig = new FunctionSignature(sampleFuncName, List.of(INTEGER)); + Exception exception = assertThrows(SemanticCheckException.class, + () -> resolver.resolve(sig)); + assertEquals("Expected type STRING instead of INTEGER for parameter #1", + exception.getMessage()); + } + + @Test + void resolve_invalid_third_param_type_test() { + var sig = new FunctionSignature(sampleFuncName, List.of(STRING, STRING, INTEGER, STRING)); + Exception exception = assertThrows(SemanticCheckException.class, + () -> resolver.resolve(sig)); + assertEquals("Expected type STRING instead of INTEGER for parameter #3", + exception.getMessage()); + } +} diff --git a/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java b/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java index 91315a7edc..3a6a95764c 100644 --- a/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java @@ -54,6 +54,7 @@ import org.opensearch.sql.planner.logical.LogicalRelation; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanDSL; +import org.opensearch.sql.storage.Table; @ExtendWith(MockitoExtension.class) class DefaultImplementorTest { @@ -67,6 +68,9 @@ class DefaultImplementorTest { @Mock private NamedExpression groupBy; + @Mock + private Table table; + private final DefaultImplementor implementor = new DefaultImplementor<>(); @Test @@ -150,7 +154,7 @@ public void visitShouldReturnDefaultPhysicalOperator() { @Test public void visitRelationShouldThrowException() { assertThrows(UnsupportedOperationException.class, - () -> new LogicalRelation("test").accept(implementor, null)); + () -> new LogicalRelation("test", table).accept(implementor, null)); } @SuppressWarnings({"rawtypes", "unchecked"}) diff --git a/core/src/test/java/org/opensearch/sql/planner/PlannerTest.java b/core/src/test/java/org/opensearch/sql/planner/PlannerTest.java index c34091dbf7..32e9d1b45b 100644 --- a/core/src/test/java/org/opensearch/sql/planner/PlannerTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/PlannerTest.java @@ -77,7 +77,7 @@ public void planner_test() { LogicalPlanDSL.rename( LogicalPlanDSL.aggregation( LogicalPlanDSL.filter( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", storageEngine.getTable("schema")), dsl.equal(DSL.ref("response", INTEGER), DSL.literal(10)) ), ImmutableList.of(DSL.named("avg(response)", dsl.avg(DSL.ref("response", INTEGER)))), @@ -114,7 +114,7 @@ protected void assertPhysicalPlan(PhysicalPlan expected, LogicalPlan logicalPlan } protected PhysicalPlan analyze(LogicalPlan logicalPlan) { - return new Planner(storageEngine, optimizer).plan(logicalPlan); + return new Planner(optimizer).plan(logicalPlan); } protected class MockTable extends LogicalPlanNodeVisitor implements Table { diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalDedupeTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalDedupeTest.java index 6b5300441b..be6d1fa48c 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalDedupeTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalDedupeTest.java @@ -34,7 +34,7 @@ class LogicalDedupeTest extends AnalyzerTestBase { public void analyze_dedup_with_two_field_with_default_option() { assertAnalyzeEqual( LogicalPlanDSL.dedupe( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), DSL.ref("integer_value", INTEGER), DSL.ref("double_value", DOUBLE)), dedupe( @@ -48,7 +48,7 @@ public void analyze_dedup_with_two_field_with_default_option() { public void analyze_dedup_with_one_field_with_customize_option() { assertAnalyzeEqual( LogicalPlanDSL.dedupe( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), 3, false, true, DSL.ref("integer_value", INTEGER), DSL.ref("double_value", DOUBLE)), diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalEvalTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalEvalTest.java index e59599cd58..d08e7c7ee8 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalEvalTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalEvalTest.java @@ -31,7 +31,7 @@ public class LogicalEvalTest extends AnalyzerTestBase { public void analyze_eval_with_one_field() { assertAnalyzeEqual( LogicalPlanDSL.eval( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), ImmutablePair .of(DSL.ref("absValue", INTEGER), dsl.abs(DSL.ref("integer_value", INTEGER)))), AstDSL.eval( @@ -43,7 +43,7 @@ public void analyze_eval_with_one_field() { public void analyze_eval_with_two_field() { assertAnalyzeEqual( LogicalPlanDSL.eval( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), ImmutablePair .of(DSL.ref("absValue", INTEGER), dsl.abs(DSL.ref("integer_value", INTEGER))), ImmutablePair.of(DSL.ref("iValue", INTEGER), dsl.abs(DSL.ref("absValue", INTEGER)))), diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index 1b81856296..c90ea365d2 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -30,6 +30,7 @@ import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.aggregation.Aggregator; import org.opensearch.sql.expression.window.WindowDefinition; +import org.opensearch.sql.storage.Table; /** * Todo. Temporary added for UT coverage, Will be removed. @@ -43,6 +44,8 @@ class LogicalPlanNodeVisitorTest { ReferenceExpression ref; @Mock Aggregator aggregator; + @Mock + Table table; @Test public void logicalPlanShouldTraversable() { @@ -50,7 +53,7 @@ public void logicalPlanShouldTraversable() { LogicalPlanDSL.rename( LogicalPlanDSL.aggregation( LogicalPlanDSL.rareTopN( - LogicalPlanDSL.filter(LogicalPlanDSL.relation("schema"), expression), + LogicalPlanDSL.filter(LogicalPlanDSL.relation("schema", table), expression), CommandType.TOP, ImmutableList.of(expression), expression), @@ -64,7 +67,7 @@ public void logicalPlanShouldTraversable() { @Test public void testAbstractPlanNodeVisitorShouldReturnNull() { - LogicalPlan relation = LogicalPlanDSL.relation("schema"); + LogicalPlan relation = LogicalPlanDSL.relation("schema", table); assertNull(relation.accept(new LogicalPlanNodeVisitor() { }, null)); @@ -119,7 +122,7 @@ public void testAbstractPlanNodeVisitorShouldReturnNull() { assertNull(highlight.accept(new LogicalPlanNodeVisitor() { }, null)); - LogicalPlan mlCommons = new LogicalMLCommons(LogicalPlanDSL.relation("schema"), + LogicalPlan mlCommons = new LogicalMLCommons(LogicalPlanDSL.relation("schema", table), "kmeans", ImmutableMap.builder() .put("centroids", new Literal(3, DataType.INTEGER)) @@ -129,7 +132,7 @@ public void testAbstractPlanNodeVisitorShouldReturnNull() { assertNull(mlCommons.accept(new LogicalPlanNodeVisitor() { }, null)); - LogicalPlan ad = new LogicalAD(LogicalPlanDSL.relation("schema"), + LogicalPlan ad = new LogicalAD(LogicalPlanDSL.relation("schema", table), new HashMap() {{ put("shingle_size", new Literal(8, DataType.INTEGER)); put("time_decay", new Literal(0.0001, DataType.DOUBLE)); diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalRelationTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalRelationTest.java index 2e5c099d5f..93448185cd 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalRelationTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalRelationTest.java @@ -9,12 +9,28 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.storage.Table; +@ExtendWith(MockitoExtension.class) class LogicalRelationTest { + @Mock + Table table; + @Test public void logicalRelationHasNoInput() { - LogicalPlan relation = LogicalPlanDSL.relation("index"); + LogicalPlan relation = LogicalPlanDSL.relation("index", table); + assertEquals(0, relation.getChild().size()); + } + + @Test + public void logicalRelationWithCatalogHasNoInput() { + LogicalPlan relation = LogicalPlanDSL.relation("prometheus.index", table); assertEquals(0, relation.getChild().size()); } -} + +} \ No newline at end of file diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalSortTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalSortTest.java index b8178de41f..dd8e76d694 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalSortTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalSortTest.java @@ -36,7 +36,7 @@ class LogicalSortTest extends AnalyzerTestBase { public void analyze_sort_with_two_field_with_default_option() { assertAnalyzeEqual( LogicalPlanDSL.sort( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), ImmutablePair.of(SortOption.DEFAULT_ASC, DSL.ref("integer_value", INTEGER)), ImmutablePair.of(SortOption.DEFAULT_ASC, DSL.ref("double_value", DOUBLE))), sort( @@ -49,7 +49,7 @@ public void analyze_sort_with_two_field_with_default_option() { public void analyze_sort_with_two_field() { assertAnalyzeEqual( LogicalPlanDSL.sort( - LogicalPlanDSL.relation("schema"), + LogicalPlanDSL.relation("schema", table), ImmutablePair.of(SortOption.DEFAULT_DESC, DSL.ref("integer_value", INTEGER)), ImmutablePair.of(SortOption.DEFAULT_ASC, DSL.ref("double_value", DOUBLE))), sort( diff --git a/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java b/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java index 2732ef8d61..d81bcf66cd 100644 --- a/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java @@ -38,14 +38,14 @@ class LogicalPlanOptimizerTest extends AnalyzerTestBase { void filter_merge_filter() { assertEquals( filter( - relation("schema"), + relation("schema", table), dsl.and(dsl.equal(DSL.ref("integer_value", INTEGER), DSL.literal(integerValue(2))), dsl.equal(DSL.ref("integer_value", INTEGER), DSL.literal(integerValue(1)))) ), optimize( filter( filter( - relation("schema"), + relation("schema", table), dsl.equal(DSL.ref("integer_value", INTEGER), DSL.literal(integerValue(1))) ), dsl.equal(DSL.ref("integer_value", INTEGER), DSL.literal(integerValue(2))) @@ -62,7 +62,7 @@ void push_filter_under_sort() { assertEquals( sort( filter( - relation("schema"), + relation("schema", table), dsl.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) ), Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("longV", LONG)) @@ -70,7 +70,7 @@ void push_filter_under_sort() { optimize( filter( sort( - relation("schema"), + relation("schema", table), Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("longV", LONG)) ), dsl.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) @@ -87,7 +87,7 @@ void multiple_filter_should_eventually_be_merged() { assertEquals( sort( filter( - relation("schema"), + relation("schema", table), dsl.and(dsl.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))), dsl.less(DSL.ref("longV", INTEGER), DSL.literal(longValue(1L)))) ), @@ -97,7 +97,7 @@ void multiple_filter_should_eventually_be_merged() { filter( sort( filter( - relation("schema"), + relation("schema", table), dsl.less(DSL.ref("longV", INTEGER), DSL.literal(longValue(1L))) ), Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("longV", LONG)) diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/AggregationOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/AggregationOperatorTest.java index 3b45a11c6c..318499c075 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/AggregationOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/AggregationOperatorTest.java @@ -495,4 +495,17 @@ public void twoBucketsSpanAndLong() { "span", new ExprDateValue("2021-01-07"), "region","iad", "host", "h2", "max", 8)) )); } + + @Test + public void copyOfAggregationOperatorShouldSame() { + AggregationOperator plan = new AggregationOperator(testScan(datetimeInputs), + Collections.singletonList(DSL + .named("count", dsl.count(DSL.ref("second", TIMESTAMP)))), + Collections.singletonList(DSL + .named("span", DSL.span(DSL.ref("second", TIMESTAMP), DSL.literal(6 * 1000), "ms")))); + AggregationOperator copy = new AggregationOperator(plan.getInput(), plan.getAggregatorList(), + plan.getGroupByExprList()); + + assertEquals(plan, copy); + } } diff --git a/docs/user/ppl/admin/security.rst b/docs/user/ppl/admin/security.rst new file mode 100644 index 0000000000..529704574b --- /dev/null +++ b/docs/user/ppl/admin/security.rst @@ -0,0 +1,69 @@ +.. highlight:: sh + +================= +Security Settings +================= + +.. rubric:: Table of contents + +.. contents:: + :local: + :depth: 1 + +Introduction +============ + +User needs ``cluster:admin/opensearch/ppl`` permission to use PPL plugin. User also needs indices level permission ``indices:admin/mappings/get`` to get field mappings and ``indices:data/read/search*`` to search index. + +Using Rest API +============== +**--INTRODUCED 2.1--** + +Example: Create the ppl_role for test_user. then test_user could use PPL to query ``ppl-security-demo`` index. + +1. Create the ppl_role and grand permission to access PPL plugin and access ppl-security-demo index:: + + PUT _plugins/_security/api/roles/ppl_role + { + "cluster_permissions": [ + "cluster:admin/opensearch/ppl" + ], + "index_permissions": [{ + "index_patterns": [ + "ppl-security-demo" + ], + "allowed_actions": [ + "indices:data/read/search*", + "indices:admin/mappings/get" + ] + }] + } + +2. Mapping the test_user to the ppl_role:: + + PUT _plugins/_security/api/rolesmapping/ppl_role + { + "backend_roles" : [], + "hosts" : [], + "users" : ["test_user"] + } + + +Using Security Dashboard +======================== +**--INTRODUCED 2.1--** + +Example: Create ppl_access permission and add to existing role + +1. Create the ppl_access permission:: + + PUT _plugins/_security/api/actiongroups/ppl_access + { + "allowed_actions": [ + "cluster:admin/opensearch/ppl" + ] + } + +2. Grant the ppl_access permission to ppl_test_role + +.. image:: https://user-images.githubusercontent.com/2969395/185448976-6c0aed6b-7540-4b99-92c3-362da8ae3763.png diff --git a/docs/user/ppl/index.rst b/docs/user/ppl/index.rst index 39adfa0902..e4f6224535 100644 --- a/docs/user/ppl/index.rst +++ b/docs/user/ppl/index.rst @@ -30,6 +30,8 @@ The query start with search command and then flowing a set of command delimited - `Plugin Settings `_ + - `Security Settings `_ + - `Monitoring `_ * **Commands** diff --git a/integ-test/build.gradle b/integ-test/build.gradle index 429c360a1b..5e0a53bf1a 100644 --- a/integ-test/build.gradle +++ b/integ-test/build.gradle @@ -56,6 +56,8 @@ configurations.all { resolutionStrategy.force "com.fasterxml.jackson.core:jackson-core:${jackson_version}" resolutionStrategy.force "com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:${jackson_version}" resolutionStrategy.force "com.fasterxml.jackson.core:jackson-databind:${jackson_version}" + resolutionStrategy.force "org.jetbrains.kotlin:kotlin-stdlib:1.6.0" + resolutionStrategy.force "org.jetbrains.kotlin:kotlin-stdlib-common:1.6.0" } dependencies { @@ -187,21 +189,17 @@ task compileJdbc(type: Exec) { } } -/* -BWC test suite was running on OpenDistro which was discontinued and no available anymore for testing. -Test suite is not removed, because it could be reused later between different OpenSearch versions. -*/ -String bwcVersion = "1.13.2.0"; +String bwcVersion = "1.1.0.0"; String baseName = "sqlBwcCluster" String bwcFilePath = "src/test/resources/bwc/" -String bwcOpenDistroPlugin = "opendistro-sql-" + bwcVersion + ".zip" -String bwcRemoteFile = 'https://d3g5vo6xdbdb9a.cloudfront.net/downloads/elasticsearch-plugins/opendistro-sql/' + bwcOpenDistroPlugin +String bwcSqlPlugin = "opensearch-sql-" + bwcVersion + ".zip" +String bwcRemoteFile = "https://ci.opensearch.org/ci/dbc/bundle-build/1.1.0/20210930/linux/x64/builds/opensearch/plugins/" + bwcSqlPlugin 2.times { i -> testClusters { "${baseName}$i" { testDistribution = "ARCHIVE" - versions = ["7.10.2", opensearch_version] + versions = ["1.1.0", opensearch_version] numberOfNodes = 3 plugin(provider(new Callable() { @Override @@ -213,7 +211,7 @@ String bwcRemoteFile = 'https://d3g5vo6xdbdb9a.cloudfront.net/downloads/elastics if (!dir.exists()) { dir.mkdirs() } - File f = new File(dir, bwcOpenDistroPlugin) + File f = new File(dir, bwcSqlPlugin) if (!f.exists()) { new URL(bwcRemoteFile).withInputStream{ ins -> f.withOutputStream{ it << ins }} } diff --git a/integ-test/src/test/java/org/opensearch/sql/bwc/SQLBackwardsCompatibilityIT.java b/integ-test/src/test/java/org/opensearch/sql/bwc/SQLBackwardsCompatibilityIT.java index 079980248f..c32a3336c0 100644 --- a/integ-test/src/test/java/org/opensearch/sql/bwc/SQLBackwardsCompatibilityIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/bwc/SQLBackwardsCompatibilityIT.java @@ -96,7 +96,7 @@ public void testBackwardsCompatibility() throws Exception { Set pluginNames = plugins.stream().map(map -> map.get("name")).collect(Collectors.toSet()); switch (CLUSTER_TYPE) { case OLD: - Assert.assertTrue(pluginNames.contains("opendistro-sql")); + Assert.assertTrue(pluginNames.contains("opensearch-sql")); updateLegacySQLSettings(); loadIndex(Index.ACCOUNT); verifySQLQueries(LEGACY_QUERY_API_ENDPOINT); diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java index 4385c44571..e6845cb154 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java @@ -18,6 +18,7 @@ import org.opensearch.client.Request; import org.opensearch.client.RestClient; import org.opensearch.client.RestHighLevelClient; +import org.opensearch.sql.catalog.CatalogService; import org.opensearch.sql.common.response.ResponseListener; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.executor.ExecutionEngine; @@ -28,6 +29,7 @@ import org.opensearch.sql.opensearch.executor.OpenSearchExecutionEngine; import org.opensearch.sql.opensearch.executor.protector.OpenSearchExecutionProtector; import org.opensearch.sql.opensearch.storage.OpenSearchStorageEngine; +import org.opensearch.sql.plugin.catalog.CatalogServiceImpl; import org.opensearch.sql.ppl.config.PPLServiceConfig; import org.opensearch.sql.ppl.domain.PPLQueryRequest; import org.opensearch.sql.protocol.response.QueryResult; @@ -53,11 +55,12 @@ public void init() { OpenSearchClient client = new OpenSearchRestClient(restClient); AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); - context.registerBean(StorageEngine.class, - () -> new OpenSearchStorageEngine(client, defaultSettings())); context.registerBean(ExecutionEngine.class, () -> new OpenSearchExecutionEngine(client, new OpenSearchExecutionProtector(new AlwaysHealthyMonitor()))); context.register(PPLServiceConfig.class); + OpenSearchStorageEngine openSearchStorageEngine = new OpenSearchStorageEngine(client, defaultSettings()); + CatalogServiceImpl.getInstance().registerOpenSearchStorageEngine(openSearchStorageEngine); + context.registerBean(CatalogService.class, CatalogServiceImpl::getInstance); context.refresh(); pplService = context.getBean(PPLService.class); diff --git a/integ-test/src/test/resources/bwc/.gitignore b/integ-test/src/test/resources/bwc/.gitignore new file mode 100644 index 0000000000..d6b7ef32c8 --- /dev/null +++ b/integ-test/src/test/resources/bwc/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/legacy/build.gradle b/legacy/build.gradle index f605ced7ba..db9d6138f0 100644 --- a/legacy/build.gradle +++ b/legacy/build.gradle @@ -92,6 +92,8 @@ dependencies { implementation group: 'org.json', name: 'json', version:'20180813' implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.10' implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" + // add geo module as dependency. https://github.com/opensearch-project/OpenSearch/pull/4180/. + implementation group: 'org.opensearch.plugin', name: 'geo', version: "${opensearch_version}" api project(':sql') api project(':common') api project(':opensearch') diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/csv/CSVResultsExtractor.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/csv/CSVResultsExtractor.java index 5a16a9ab61..70cdd91452 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/csv/CSVResultsExtractor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/csv/CSVResultsExtractor.java @@ -21,7 +21,7 @@ import org.opensearch.search.aggregations.bucket.MultiBucketsAggregation; import org.opensearch.search.aggregations.bucket.SingleBucketAggregation; import org.opensearch.search.aggregations.metrics.ExtendedStats; -import org.opensearch.search.aggregations.metrics.GeoBounds; +import org.opensearch.geo.search.aggregations.metrics.GeoBounds; import org.opensearch.search.aggregations.metrics.NumericMetricsAggregation; import org.opensearch.search.aggregations.metrics.Percentile; import org.opensearch.search.aggregations.metrics.Percentiles; diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/OpenSearchSQLPluginConfig.java b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/OpenSearchSQLPluginConfig.java index 91b3a58925..b396d896b0 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/OpenSearchSQLPluginConfig.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/OpenSearchSQLPluginConfig.java @@ -7,7 +7,6 @@ package org.opensearch.sql.legacy.plugin; import org.opensearch.client.node.NodeClient; -import org.opensearch.cluster.service.ClusterService; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.expression.config.ExpressionConfig; @@ -34,8 +33,6 @@ @Configuration @Import({ExpressionConfig.class}) public class OpenSearchSQLPluginConfig { - @Autowired - private ClusterService clusterService; @Autowired private NodeClient nodeClient; @@ -48,7 +45,7 @@ public class OpenSearchSQLPluginConfig { @Bean public OpenSearchClient client() { - return new OpenSearchNodeClient(clusterService, nodeClient); + return new OpenSearchNodeClient(nodeClient); } @Bean diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSQLQueryAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSQLQueryAction.java index 51484feda7..0db08398b8 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSQLQueryAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSQLQueryAction.java @@ -14,6 +14,7 @@ import java.io.IOException; import java.security.PrivilegedExceptionAction; import java.util.List; +import javax.xml.catalog.Catalog; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.client.node.NodeClient; @@ -23,6 +24,7 @@ import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestStatus; +import org.opensearch.sql.catalog.CatalogService; import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.common.response.ResponseListener; import org.opensearch.sql.common.setting.Settings; @@ -61,13 +63,16 @@ public class RestSQLQueryAction extends BaseRestHandler { */ private final Settings pluginSettings; + private final CatalogService catalogService; + /** * Constructor of RestSQLQueryAction. */ - public RestSQLQueryAction(ClusterService clusterService, Settings pluginSettings) { + public RestSQLQueryAction(ClusterService clusterService, Settings pluginSettings, CatalogService catalogService) { super(); this.clusterService = clusterService; this.pluginSettings = pluginSettings; + this.catalogService = catalogService; } @Override @@ -124,6 +129,7 @@ private SQLService createSQLService(NodeClient client) { context.registerBean(ClusterService.class, () -> clusterService); context.registerBean(NodeClient.class, () -> client); context.registerBean(Settings.class, () -> pluginSettings); + context.registerBean(CatalogService.class, () -> catalogService); context.register(OpenSearchSQLPluginConfig.class); context.register(SQLServiceConfig.class); context.refresh(); diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java index 06d1ba1c73..ab146404f8 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java @@ -35,6 +35,7 @@ import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestStatus; +import org.opensearch.sql.catalog.CatalogService; import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.common.utils.QueryContext; import org.opensearch.sql.exception.ExpressionEvaluationException; @@ -89,10 +90,11 @@ public class RestSqlAction extends BaseRestHandler { private final RestSQLQueryAction newSqlQueryHandler; public RestSqlAction(Settings settings, ClusterService clusterService, - org.opensearch.sql.common.setting.Settings pluginSettings) { + org.opensearch.sql.common.setting.Settings pluginSettings, + CatalogService catalogService) { super(); this.allowExplicitIndex = MULTI_ALLOW_EXPLICIT_INDEX.get(settings); - this.newSqlQueryHandler = new RestSQLQueryAction(clusterService, pluginSettings); + this.newSqlQueryHandler = new RestSQLQueryAction(clusterService, pluginSettings, catalogService); } @Override diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/maker/AggMaker.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/maker/AggMaker.java index b56692e453..87125721c0 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/maker/AggMaker.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/maker/AggMaker.java @@ -25,6 +25,7 @@ import org.opensearch.common.xcontent.XContentParser; import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.common.xcontent.json.JsonXContentParser; +import org.opensearch.geo.search.aggregations.bucket.geogrid.GeoHashGridAggregationBuilder; import org.opensearch.join.aggregations.JoinAggregationBuilders; import org.opensearch.script.Script; import org.opensearch.script.ScriptType; @@ -34,7 +35,7 @@ import org.opensearch.search.aggregations.BucketOrder; import org.opensearch.search.aggregations.InternalOrder; import org.opensearch.search.aggregations.bucket.filter.FilterAggregationBuilder; -import org.opensearch.search.aggregations.bucket.geogrid.GeoGridAggregationBuilder; +import org.opensearch.geo.search.aggregations.bucket.geogrid.GeoGridAggregationBuilder; import org.opensearch.search.aggregations.bucket.histogram.DateHistogramAggregationBuilder; import org.opensearch.search.aggregations.bucket.histogram.DateHistogramInterval; import org.opensearch.search.aggregations.bucket.histogram.HistogramAggregationBuilder; @@ -44,7 +45,7 @@ import org.opensearch.search.aggregations.bucket.range.RangeAggregationBuilder; import org.opensearch.search.aggregations.bucket.terms.IncludeExclude; import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; -import org.opensearch.search.aggregations.metrics.GeoBoundsAggregationBuilder; +import org.opensearch.geo.search.aggregations.metrics.GeoBoundsAggregationBuilder; import org.opensearch.search.aggregations.metrics.PercentilesAggregationBuilder; import org.opensearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder; import org.opensearch.search.aggregations.metrics.TopHitsAggregationBuilder; @@ -285,7 +286,7 @@ private AggregationBuilder makeRangeGroup(MethodField field) throws SqlParseExce private AggregationBuilder geoBounds(MethodField field) throws SqlParseException { String aggName = gettAggNameFromParamsOrAlias(field); - GeoBoundsAggregationBuilder boundsBuilder = AggregationBuilders.geoBounds(aggName); + GeoBoundsAggregationBuilder boundsBuilder = new GeoBoundsAggregationBuilder(aggName); String value; for (KVValue kv : field.getParams()) { value = kv.value.toString(); @@ -472,7 +473,7 @@ private AbstractAggregationBuilder scriptedMetric(MethodField field) throws SqlP private AggregationBuilder geohashGrid(MethodField field) throws SqlParseException { String aggName = gettAggNameFromParamsOrAlias(field); - GeoGridAggregationBuilder geoHashGrid = AggregationBuilders.geohashGrid(aggName); + GeoGridAggregationBuilder geoHashGrid = new GeoHashGridAggregationBuilder(aggName); String value; for (KVValue kv : field.getParams()) { value = kv.value.toString(); diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionTest.java index c3046785dc..56d153eb9d 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionTest.java @@ -22,6 +22,7 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.sql.catalog.CatalogService; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.sql.domain.SQLQueryRequest; import org.opensearch.threadpool.ThreadPool; @@ -40,6 +41,9 @@ public class RestSQLQueryActionTest { @Mock private Settings settings; + @Mock + private CatalogService catalogService; + @Before public void setup() { nodeClient = new NodeClient(org.opensearch.common.settings.Settings.EMPTY, threadPool); @@ -55,7 +59,7 @@ public void handleQueryThatCanSupport() { QUERY_API_ENDPOINT, ""); - RestSQLQueryAction queryAction = new RestSQLQueryAction(clusterService, settings); + RestSQLQueryAction queryAction = new RestSQLQueryAction(clusterService, settings, catalogService); assertNotSame(NOT_SUPPORTED_YET, queryAction.prepareRequest(request, nodeClient)); } @@ -67,7 +71,7 @@ public void handleExplainThatCanSupport() { EXPLAIN_API_ENDPOINT, ""); - RestSQLQueryAction queryAction = new RestSQLQueryAction(clusterService, settings); + RestSQLQueryAction queryAction = new RestSQLQueryAction(clusterService, settings, catalogService); assertNotSame(NOT_SUPPORTED_YET, queryAction.prepareRequest(request, nodeClient)); } @@ -80,7 +84,7 @@ public void skipQueryThatNotSupport() { QUERY_API_ENDPOINT, ""); - RestSQLQueryAction queryAction = new RestSQLQueryAction(clusterService, settings); + RestSQLQueryAction queryAction = new RestSQLQueryAction(clusterService, settings, catalogService); assertSame(NOT_SUPPORTED_YET, queryAction.prepareRequest(request, nodeClient)); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java index db35f3580c..80a2fb8604 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java @@ -9,7 +9,7 @@ import com.carrotsearch.hppc.cursors.ObjectObjectCursor; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import java.io.IOException; +import com.google.common.collect.Streams; import java.util.Arrays; import java.util.Collection; import java.util.List; @@ -18,24 +18,17 @@ import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; -import org.apache.logging.log4j.ThreadContext; import org.opensearch.action.admin.indices.get.GetIndexResponse; -import org.opensearch.action.support.IndicesOptions; +import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; +import org.opensearch.action.admin.indices.settings.get.GetSettingsResponse; import org.opensearch.client.node.NodeClient; -import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.AliasMetadata; -import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; -import org.opensearch.cluster.metadata.MappingMetadata; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.collect.ImmutableOpenMap; import org.opensearch.common.settings.Settings; -import org.opensearch.common.unit.TimeValue; import org.opensearch.index.IndexSettings; import org.opensearch.sql.opensearch.mapping.IndexMapping; import org.opensearch.sql.opensearch.request.OpenSearchRequest; import org.opensearch.sql.opensearch.response.OpenSearchResponse; -import org.opensearch.threadpool.ThreadPool; /** OpenSearch connection by node client. */ public class OpenSearchNodeClient implements OpenSearchClient { @@ -43,23 +36,16 @@ public class OpenSearchNodeClient implements OpenSearchClient { public static final Function> ALL_FIELDS = (anyIndex -> (anyField -> true)); - /** Current cluster state on local node. */ - private final ClusterService clusterService; - /** Node client provided by OpenSearch container. */ private final NodeClient client; /** Index name expression resolver to get concrete index name. */ private final IndexNameExpressionResolver resolver; - private static final String SQL_WORKER_THREAD_POOL_NAME = "sql-worker"; - /** * Constructor of ElasticsearchNodeClient. */ - public OpenSearchNodeClient(ClusterService clusterService, - NodeClient client) { - this.clusterService = clusterService; + public OpenSearchNodeClient(NodeClient client) { this.client = client; this.resolver = new IndexNameExpressionResolver(client.threadPool().getThreadContext()); } @@ -78,14 +64,16 @@ public OpenSearchNodeClient(ClusterService clusterService, @Override public Map getIndexMappings(String... indexExpression) { try { - ClusterState state = clusterService.state(); - String[] concreteIndices = resolveIndexExpression(state, indexExpression); - - return populateIndexMappings( - state.metadata().findMappings(concreteIndices, ALL_FIELDS)); - } catch (IOException e) { + GetMappingsResponse mappingsResponse = client.admin().indices() + .prepareGetMappings(indexExpression) + .setLocal(true) + .get(); + return Streams.stream(mappingsResponse.mappings().iterator()) + .collect(Collectors.toMap(cursor -> cursor.key, + cursor -> new IndexMapping(cursor.value))); + } catch (Exception e) { throw new IllegalStateException( - "Failed to read mapping in cluster state for index pattern [" + indexExpression + "]", e); + "Failed to read mapping for index pattern [" + indexExpression + "]", e); } } @@ -97,19 +85,24 @@ public Map getIndexMappings(String... indexExpression) { */ @Override public Map getIndexMaxResultWindows(String... indexExpression) { - ClusterState state = clusterService.state(); - ImmutableOpenMap indicesMetadata = state.metadata().getIndices(); - String[] concreteIndices = resolveIndexExpression(state, indexExpression); - - ImmutableMap.Builder result = ImmutableMap.builder(); - for (String index : concreteIndices) { - Settings settings = indicesMetadata.get(index).getSettings(); - Integer maxResultWindow = settings.getAsInt("index.max_result_window", - IndexSettings.MAX_RESULT_WINDOW_SETTING.getDefault(settings)); - result.put(index, maxResultWindow); + try { + GetSettingsResponse settingsResponse = + client.admin().indices().prepareGetSettings(indexExpression).setLocal(true).get(); + ImmutableMap.Builder result = ImmutableMap.builder(); + for (ObjectObjectCursor indexToSetting : + settingsResponse.getIndexToSettings()) { + Settings settings = indexToSetting.value; + result.put( + indexToSetting.key, + settings.getAsInt( + IndexSettings.MAX_RESULT_WINDOW_SETTING.getKey(), + IndexSettings.MAX_RESULT_WINDOW_SETTING.getDefault(settings))); + } + return result.build(); + } catch (Exception e) { + throw new IllegalStateException( + "Failed to read setting for index pattern [" + indexExpression + "]", e); } - - return result.build(); } /** @@ -149,9 +142,8 @@ public List indices() { */ @Override public Map meta() { - final ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); - builder.put(META_CLUSTER_NAME, clusterService.getClusterName().value()); - return builder.build(); + return ImmutableMap.of(META_CLUSTER_NAME, + client.settings().get("cluster.name", "opensearch")); } @Override @@ -161,40 +153,12 @@ public void cleanup(OpenSearchRequest request) { @Override public void schedule(Runnable task) { - ThreadPool threadPool = client.threadPool(); - threadPool.schedule( - withCurrentContext(task), - new TimeValue(0), - SQL_WORKER_THREAD_POOL_NAME - ); + // at that time, task already running the sql-worker ThreadPool. + task.run(); } @Override public NodeClient getNodeClient() { return client; } - - private String[] resolveIndexExpression(ClusterState state, String[] indices) { - return resolver.concreteIndexNames(state, IndicesOptions.strictExpandOpen(), true, indices); - } - - private Map populateIndexMappings( - ImmutableOpenMap indexMappings) { - - ImmutableMap.Builder result = ImmutableMap.builder(); - for (ObjectObjectCursor cursor: - indexMappings) { - result.put(cursor.key, new IndexMapping(cursor.value)); - } - return result.build(); - } - - /** Copy from LogUtils. */ - private static Runnable withCurrentContext(final Runnable task) { - final Map currentContext = ThreadContext.getImmutableContext(); - return () -> { - ThreadContext.putAll(currentContext); - task.run(); - }; - } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchBoolPrefixQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchBoolPrefixQuery.java index 754a09259d..33e357afe3 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchBoolPrefixQuery.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchBoolPrefixQuery.java @@ -14,7 +14,7 @@ * Initializes MatchBoolPrefixQueryBuilder from a FunctionExpression. */ public class MatchBoolPrefixQuery - extends RelevanceQuery { + extends SingleFieldQuery { /** * Constructor for MatchBoolPrefixQuery to configure RelevanceQuery * with support of optional parameters. @@ -41,7 +41,12 @@ public MatchBoolPrefixQuery() { * @return Object of executed query */ @Override - protected MatchBoolPrefixQueryBuilder createQueryBuilder(String field, String query) { + protected MatchBoolPrefixQueryBuilder createBuilder(String field, String query) { return QueryBuilders.matchBoolPrefixQuery(field, query); } + + @Override + protected String getQueryName() { + return MatchBoolPrefixQueryBuilder.NAME; + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchPhrasePrefixQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchPhrasePrefixQuery.java index b8d0d4f18d..6d181daa4c 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchPhrasePrefixQuery.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchPhrasePrefixQuery.java @@ -12,7 +12,7 @@ /** * Lucene query that builds a match_phrase_prefix query. */ -public class MatchPhrasePrefixQuery extends RelevanceQuery { +public class MatchPhrasePrefixQuery extends SingleFieldQuery { /** * Default constructor for MatchPhrasePrefixQuery configures how RelevanceQuery.build() handles * named arguments. @@ -29,7 +29,12 @@ public MatchPhrasePrefixQuery() { } @Override - protected MatchPhrasePrefixQueryBuilder createQueryBuilder(String field, String query) { + protected MatchPhrasePrefixQueryBuilder createBuilder(String field, String query) { return QueryBuilders.matchPhrasePrefixQuery(field, query); } + + @Override + protected String getQueryName() { + return MatchPhrasePrefixQueryBuilder.NAME; + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchPhraseQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchPhraseQuery.java index 333d8eff89..6a7694f629 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchPhraseQuery.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchPhraseQuery.java @@ -23,7 +23,7 @@ /** * Lucene query that builds a match_phrase query. */ -public class MatchPhraseQuery extends RelevanceQuery { +public class MatchPhraseQuery extends SingleFieldQuery { /** * Default constructor for MatchPhraseQuery configures how RelevanceQuery.build() handles * named arguments. @@ -39,7 +39,12 @@ public MatchPhraseQuery() { } @Override - protected MatchPhraseQueryBuilder createQueryBuilder(String field, String query) { + protected MatchPhraseQueryBuilder createBuilder(String field, String query) { return QueryBuilders.matchPhraseQuery(field, query); } + + @Override + protected String getQueryName() { + return MatchPhraseQueryBuilder.NAME; + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchQuery.java index 4095ffba4e..f6d88013e4 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchQuery.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchQuery.java @@ -6,7 +6,6 @@ package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance; import com.google.common.collect.ImmutableMap; -import java.util.Map; import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.index.query.Operator; import org.opensearch.index.query.QueryBuilders; @@ -14,7 +13,7 @@ /** * Initializes MatchQueryBuilder from a FunctionExpression. */ -public class MatchQuery extends RelevanceQuery { +public class MatchQuery extends SingleFieldQuery { /** * Default constructor for MatchQuery configures how RelevanceQuery.build() handles * named arguments. @@ -40,7 +39,12 @@ public MatchQuery() { } @Override - protected MatchQueryBuilder createQueryBuilder(String field, String query) { + protected MatchQueryBuilder createBuilder(String field, String query) { return QueryBuilders.matchQuery(field, query); } + + @Override + protected String getQueryName() { + return MatchQueryBuilder.NAME; + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiFieldQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiFieldQuery.java new file mode 100644 index 0000000000..b447f2ffe2 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiFieldQuery.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance; + +import com.google.common.collect.ImmutableMap; +import java.util.Map; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.sql.expression.NamedArgumentExpression; + +/** + * Base class to represent relevance queries that search multiple fields. + * @param The builder class for the OpenSearch query. + */ +abstract class MultiFieldQuery extends RelevanceQuery { + + public MultiFieldQuery(Map> queryBuildActions) { + super(queryBuildActions); + } + + @Override + public T createQueryBuilder(NamedArgumentExpression fields, NamedArgumentExpression queryExpr) { + var fieldsAndWeights = fields + .getValue() + .valueOf(null) + .tupleValue() + .entrySet() + .stream() + .collect(ImmutableMap.toImmutableMap(e -> e.getKey(), e -> e.getValue().floatValue())); + var query = queryExpr.getValue().valueOf(null).stringValue(); + return createBuilder(fieldsAndWeights, query); + } + + protected abstract T createBuilder(ImmutableMap fields, String query); +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiMatchQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiMatchQuery.java index 524d42f0b6..549f58cb19 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiMatchQuery.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiMatchQuery.java @@ -6,18 +6,11 @@ package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance; import com.google.common.collect.ImmutableMap; -import java.util.Iterator; -import java.util.Objects; import org.opensearch.index.query.MultiMatchQueryBuilder; import org.opensearch.index.query.Operator; -import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; -import org.opensearch.sql.exception.SemanticCheckException; -import org.opensearch.sql.expression.Expression; -import org.opensearch.sql.expression.FunctionExpression; -import org.opensearch.sql.expression.NamedArgumentExpression; -public class MultiMatchQuery extends RelevanceQuery { +public class MultiMatchQuery extends MultiFieldQuery { /** * Default constructor for MultiMatch configures how RelevanceQuery.build() handles * named arguments. @@ -46,43 +39,12 @@ public MultiMatchQuery() { } @Override - public QueryBuilder build(FunctionExpression func) { - if (func.getArguments().size() < 2) { - throw new SemanticCheckException("'multi_match' must have at least two arguments"); - } - Iterator iterator = func.getArguments().iterator(); - var fields = (NamedArgumentExpression) iterator.next(); - var query = (NamedArgumentExpression) iterator.next(); - // Fields is a map already, but we need to convert types. - var fieldsAndWeights = fields - .getValue() - .valueOf(null) - .tupleValue() - .entrySet() - .stream() - .collect(ImmutableMap.toImmutableMap(e -> e.getKey(), e -> e.getValue().floatValue())); - - MultiMatchQueryBuilder queryBuilder = createQueryBuilder(null, - query.getValue().valueOf(null).stringValue()) - .fields(fieldsAndWeights); - while (iterator.hasNext()) { - NamedArgumentExpression arg = (NamedArgumentExpression) iterator.next(); - String argNormalized = arg.getArgName().toLowerCase(); - if (!queryBuildActions.containsKey(argNormalized)) { - throw new SemanticCheckException( - String.format("Parameter %s is invalid for %s function.", - argNormalized, queryBuilder.getWriteableName())); - } - (Objects.requireNonNull( - queryBuildActions - .get(argNormalized))) - .apply(queryBuilder, arg.getValue().valueOf(null)); - } - return queryBuilder; + protected MultiMatchQueryBuilder createBuilder(ImmutableMap fields, String query) { + return QueryBuilders.multiMatchQuery(query).fields(fields); } @Override - protected MultiMatchQueryBuilder createQueryBuilder(String field, String query) { - return QueryBuilders.multiMatchQuery(query); + protected String getQueryName() { + return MultiMatchQueryBuilder.NAME; } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/QueryStringQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/QueryStringQuery.java index 54ffea6158..21eb3f8837 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/QueryStringQuery.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/QueryStringQuery.java @@ -23,7 +23,7 @@ /** * Class for Lucene query that builds the query_string query. */ -public class QueryStringQuery extends RelevanceQuery { +public class QueryStringQuery extends MultiFieldQuery { /** * Default constructor for QueryString configures how RelevanceQuery.build() handles * named arguments. @@ -66,55 +66,22 @@ public QueryStringQuery() { .build()); } - /** - * Override base build function for multi-field query support. - * @param func function : 'query_string' function - * @return : QueryBuilder for query_string query - */ - @Override - public QueryBuilder build(FunctionExpression func) { - Iterator iterator = func.getArguments().iterator(); - if (func.getArguments().size() < 2) { - throw new SemanticCheckException("'query_string' must have at least two arguments"); - } - NamedArgumentExpression fields = (NamedArgumentExpression) iterator.next(); - NamedArgumentExpression query = (NamedArgumentExpression) iterator.next(); - // Fields is a map already, but we need to convert types. - var fieldsAndWeights = fields - .getValue() - .valueOf(null) - .tupleValue() - .entrySet() - .stream() - .collect(ImmutableMap.toImmutableMap(e -> e.getKey(), e -> e.getValue().floatValue())); - - QueryStringQueryBuilder queryBuilder = createQueryBuilder(null, - query.getValue().valueOf(null).stringValue()) - .fields(fieldsAndWeights); - while (iterator.hasNext()) { - NamedArgumentExpression arg = (NamedArgumentExpression) iterator.next(); - String argNormalized = arg.getArgName().toLowerCase(); - if (!queryBuildActions.containsKey(argNormalized)) { - throw new SemanticCheckException( - String.format("Parameter %s is invalid for %s function.", - argNormalized, queryBuilder.getWriteableName())); - } - (Objects.requireNonNull( - queryBuildActions - .get(argNormalized))) - .apply(queryBuilder, arg.getValue().valueOf(null)); - } - return queryBuilder; - } /** * Builds QueryBuilder with query value and other default parameter values set. - * @param field : Field value in query_string query + * + * @param fields : A map of field names and their boost values * @param query : Query value for query_string query * @return : Builder for query_string query */ @Override - protected QueryStringQueryBuilder createQueryBuilder(String field, String query) { - return QueryBuilders.queryStringQuery(query); + protected QueryStringQueryBuilder createBuilder(ImmutableMap fields, + String query) { + return QueryBuilders.queryStringQuery(query).fields(fields); + } + + @Override + protected String getQueryName() { + return QueryStringQueryBuilder.NAME; } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQuery.java index fb997646f4..282c5478b4 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQuery.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQuery.java @@ -5,11 +5,14 @@ package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.function.BiFunction; +import lombok.RequiredArgsConstructor; import org.opensearch.index.query.QueryBuilder; import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.data.model.ExprValue; @@ -22,31 +25,33 @@ /** * Base class for query abstraction that builds a relevance query from function expression. */ +@RequiredArgsConstructor public abstract class RelevanceQuery extends LuceneQuery { - protected Map> queryBuildActions; - - protected RelevanceQuery(Map> actionMap) { - queryBuildActions = actionMap; - } + private final Map> queryBuildActions; @Override public QueryBuilder build(FunctionExpression func) { List arguments = func.getArguments(); if (arguments.size() < 2) { - String queryName = createQueryBuilder("dummy_field", "").getWriteableName(); throw new SyntaxCheckException( - String.format("%s requires at least two parameters", queryName)); + String.format("%s requires at least two parameters", getQueryName())); } NamedArgumentExpression field = (NamedArgumentExpression) arguments.get(0); NamedArgumentExpression query = (NamedArgumentExpression) arguments.get(1); - T queryBuilder = createQueryBuilder( - field.getValue().valueOf(null).stringValue(), - query.getValue().valueOf(null).stringValue()); + T queryBuilder = createQueryBuilder(field, query); Iterator iterator = arguments.listIterator(2); + Set visitedParms = new HashSet(); while (iterator.hasNext()) { NamedArgumentExpression arg = (NamedArgumentExpression) iterator.next(); String argNormalized = arg.getArgName().toLowerCase(); + if (visitedParms.contains(argNormalized)) { + throw new SemanticCheckException(String.format("Parameter '%s' can only be specified once.", + argNormalized)); + } else { + visitedParms.add(argNormalized); + } + if (!queryBuildActions.containsKey(argNormalized)) { throw new SemanticCheckException( String.format("Parameter %s is invalid for %s function.", @@ -60,16 +65,19 @@ public QueryBuilder build(FunctionExpression func) { return queryBuilder; } - protected abstract T createQueryBuilder(String field, String query); + protected abstract T createQueryBuilder(NamedArgumentExpression field, + NamedArgumentExpression query); + + protected abstract String getQueryName(); /** * Convenience interface for a function that updates a QueryBuilder * based on ExprValue. + * * @param Concrete query builder */ - public interface QueryBuilderStep extends + protected interface QueryBuilderStep extends BiFunction { - } public static String valueOfToUpper(ExprValue v) { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SimpleQueryStringQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SimpleQueryStringQuery.java index 45637e98a6..1b7c18cb2c 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SimpleQueryStringQuery.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SimpleQueryStringQuery.java @@ -10,16 +10,11 @@ import java.util.Iterator; import java.util.Objects; import org.opensearch.index.query.Operator; -import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.SimpleQueryStringBuilder; import org.opensearch.index.query.SimpleQueryStringFlag; -import org.opensearch.sql.exception.SemanticCheckException; -import org.opensearch.sql.expression.Expression; -import org.opensearch.sql.expression.FunctionExpression; -import org.opensearch.sql.expression.NamedArgumentExpression; -public class SimpleQueryStringQuery extends RelevanceQuery { +public class SimpleQueryStringQuery extends MultiFieldQuery { /** * Default constructor for SimpleQueryString configures how RelevanceQuery.build() handles * named arguments. @@ -48,43 +43,13 @@ public SimpleQueryStringQuery() { } @Override - public QueryBuilder build(FunctionExpression func) { - if (func.getArguments().size() < 2) { - throw new SemanticCheckException("'simple_query_string' must have at least two arguments"); - } - Iterator iterator = func.getArguments().iterator(); - var fields = (NamedArgumentExpression) iterator.next(); - var query = (NamedArgumentExpression) iterator.next(); - // Fields is a map already, but we need to convert types. - var fieldsAndWeights = fields - .getValue() - .valueOf(null) - .tupleValue() - .entrySet() - .stream() - .collect(ImmutableMap.toImmutableMap(e -> e.getKey(), e -> e.getValue().floatValue())); - - SimpleQueryStringBuilder queryBuilder = createQueryBuilder(null, - query.getValue().valueOf(null).stringValue()) - .fields(fieldsAndWeights); - while (iterator.hasNext()) { - NamedArgumentExpression arg = (NamedArgumentExpression) iterator.next(); - String argNormalized = arg.getArgName().toLowerCase(); - if (!queryBuildActions.containsKey(argNormalized)) { - throw new SemanticCheckException( - String.format("Parameter %s is invalid for %s function.", - argNormalized, queryBuilder.getWriteableName())); - } - (Objects.requireNonNull( - queryBuildActions - .get(argNormalized))) - .apply(queryBuilder, arg.getValue().valueOf(null)); - } - return queryBuilder; + protected SimpleQueryStringBuilder createBuilder(ImmutableMap fields, + String query) { + return QueryBuilders.simpleQueryStringQuery(query).fields(fields); } @Override - protected SimpleQueryStringBuilder createQueryBuilder(String field, String query) { - return QueryBuilders.simpleQueryStringQuery(query); + protected String getQueryName() { + return SimpleQueryStringBuilder.NAME; } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQuery.java new file mode 100644 index 0000000000..9876c62cce --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQuery.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance; + +import java.util.Map; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.sql.expression.NamedArgumentExpression; + +/** + * Base class to represent builder class for relevance queries like match_query, match_bool_prefix, + * and match_phrase that search in a single field only. + * + * @param The builder class for the OpenSearch query class. + */ +abstract class SingleFieldQuery extends RelevanceQuery { + public SingleFieldQuery(Map> queryBuildActions) { + super(queryBuildActions); + } + + @Override + protected T createQueryBuilder(NamedArgumentExpression fields, NamedArgumentExpression query) { + return createBuilder( + fields.getValue().valueOf(null).stringValue(), + query.getValue().valueOf(null).stringValue()); + } + + protected abstract T createBuilder(String field, String query); +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java index 8fdb93427b..ad26d792ed 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java @@ -12,8 +12,9 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Answers.RETURNS_DEEP_STUBS; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.any; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; @@ -22,12 +23,10 @@ import com.google.common.base.Charsets; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSortedMap; import com.google.common.io.Resources; import java.io.IOException; import java.net.URL; import java.util.Arrays; -import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -40,24 +39,21 @@ import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.admin.indices.get.GetIndexResponse; +import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; +import org.opensearch.action.admin.indices.settings.get.GetSettingsResponse; import org.opensearch.action.search.ClearScrollRequestBuilder; import org.opensearch.action.search.SearchResponse; import org.opensearch.client.node.NodeClient; -import org.opensearch.cluster.ClusterName; -import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.AliasMetadata; -import org.opensearch.cluster.metadata.IndexAbstraction; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.metadata.MappingMetadata; -import org.opensearch.cluster.metadata.Metadata; -import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.collect.ImmutableOpenMap; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.DeprecationHandler; import org.opensearch.common.xcontent.NamedXContentRegistry; import org.opensearch.common.xcontent.XContentParser; import org.opensearch.common.xcontent.XContentType; -import org.opensearch.index.IndexNotFoundException; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.sql.data.model.ExprIntegerValue; @@ -67,7 +63,6 @@ import org.opensearch.sql.opensearch.mapping.IndexMapping; import org.opensearch.sql.opensearch.request.OpenSearchScrollRequest; import org.opensearch.sql.opensearch.response.OpenSearchResponse; -import org.opensearch.threadpool.ThreadPool; @ExtendWith(MockitoExtension.class) class OpenSearchNodeClientTest { @@ -139,8 +134,8 @@ public void getIndexMappingsWithEmptyMapping() { @Test public void getIndexMappingsWithIOException() { String indexName = "test"; - ClusterService clusterService = mockClusterService(indexName, new IOException()); - OpenSearchNodeClient client = new OpenSearchNodeClient(clusterService, nodeClient); + when(nodeClient.admin().indices()).thenThrow(RuntimeException.class); + OpenSearchNodeClient client = new OpenSearchNodeClient(nodeClient); assertThrows(IllegalStateException.class, () -> client.getIndexMappings(indexName)); } @@ -148,18 +143,17 @@ public void getIndexMappingsWithIOException() { @Test public void getIndexMappingsWithNonExistIndex() { OpenSearchNodeClient client = - new OpenSearchNodeClient(mockClusterService("test"), nodeClient); - - assertThrows(IndexNotFoundException.class, () -> client.getIndexMappings("non_exist_index")); + new OpenSearchNodeClient(mockNodeClient("test")); + assertTrue(client.getIndexMappings("non_exist_index").isEmpty()); } @Test public void getIndexMaxResultWindows() throws IOException { URL url = Resources.getResource(TEST_MAPPING_SETTINGS_FILE); - String mappings = Resources.toString(url, Charsets.UTF_8); + String indexMetadata = Resources.toString(url, Charsets.UTF_8); String indexName = "accounts"; - ClusterService clusterService = mockClusterServiceForSettings(indexName, mappings); - OpenSearchNodeClient client = new OpenSearchNodeClient(clusterService, nodeClient); + OpenSearchNodeClient client = + new OpenSearchNodeClient(mockNodeClientSettings(indexName, indexMetadata)); Map indexMaxResultWindows = client.getIndexMaxResultWindows(indexName); assertEquals(1, indexMaxResultWindows.size()); @@ -171,10 +165,10 @@ public void getIndexMaxResultWindows() throws IOException { @Test public void getIndexMaxResultWindowsWithDefaultSettings() throws IOException { URL url = Resources.getResource(TEST_MAPPING_FILE); - String mappings = Resources.toString(url, Charsets.UTF_8); + String indexMetadata = Resources.toString(url, Charsets.UTF_8); String indexName = "accounts"; - ClusterService clusterService = mockClusterServiceForSettings(indexName, mappings); - OpenSearchNodeClient client = new OpenSearchNodeClient(clusterService, nodeClient); + OpenSearchNodeClient client = + new OpenSearchNodeClient(mockNodeClientSettings(indexName, indexMetadata)); Map indexMaxResultWindows = client.getIndexMaxResultWindows(indexName); assertEquals(1, indexMaxResultWindows.size()); @@ -183,6 +177,15 @@ public void getIndexMaxResultWindowsWithDefaultSettings() throws IOException { assertEquals(10000, indexMaxResultWindow); } + @Test + public void getIndexMaxResultWindowsWithIOException() { + String indexName = "test"; + when(nodeClient.admin().indices()).thenThrow(RuntimeException.class); + OpenSearchNodeClient client = new OpenSearchNodeClient(nodeClient); + + assertThrows(IllegalStateException.class, () -> client.getIndexMaxResultWindows(indexName)); + } + /** Jacoco enforce this constant lambda be tested. */ @Test public void testAllFieldsPredicate() { @@ -192,7 +195,7 @@ public void testAllFieldsPredicate() { @Test public void search() { OpenSearchNodeClient client = - new OpenSearchNodeClient(mock(ClusterService.class), nodeClient); + new OpenSearchNodeClient(nodeClient); // Mock first scroll request SearchResponse searchResponse = mock(SearchResponse.class); @@ -230,23 +233,12 @@ public void search() { @Test void schedule() { - ThreadPool threadPool = mock(ThreadPool.class); - when(nodeClient.threadPool()).thenReturn(threadPool); - when(threadPool.getThreadContext()).thenReturn(threadContext); - - doAnswer( - invocation -> { - Runnable task = invocation.getArgument(0); - task.run(); - return null; - }) - .when(threadPool) - .schedule(any(), any(), any()); - - OpenSearchNodeClient client = - new OpenSearchNodeClient(mock(ClusterService.class), nodeClient); + OpenSearchNodeClient client = new OpenSearchNodeClient(nodeClient); AtomicBoolean isRun = new AtomicBoolean(false); - client.schedule(() -> isRun.set(true)); + client.schedule( + () -> { + isRun.set(true); + }); assertTrue(isRun.get()); } @@ -257,8 +249,7 @@ void cleanup() { when(requestBuilder.addScrollId(any())).thenReturn(requestBuilder); when(requestBuilder.get()).thenReturn(null); - OpenSearchNodeClient client = - new OpenSearchNodeClient(mock(ClusterService.class), nodeClient); + OpenSearchNodeClient client = new OpenSearchNodeClient(nodeClient); OpenSearchScrollRequest request = new OpenSearchScrollRequest("test", factory); request.setScrollId("scroll123"); client.cleanup(request); @@ -272,8 +263,7 @@ void cleanup() { @Test void cleanupWithoutScrollId() { - OpenSearchNodeClient client = - new OpenSearchNodeClient(mock(ClusterService.class), nodeClient); + OpenSearchNodeClient client = new OpenSearchNodeClient(nodeClient); OpenSearchScrollRequest request = new OpenSearchScrollRequest("test", factory); client.cleanup(request); @@ -294,122 +284,80 @@ void getIndices() { when(indexResponse.getIndices()).thenReturn(new String[] {"index"}); when(indexResponse.aliases()).thenReturn(openMap); - OpenSearchNodeClient client = - new OpenSearchNodeClient(mock(ClusterService.class), nodeClient); + OpenSearchNodeClient client = new OpenSearchNodeClient(nodeClient); final List indices = client.indices(); assertEquals(2, indices.size()); } @Test void meta() { - ClusterName clusterName = mock(ClusterName.class); - ClusterService mockService = mock(ClusterService.class); - when(clusterName.value()).thenReturn("cluster-name"); - when(mockService.getClusterName()).thenReturn(clusterName); + Settings settings = mock(Settings.class); + when(nodeClient.settings()).thenReturn(settings); + when(settings.get(anyString(), anyString())).thenReturn("cluster-name"); - OpenSearchNodeClient client = - new OpenSearchNodeClient(mockService, nodeClient); + OpenSearchNodeClient client = new OpenSearchNodeClient(nodeClient); final Map meta = client.meta(); assertEquals("cluster-name", meta.get(META_CLUSTER_NAME)); } @Test void ml() { - OpenSearchNodeClient client = new OpenSearchNodeClient(mock(ClusterService.class), nodeClient); + OpenSearchNodeClient client = new OpenSearchNodeClient(nodeClient); assertNotNull(client.getNodeClient()); } private OpenSearchNodeClient mockClient(String indexName, String mappings) { - ClusterService clusterService = mockClusterService(indexName, mappings); - return new OpenSearchNodeClient(clusterService, nodeClient); + mockNodeClientIndicesMappings(indexName, mappings); + return new OpenSearchNodeClient(nodeClient); } - /** Mock getAliasAndIndexLookup() only for index name resolve test. */ - public ClusterService mockClusterService(String indexName) { - ClusterService mockService = mock(ClusterService.class); - ClusterState mockState = mock(ClusterState.class); - Metadata mockMetaData = mock(Metadata.class); - - when(mockService.state()).thenReturn(mockState); - when(mockState.metadata()).thenReturn(mockMetaData); - when(mockMetaData.getIndicesLookup()) - .thenReturn(ImmutableSortedMap.of(indexName, mock(IndexAbstraction.class))); - return mockService; - } - - public ClusterService mockClusterService(String indexName, String mappings) { - ClusterService mockService = mock(ClusterService.class); - ClusterState mockState = mock(ClusterState.class); - Metadata mockMetaData = mock(Metadata.class); - - when(mockService.state()).thenReturn(mockState); - when(mockState.metadata()).thenReturn(mockMetaData); + public void mockNodeClientIndicesMappings(String indexName, String mappings) { + GetMappingsResponse mockResponse = mock(GetMappingsResponse.class); + MappingMetadata emptyMapping = mock(MappingMetadata.class); + when(nodeClient.admin().indices() + .prepareGetMappings(any()) + .setLocal(anyBoolean()) + .get()).thenReturn(mockResponse); try { - ImmutableOpenMap.Builder builder = - ImmutableOpenMap.builder(); - MappingMetadata metadata; + ImmutableOpenMap metadata; if (mappings.isEmpty()) { - metadata = MappingMetadata.EMPTY_MAPPINGS; + when(emptyMapping.getSourceAsMap()).thenReturn(ImmutableMap.of()); + metadata = + new ImmutableOpenMap.Builder() + .fPut(indexName, emptyMapping) + .build(); } else { - metadata = IndexMetadata.fromXContent(createParser(mappings)).mapping(); + metadata = new ImmutableOpenMap.Builder().fPut(indexName, + IndexMetadata.fromXContent(createParser(mappings)).mapping()).build(); } - - - builder.put(indexName, metadata); - when(mockMetaData.findMappings(any(), any())).thenReturn(builder.build()); - - // IndexNameExpressionResolver use this method to check if index exists. If not, - // IndexNotFoundException is thrown. - when(mockMetaData.getIndicesLookup()) - .thenReturn(ImmutableSortedMap.of(indexName, mock(IndexAbstraction.class))); + when(mockResponse.mappings()).thenReturn(metadata); } catch (IOException e) { - throw new IllegalStateException("Failed to mock cluster service", e); + throw new IllegalStateException("Failed to mock node client", e); } - return mockService; } - public ClusterService mockClusterService(String indexName, Throwable t) { - ClusterService mockService = mock(ClusterService.class); - ClusterState mockState = mock(ClusterState.class); - Metadata mockMetaData = mock(Metadata.class); - - when(mockService.state()).thenReturn(mockState); - when(mockState.metadata()).thenReturn(mockMetaData); - try { - when(mockMetaData.findMappings(any(), any())).thenThrow(t); - when(mockMetaData.getIndicesLookup()) - .thenReturn(ImmutableSortedMap.of(indexName, mock(IndexAbstraction.class))); - } catch (IOException e) { - throw new IllegalStateException("Failed to mock cluster service", e); - } - return mockService; + public NodeClient mockNodeClient(String indexName) { + GetMappingsResponse mockResponse = mock(GetMappingsResponse.class); + when(nodeClient.admin().indices() + .prepareGetMappings(any()) + .setLocal(anyBoolean()) + .get()).thenReturn(mockResponse); + when(mockResponse.mappings()).thenReturn(ImmutableOpenMap.of()); + return nodeClient; } - public ClusterService mockClusterServiceForSettings(String indexName, String mappings) { - ClusterService mockService = mock(ClusterService.class); - ClusterState mockState = mock(ClusterState.class); - Metadata mockMetaData = mock(Metadata.class); - - when(mockService.state()).thenReturn(mockState); - when(mockState.metadata()).thenReturn(mockMetaData); - try { - ImmutableOpenMap.Builder indexBuilder = - ImmutableOpenMap.builder(); - IndexMetadata indexMetadata = IndexMetadata.fromXContent(createParser(mappings)); - - indexBuilder.put(indexName, indexMetadata); - when(mockMetaData.getIndices()).thenReturn(indexBuilder.build()); - - // IndexNameExpressionResolver use this method to check if index exists. If not, - // IndexNotFoundException is thrown. - IndexAbstraction indexAbstraction = mock(IndexAbstraction.class); - when(indexAbstraction.getIndices()).thenReturn(Collections.singletonList(indexMetadata)); - when(mockMetaData.getIndicesLookup()) - .thenReturn(ImmutableSortedMap.of(indexName, indexAbstraction)); - } catch (IOException e) { - throw new IllegalStateException("Failed to mock cluster service", e); - } - return mockService; + private NodeClient mockNodeClientSettings(String indexName, String indexMetadata) + throws IOException { + GetSettingsResponse mockResponse = mock(GetSettingsResponse.class); + when(nodeClient.admin().indices().prepareGetSettings(any()).setLocal(anyBoolean()).get()) + .thenReturn(mockResponse); + ImmutableOpenMap metadata = + new ImmutableOpenMap.Builder() + .fPut(indexName, IndexMetadata.fromXContent(createParser(indexMetadata)).getSettings()) + .build(); + + when(mockResponse.getIndexToSettings()).thenReturn(metadata); + return nodeClient; } private XContentParser createParser(String mappings) throws IOException { diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicOptimizerTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicOptimizerTest.java index 8085a2c0d4..9ad37c6ef3 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicOptimizerTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicOptimizerTest.java @@ -28,18 +28,25 @@ import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.opensearch.utils.Utils; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.optimizer.LogicalPlanOptimizer; +import org.opensearch.sql.storage.Table; - +@ExtendWith(MockitoExtension.class) class OpenSearchLogicOptimizerTest { private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); + @Mock + private Table table; + /** * SELECT intV as i FROM schema WHERE intV = 1. */ @@ -55,7 +62,7 @@ void project_filter_merge_with_relation() { optimize( project( filter( - relation("schema"), + relation("schema", table), dsl.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) ), DSL.named("i", DSL.ref("intV", INTEGER))) @@ -79,7 +86,7 @@ void aggregation_merge_relation() { optimize( project( aggregation( - relation("schema"), + relation("schema", table), ImmutableList .of(DSL.named("AVG(intV)", dsl.avg(DSL.ref("intV", INTEGER)))), @@ -109,7 +116,7 @@ void aggregation_merge_filter_relation() { project( aggregation( filter( - relation("schema"), + relation("schema", table), dsl.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) ), ImmutableList @@ -160,7 +167,7 @@ void sort_merge_with_relation() { indexScan("schema", Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("intV", INTEGER))), optimize( sort( - relation("schema"), + relation("schema", table), Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("intV", INTEGER)) ) ) @@ -198,7 +205,7 @@ void sort_filter_merge_with_relation() { optimize( sort( filter( - relation("schema"), + relation("schema", table), dsl.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) ), Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("longV", LONG)) @@ -211,12 +218,12 @@ void sort_filter_merge_with_relation() { void sort_with_expression_cannot_merge_with_relation() { assertEquals( sort( - relation("schema"), + relation("schema", table), Pair.of(Sort.SortOption.DEFAULT_ASC, dsl.abs(DSL.ref("intV", INTEGER))) ), optimize( sort( - relation("schema"), + relation("schema", table), Pair.of(Sort.SortOption.DEFAULT_ASC, dsl.abs(DSL.ref("intV", INTEGER))) ) ) @@ -240,7 +247,7 @@ void sort_merge_indexagg() { project( sort( aggregation( - relation("schema"), + relation("schema", table), ImmutableList .of(DSL.named("AVG(intV)", dsl.avg(DSL.ref("intV", INTEGER)))), ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), @@ -268,7 +275,7 @@ void sort_merge_indexagg_nulls_last() { project( sort( aggregation( - relation("schema"), + relation("schema", table), ImmutableList .of(DSL.named("AVG(intV)", dsl.avg(DSL.ref("intV", INTEGER)))), ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), @@ -339,7 +346,7 @@ void limit_merge_with_relation() { optimize( project( limit( - relation("schema"), + relation("schema", table), 1, 1 ), DSL.named("intV", DSL.ref("intV", INTEGER)) @@ -363,7 +370,7 @@ void limit_merge_with_index_scan() { project( limit( filter( - relation("schema"), + relation("schema", table), dsl.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) ), 1, 1 ), @@ -389,7 +396,7 @@ void limit_merge_with_index_scan_sort() { limit( sort( filter( - relation("schema"), + relation("schema", table), dsl.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) ), Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("longV", LONG)) @@ -434,7 +441,7 @@ void push_down_projectList_to_relation() { ), optimize( project( - relation("schema"), + relation("schema", table), DSL.named("i", DSL.ref("intV", INTEGER))) ) ); @@ -455,7 +462,7 @@ void push_down_should_handle_duplication() { ), optimize( project( - relation("schema"), + relation("schema", table), DSL.named("i", DSL.ref("intV", INTEGER)), DSL.named("absi", dsl.abs(DSL.ref("intV", INTEGER)))) ) @@ -483,7 +490,7 @@ void only_one_project_should_be_push() { optimize( project( project( - relation("schema"), + relation("schema", table), DSL.named("i", DSL.ref("intV", INTEGER)), DSL.named("s", DSL.ref("stringV", STRING)) ), @@ -497,12 +504,12 @@ void only_one_project_should_be_push() { void project_literal_no_push() { assertEquals( project( - relation("schema"), + relation("schema", table), DSL.named("i", DSL.literal("str")) ), optimize( project( - relation("schema"), + relation("schema", table), DSL.named("i", DSL.literal("str")) ) ) @@ -524,7 +531,7 @@ void filter_aggregation_merge_relation() { optimize( project( aggregation( - relation("schema"), + relation("schema", table), ImmutableList.of(DSL.named("AVG(intV)", dsl.avg(DSL.ref("intV", INTEGER)) .condition(dsl.greater(DSL.ref("intV", INTEGER), DSL.literal(1))))), @@ -552,7 +559,7 @@ void filter_aggregation_merge_filter_relation() { project( aggregation( filter( - relation("schema"), + relation("schema", table), dsl.less(DSL.ref("longV", LONG), DSL.literal(1)) ), ImmutableList.of(DSL.named("avg(intV)", diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java index b85d60c1fb..64b87aa2c5 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java @@ -25,6 +25,7 @@ import org.opensearch.sql.planner.logical.LogicalHighlight; import org.opensearch.sql.planner.logical.LogicalMLCommons; import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.storage.Table; @ExtendWith(MockitoExtension.class) public class OpenSearchDefaultImplementorTest { @@ -34,6 +35,9 @@ public class OpenSearchDefaultImplementorTest { @Mock OpenSearchClient client; + @Mock + Table table; + /** * For test coverage. */ @@ -43,8 +47,9 @@ public void visitInvalidTypeShouldThrowException() { new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); final IllegalStateException exception = - assertThrows(IllegalStateException.class, () -> implementor.visitNode(relation("index"), - indexScan)); + assertThrows(IllegalStateException.class, + () -> implementor.visitNode(relation("index", table), + indexScan)); ; assertEquals( "unexpected plan node type " @@ -55,20 +60,20 @@ public void visitInvalidTypeShouldThrowException() { @Test public void visitMachineLearning() { LogicalMLCommons node = Mockito.mock(LogicalMLCommons.class, - Answers.RETURNS_DEEP_STUBS); + Answers.RETURNS_DEEP_STUBS); Mockito.when(node.getChild().get(0)).thenReturn(Mockito.mock(LogicalPlan.class)); OpenSearchIndex.OpenSearchDefaultImplementor implementor = - new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); + new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); assertNotNull(implementor.visitMLCommons(node, indexScan)); } @Test public void visitAD() { LogicalAD node = Mockito.mock(LogicalAD.class, - Answers.RETURNS_DEEP_STUBS); + Answers.RETURNS_DEEP_STUBS); Mockito.when(node.getChild().get(0)).thenReturn(Mockito.mock(LogicalPlan.class)); OpenSearchIndex.OpenSearchDefaultImplementor implementor = - new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); + new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); assertNotNull(implementor.visitAD(node, indexScan)); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java index f1754a455d..82ac3991ac 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java @@ -70,6 +70,7 @@ import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanDSL; import org.opensearch.sql.planner.physical.ProjectOperator; +import org.opensearch.sql.storage.Table; @ExtendWith(MockitoExtension.class) class OpenSearchIndexTest { @@ -85,6 +86,9 @@ class OpenSearchIndexTest { @Mock private Settings settings; + @Mock + private Table table; + @Test void getFieldTypes() { when(client.getIndexMappings("test")) @@ -136,7 +140,7 @@ void implementRelationOperatorOnly() { when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); String indexName = "test"; - LogicalPlan plan = relation(indexName); + LogicalPlan plan = relation(indexName, table); OpenSearchIndex index = new OpenSearchIndex(client, settings, indexName); Integer maxResultWindow = index.getMaxResultWindow(); assertEquals( @@ -150,7 +154,7 @@ void implementRelationOperatorWithOptimization() { when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); String indexName = "test"; - LogicalPlan plan = relation(indexName); + LogicalPlan plan = relation(indexName, table); OpenSearchIndex index = new OpenSearchIndex(client, settings, indexName); Integer maxResultWindow = index.getMaxResultWindow(); assertEquals( @@ -187,7 +191,7 @@ void implementOtherLogicalOperators() { eval( remove( rename( - relation(indexName), + relation(indexName, table), mappings), exclude), newEvalField), @@ -255,7 +259,7 @@ void shouldNotPushDownFilterFarFromRelation() { PhysicalPlan plan = index.implement( filter( aggregation( - relation(indexName), + relation(indexName, table), aggregators, groupByExprs ), @@ -319,7 +323,7 @@ void shouldNotPushDownAggregationFarFromRelation() { PhysicalPlan plan = index.implement( aggregation( filter(filter( - relation(indexName), + relation(indexName, table), filterExpr), filterExpr), aggregators, groupByExprs)); @@ -407,7 +411,7 @@ void shouldNotPushDownLimitFarFromRelationButUpdateScanSize() { project( limit( sort( - relation("test"), + relation("test", table), Pair.of(Sort.SortOption.DEFAULT_ASC, dsl.abs(named("intV", ref("intV", INTEGER)))) ), diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java index b1efe86d01..75ddd1dd93 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java @@ -855,41 +855,6 @@ void match_phrase_invalid_value_ztq() { msg); } - @Test - void match_phrase_missing_field() { - var msg = assertThrows(ExpressionEvaluationException.class, () -> - dsl.match_phrase( - dsl.namedArgument("query", literal("search query")))).getMessage(); - assertEquals("match_phrase function expected {[STRING,STRING],[STRING,STRING,STRING]," - + "[STRING,STRING,STRING,STRING],[STRING,STRING,STRING,STRING,STRING]}, but get [STRING]", - msg); - } - - @Test - void match_phrase_missing_query() { - var msg = assertThrows(ExpressionEvaluationException.class, () -> - dsl.match_phrase( - dsl.namedArgument("field", literal("message")))).getMessage(); - assertEquals("match_phrase function expected {[STRING,STRING],[STRING,STRING,STRING]," - + "[STRING,STRING,STRING,STRING],[STRING,STRING,STRING,STRING,STRING]}, but get [STRING]", - msg); - } - - @Test - void match_phrase_too_many_args() { - var msg = assertThrows(ExpressionEvaluationException.class, () -> - dsl.match_phrase( - dsl.namedArgument("one", literal("1")), - dsl.namedArgument("two", literal("2")), - dsl.namedArgument("three", literal("3")), - dsl.namedArgument("four", literal("4")), - dsl.namedArgument("fix", literal("5")), - dsl.namedArgument("six", literal("6")) - )).getMessage(); - assertEquals("match_phrase function expected {[STRING,STRING],[STRING,STRING,STRING]," - + "[STRING,STRING,STRING,STRING],[STRING,STRING,STRING,STRING,STRING]}, but get " - + "[STRING,STRING,STRING,STRING,STRING,STRING]", msg); - } @Test @@ -913,55 +878,6 @@ void should_build_match_bool_prefix_query_with_default_parameters() { dsl.namedArgument("query", literal("search query"))))); } - @Test - void multi_match_missing_fields() { - var msg = assertThrows(ExpressionEvaluationException.class, () -> - dsl.multi_match( - dsl.namedArgument("query", literal("search query")))).getMessage(); - assertEquals("multi_match function expected {[STRUCT,STRING],[STRUCT,STRING,STRING]," - + "[STRUCT,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING],[STRUCT,STRING," - + "STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING,STRING,STRING]," - + "[STRUCT,STRING,STRING,STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING]," - + "[STRUCT,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING," - + "STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING]}, but get [STRING]", - msg); - } - - @Test - void multi_match_missing_query() { - var msg = assertThrows(ExpressionEvaluationException.class, () -> - dsl.multi_match( - dsl.namedArgument("fields", DSL.literal( - new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field1", ExprValueUtils.floatValue(1.F), - "field2", ExprValueUtils.floatValue(.3F)))))))).getMessage(); - assertEquals("multi_match function expected {[STRUCT,STRING],[STRUCT,STRING,STRING]," - + "[STRUCT,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING],[STRUCT,STRING," - + "STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING,STRING,STRING]," - + "[STRUCT,STRING,STRING,STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING]," - + "[STRUCT,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING," - + "STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING]}, but get [STRUCT]", - msg); - } - @Test void should_build_match_phrase_prefix_query_with_default_parameters() { assertJsonEquals( diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchBoolPrefixQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchBoolPrefixQueryTest.java index 00cf3158c4..c30e06bc1a 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchBoolPrefixQueryTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchBoolPrefixQueryTest.java @@ -61,8 +61,8 @@ public void test_valid_arguments(List validArgs) { @Test public void test_valid_when_two_arguments() { List arguments = List.of( - namedArgument("field", "field_value"), - namedArgument("query", "query_value")); + dsl.namedArgument("field", "field_value"), + dsl.namedArgument("query", "query_value")); Assertions.assertNotNull(matchBoolPrefixQuery.build(new MatchExpression(arguments))); } @@ -75,7 +75,7 @@ public void test_SyntaxCheckException_when_no_arguments() { @Test public void test_SyntaxCheckException_when_one_argument() { - List arguments = List.of(namedArgument("field", "field_value")); + List arguments = List.of(dsl.namedArgument("field", "field_value")); assertThrows(SyntaxCheckException.class, () -> matchBoolPrefixQuery.build(new MatchExpression(arguments))); } @@ -83,17 +83,13 @@ public void test_SyntaxCheckException_when_one_argument() { @Test public void test_SemanticCheckException_when_invalid_argument() { List arguments = List.of( - namedArgument("field", "field_value"), - namedArgument("query", "query_value"), - namedArgument("unsupported", "unsupported_value")); + dsl.namedArgument("field", "field_value"), + dsl.namedArgument("query", "query_value"), + dsl.namedArgument("unsupported", "unsupported_value")); Assertions.assertThrows(SemanticCheckException.class, () -> matchBoolPrefixQuery.build(new MatchExpression(arguments))); } - private NamedArgumentExpression namedArgument(String name, String value) { - return dsl.namedArgument(name, DSL.literal(value)); - } - private class MatchExpression extends FunctionExpression { public MatchExpression(List arguments) { super(MatchBoolPrefixQueryTest.this.matchBoolPrefix, arguments); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhraseQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhraseQueryTest.java index 4e8895a12a..09e25fe569 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhraseQueryTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhraseQueryTest.java @@ -20,7 +20,6 @@ import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.FunctionExpression; -import org.opensearch.sql.expression.NamedArgumentExpression; import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.expression.env.Environment; import org.opensearch.sql.expression.function.FunctionName; @@ -33,10 +32,6 @@ public class MatchPhraseQueryTest { private final MatchPhraseQuery matchPhraseQuery = new MatchPhraseQuery(); private final FunctionName matchPhrase = FunctionName.of("match_phrase"); - private NamedArgumentExpression namedArgument(String name, String value) { - return dsl.namedArgument(name, DSL.literal(value)); - } - @Test public void test_SyntaxCheckException_when_no_arguments() { List arguments = List.of(); @@ -46,7 +41,7 @@ public void test_SyntaxCheckException_when_no_arguments() { @Test public void test_SyntaxCheckException_when_one_argument() { - List arguments = List.of(namedArgument("field", "test")); + List arguments = List.of(dsl.namedArgument("field", "test")); assertThrows(SyntaxCheckException.class, () -> matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } @@ -54,9 +49,9 @@ public void test_SyntaxCheckException_when_one_argument() { @Test public void test_SyntaxCheckException_when_invalid_parameter() { List arguments = List.of( - namedArgument("field", "test"), - namedArgument("query", "test2"), - namedArgument("unsupported", "3")); + dsl.namedArgument("field", "test"), + dsl.namedArgument("query", "test2"), + dsl.namedArgument("unsupported", "3")); Assertions.assertThrows(SemanticCheckException.class, () -> matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } @@ -64,9 +59,9 @@ public void test_SyntaxCheckException_when_invalid_parameter() { @Test public void test_analyzer_parameter() { List arguments = List.of( - namedArgument("field", "t1"), - namedArgument("query", "t2"), - namedArgument("analyzer", "standard") + dsl.namedArgument("field", "t1"), + dsl.namedArgument("query", "t2"), + dsl.namedArgument("analyzer", "standard") ); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } @@ -74,17 +69,17 @@ public void test_analyzer_parameter() { @Test public void build_succeeds_with_two_arguments() { List arguments = List.of( - namedArgument("field", "test"), - namedArgument("query", "test2")); + dsl.namedArgument("field", "test"), + dsl.namedArgument("query", "test2")); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } @Test public void test_slop_parameter() { List arguments = List.of( - namedArgument("field", "t1"), - namedArgument("query", "t2"), - namedArgument("slop", "2") + dsl.namedArgument("field", "t1"), + dsl.namedArgument("query", "t2"), + dsl.namedArgument("slop", "2") ); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } @@ -92,9 +87,9 @@ public void test_slop_parameter() { @Test public void test_zero_terms_query_parameter() { List arguments = List.of( - namedArgument("field", "t1"), - namedArgument("query", "t2"), - namedArgument("zero_terms_query", "ALL") + dsl.namedArgument("field", "t1"), + dsl.namedArgument("query", "t2"), + dsl.namedArgument("zero_terms_query", "ALL") ); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } @@ -102,9 +97,9 @@ public void test_zero_terms_query_parameter() { @Test public void test_zero_terms_query_parameter_lower_case() { List arguments = List.of( - namedArgument("field", "t1"), - namedArgument("query", "t2"), - namedArgument("zero_terms_query", "all") + dsl.namedArgument("field", "t1"), + dsl.namedArgument("query", "t2"), + dsl.namedArgument("zero_terms_query", "all") ); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MultiMatchTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MultiMatchTest.java index 4a6e1d2ed9..261870ca17 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MultiMatchTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MultiMatchTest.java @@ -18,6 +18,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -137,16 +138,16 @@ public void test_valid_parameters(List validArgs) { } @Test - public void test_SemanticCheckException_when_no_arguments() { + public void test_SyntaxCheckException_when_no_arguments() { List arguments = List.of(); - assertThrows(SemanticCheckException.class, + assertThrows(SyntaxCheckException.class, () -> multiMatchQuery.build(new MultiMatchExpression(arguments))); } @Test - public void test_SemanticCheckException_when_one_argument() { + public void test_SyntaxCheckException_when_one_argument() { List arguments = List.of(namedArgument("fields", fields_value)); - assertThrows(SemanticCheckException.class, + assertThrows(SyntaxCheckException.class, () -> multiMatchQuery.build(new MultiMatchExpression(arguments))); } @@ -155,15 +156,11 @@ public void test_SemanticCheckException_when_invalid_parameter() { List arguments = List.of( namedArgument("fields", fields_value), namedArgument("query", query_value), - namedArgument("unsupported", "unsupported_value")); + dsl.namedArgument("unsupported", "unsupported_value")); Assertions.assertThrows(SemanticCheckException.class, () -> multiMatchQuery.build(new MultiMatchExpression(arguments))); } - private NamedArgumentExpression namedArgument(String name, String value) { - return dsl.namedArgument(name, DSL.literal(value)); - } - private NamedArgumentExpression namedArgument(String name, LiteralExpression value) { return dsl.namedArgument(name, value); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/QueryStringTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/QueryStringTest.java index fce835bf43..21b03abab0 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/QueryStringTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/QueryStringTest.java @@ -17,6 +17,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -88,16 +89,16 @@ public void test_valid_parameters(List validArgs) { } @Test - public void test_SemanticCheckException_when_no_arguments() { + public void test_SyntaxCheckException_when_no_arguments() { List arguments = List.of(); - assertThrows(SemanticCheckException.class, + assertThrows(SyntaxCheckException.class, () -> queryStringQuery.build(new QueryStringExpression(arguments))); } @Test - public void test_SemanticCheckException_when_one_argument() { + public void test_SyntaxCheckException_when_one_argument() { List arguments = List.of(namedArgument("fields", fields_value)); - assertThrows(SemanticCheckException.class, + assertThrows(SyntaxCheckException.class, () -> queryStringQuery.build(new QueryStringExpression(arguments))); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/SimpleQueryStringTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/SimpleQueryStringTest.java index 048f6e1cb9..8f06f48727 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/SimpleQueryStringTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/SimpleQueryStringTest.java @@ -18,6 +18,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -161,16 +162,16 @@ public void test_valid_parameters(List validArgs) { } @Test - public void test_SemanticCheckException_when_no_arguments() { + public void test_SyntaxCheckException_when_no_arguments() { List arguments = List.of(); - assertThrows(SemanticCheckException.class, + assertThrows(SyntaxCheckException.class, () -> simpleQueryStringQuery.build(new SimpleQueryStringExpression(arguments))); } @Test - public void test_SemanticCheckException_when_one_argument() { + public void test_SyntaxCheckException_when_one_argument() { List arguments = List.of(namedArgument("fields", fields_value)); - assertThrows(SemanticCheckException.class, + assertThrows(SyntaxCheckException.class, () -> simpleQueryStringQuery.build(new SimpleQueryStringExpression(arguments))); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiFieldQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiFieldQueryTest.java new file mode 100644 index 0000000000..7e4c6ea011 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiFieldQueryTest.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance; + +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableMap; +import java.util.Map; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentMatcher; +import org.mockito.Mockito; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.LiteralExpression; +import org.opensearch.sql.expression.config.ExpressionConfig; + +class MultiFieldQueryTest { + MultiFieldQuery query; + private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); + private final String testQueryName = "test_query"; + private final Map actionMap + = ImmutableMap.of("paramA", (o, v) -> o); + + @BeforeEach + public void setUp() { + query = mock(MultiFieldQuery.class, + Mockito.withSettings().useConstructor(actionMap) + .defaultAnswer(Mockito.CALLS_REAL_METHODS)); + when(query.getQueryName()).thenReturn(testQueryName); + } + + @Test + void createQueryBuilderTest() { + String sampleQuery = "sample query"; + String sampleField = "fieldA"; + float sampleValue = 34f; + + var fieldSpec = ImmutableMap.builder().put(sampleField, + ExprValueUtils.floatValue(sampleValue)).build(); + + query.createQueryBuilder(dsl.namedArgument("fields", + new LiteralExpression(ExprTupleValue.fromExprValueMap(fieldSpec))), + dsl.namedArgument("query", + new LiteralExpression(ExprValueUtils.stringValue(sampleQuery)))); + + verify(query).createBuilder(argThat( + (ArgumentMatcher>) map -> map.size() == 1 + && map.containsKey(sampleField) && map.containsValue(sampleValue)), + eq(sampleQuery)); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQueryBuildTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQueryBuildTest.java index a67f0f34a7..fa6a43474a 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQueryBuildTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQueryBuildTest.java @@ -30,7 +30,6 @@ import org.opensearch.sql.data.model.ExprStringValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprType; -import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.FunctionExpression; @@ -55,14 +54,20 @@ public void setUp() { .defaultAnswer(Mockito.CALLS_REAL_METHODS)); queryBuilder = mock(QueryBuilder.class); when(query.createQueryBuilder(any(), any())).thenReturn(queryBuilder); - when(queryBuilder.queryName()).thenReturn("mocked_query"); - when(queryBuilder.getWriteableName()).thenReturn("mock_query"); + String queryName = "mock_query"; + when(queryBuilder.queryName()).thenReturn(queryName); + when(queryBuilder.getWriteableName()).thenReturn(queryName); + when(query.getQueryName()).thenReturn(queryName); } @Test - void first_arg_field_second_arg_query_test() { - query.build(createCall(List.of(FIELD_ARG, QUERY_ARG))); - verify(query, times(1)).createQueryBuilder("field_A", "find me"); + void throws_SemanticCheckException_when_same_argument_twice() { + FunctionExpression expr = createCall(List.of(FIELD_ARG, QUERY_ARG, + namedArgument("boost", "2.3"), + namedArgument("boost", "2.4"))); + SemanticCheckException exception = + assertThrows(SemanticCheckException.class, () -> query.build(expr)); + assertEquals("Parameter 'boost' can only be specified once.", exception.getMessage()); } @Test @@ -72,7 +77,8 @@ void throws_SemanticCheckException_when_wrong_argument_name() { SemanticCheckException exception = assertThrows(SemanticCheckException.class, () -> query.build(expr)); - assertEquals("Parameter wrongarg is invalid for mock_query function.", exception.getMessage()); + assertEquals("Parameter wrongarg is invalid for mock_query function.", + exception.getMessage()); } @Test diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQueryTest.java new file mode 100644 index 0000000000..5d35327116 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQueryTest.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableMap; +import java.util.Map; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.LiteralExpression; +import org.opensearch.sql.expression.config.ExpressionConfig; + +class SingleFieldQueryTest { + SingleFieldQuery query; + private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); + private final String testQueryName = "test_query"; + private final Map actionMap + = ImmutableMap.of("paramA", (o, v) -> o); + + @BeforeEach + void setUp() { + query = mock(SingleFieldQuery.class, + Mockito.withSettings().useConstructor(actionMap) + .defaultAnswer(Mockito.CALLS_REAL_METHODS)); + when(query.getQueryName()).thenReturn(testQueryName); + } + + @Test + void createQueryBuilderTest() { + String sampleQuery = "sample query"; + String sampleField = "fieldA"; + + query.createQueryBuilder(dsl.namedArgument("field", + new LiteralExpression(ExprValueUtils.stringValue(sampleField))), + dsl.namedArgument("query", + new LiteralExpression(ExprValueUtils.stringValue(sampleQuery)))); + + verify(query).createBuilder(eq(sampleField), + eq(sampleQuery)); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/system/OpenSearchSystemIndexTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/system/OpenSearchSystemIndexTest.java index 685d3e33af..e2efff22cb 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/system/OpenSearchSystemIndexTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/system/OpenSearchSystemIndexTest.java @@ -28,6 +28,7 @@ import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.ProjectOperator; +import org.opensearch.sql.storage.Table; @ExtendWith(MockitoExtension.class) class OpenSearchSystemIndexTest { @@ -35,6 +36,9 @@ class OpenSearchSystemIndexTest { @Mock private OpenSearchClient client; + @Mock + private Table table; + @Test void testGetFieldTypesOfMetaTable() { OpenSearchSystemIndex systemIndex = new OpenSearchSystemIndex(client, TABLE_INFO); @@ -61,7 +65,7 @@ void implement() { final PhysicalPlan plan = systemIndex.implement( project( - relation(TABLE_INFO), + relation(TABLE_INFO, table), projectExpr )); assertTrue(plan instanceof ProjectOperator); diff --git a/plugin/build.gradle b/plugin/build.gradle index 5c3b3974ef..c1aae613bd 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -88,6 +88,8 @@ configurations.all { resolutionStrategy.force 'com.google.guava:guava:31.0.1-jre' resolutionStrategy.force "com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:${jackson_version}" resolutionStrategy.force "com.fasterxml.jackson.core:jackson-databind:${jackson_version}" + resolutionStrategy.force "org.jetbrains.kotlin:kotlin-stdlib:1.6.0" + resolutionStrategy.force "org.jetbrains.kotlin:kotlin-stdlib-common:1.6.0" } compileJava { options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor']) @@ -99,6 +101,10 @@ compileTestJava { dependencies { api group: 'org.springframework', name: 'spring-beans', version: "${spring_version}" + api "com.fasterxml.jackson.core:jackson-core:${jackson_version}" + api "com.fasterxml.jackson.core:jackson-databind:${jackson_version}" + api "com.fasterxml.jackson.core:jackson-annotations:${jackson_version}" + api project(":ppl") api project(':legacy') api project(':opensearch') diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index a4a03fde11..200364580b 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -16,6 +16,7 @@ import org.opensearch.action.ActionResponse; import org.opensearch.action.ActionType; import org.opensearch.client.Client; +import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; @@ -31,6 +32,7 @@ import org.opensearch.env.NodeEnvironment; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.Plugin; +import org.opensearch.plugins.ReloadablePlugin; import org.opensearch.plugins.ScriptPlugin; import org.opensearch.repositories.RepositoriesService; import org.opensearch.rest.RestController; @@ -43,28 +45,37 @@ import org.opensearch.sql.legacy.metrics.Metrics; import org.opensearch.sql.legacy.plugin.RestSqlAction; import org.opensearch.sql.legacy.plugin.RestSqlStatsAction; +import org.opensearch.sql.opensearch.client.OpenSearchNodeClient; import org.opensearch.sql.opensearch.setting.LegacyOpenDistroSettings; import org.opensearch.sql.opensearch.setting.OpenSearchSettings; +import org.opensearch.sql.opensearch.storage.OpenSearchStorageEngine; import org.opensearch.sql.opensearch.storage.script.ExpressionScriptEngine; import org.opensearch.sql.opensearch.storage.serialization.DefaultExpressionSerializer; +import org.opensearch.sql.plugin.catalog.CatalogServiceImpl; +import org.opensearch.sql.plugin.catalog.CatalogSettings; import org.opensearch.sql.plugin.rest.RestPPLQueryAction; import org.opensearch.sql.plugin.rest.RestPPLStatsAction; import org.opensearch.sql.plugin.rest.RestQuerySettingsAction; import org.opensearch.sql.plugin.transport.PPLQueryAction; import org.opensearch.sql.plugin.transport.TransportPPLQueryAction; import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse; +import org.opensearch.sql.storage.StorageEngine; import org.opensearch.threadpool.ExecutorBuilder; import org.opensearch.threadpool.FixedExecutorBuilder; import org.opensearch.threadpool.ThreadPool; import org.opensearch.watcher.ResourceWatcherService; -public class SQLPlugin extends Plugin implements ActionPlugin, ScriptPlugin { +public class SQLPlugin extends Plugin implements ActionPlugin, ScriptPlugin, ReloadablePlugin { private ClusterService clusterService; - /** Settings should be inited when bootstrap the plugin. */ + /** + * Settings should be inited when bootstrap the plugin. + */ private org.opensearch.sql.common.setting.Settings pluginSettings; + private NodeClient client; + public String name() { return "sql"; } @@ -90,13 +101,16 @@ public List getRestHandlers( return Arrays.asList( new RestPPLQueryAction(pluginSettings, settings), - new RestSqlAction(settings, clusterService, pluginSettings), + new RestSqlAction(settings, clusterService, pluginSettings, + CatalogServiceImpl.getInstance()), new RestSqlStatsAction(settings, restController), new RestPPLStatsAction(settings, restController), new RestQuerySettingsAction(settings, restController)); } - /** Register action and handler so that transportClient can find proxy for action. */ + /** + * Register action and handler so that transportClient can find proxy for action. + */ @Override public List> getActions() { return Arrays.asList( @@ -120,7 +134,9 @@ public Collection createComponents( Supplier repositoriesServiceSupplier) { this.clusterService = clusterService; this.pluginSettings = new OpenSearchSettings(clusterService.getClusterSettings()); - + this.client = (NodeClient) client; + CatalogServiceImpl.getInstance().loadConnectors(clusterService.getSettings()); + CatalogServiceImpl.getInstance().registerOpenSearchStorageEngine(openSearchStorageEngine()); LocalClusterState.state().setClusterService(clusterService); LocalClusterState.state().setPluginSettings((OpenSearchSettings) pluginSettings); @@ -154,6 +170,7 @@ public List> getSettings() { return new ImmutableList.Builder>() .addAll(LegacyOpenDistroSettings.legacySettings()) .addAll(OpenSearchSettings.pluginSettings()) + .add(CatalogSettings.CATALOG_CONFIG) .build(); } @@ -161,4 +178,16 @@ public List> getSettings() { public ScriptEngine getScriptEngine(Settings settings, Collection> contexts) { return new ExpressionScriptEngine(new DefaultExpressionSerializer()); } + + @Override + public void reload(Settings settings) { + CatalogServiceImpl.getInstance().loadConnectors(clusterService.getSettings()); + CatalogServiceImpl.getInstance().registerOpenSearchStorageEngine(openSearchStorageEngine()); + } + + private StorageEngine openSearchStorageEngine() { + return new OpenSearchStorageEngine(new OpenSearchNodeClient(client), + pluginSettings); + } + } diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/catalog/CatalogServiceImpl.java b/plugin/src/main/java/org/opensearch/sql/plugin/catalog/CatalogServiceImpl.java new file mode 100644 index 0000000000..5a77961d8b --- /dev/null +++ b/plugin/src/main/java/org/opensearch/sql/plugin/catalog/CatalogServiceImpl.java @@ -0,0 +1,168 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.plugin.catalog; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import java.io.InputStream; +import java.net.URISyntaxException; +import java.security.PrivilegedExceptionAction; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.apache.commons.lang3.StringUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.settings.Settings; +import org.opensearch.sql.catalog.CatalogService; +import org.opensearch.sql.catalog.model.CatalogMetadata; +import org.opensearch.sql.catalog.model.ConnectorType; +import org.opensearch.sql.opensearch.security.SecurityAccess; +import org.opensearch.sql.storage.StorageEngine; + +/** + * This class manages catalogs and responsible for creating connectors to these catalogs. + */ +public class CatalogServiceImpl implements CatalogService { + + private static final CatalogServiceImpl INSTANCE = new CatalogServiceImpl(); + + private static final Logger LOG = LogManager.getLogger(); + + public static final String OPEN_SEARCH = "opensearch"; + + private Map storageEngineMap = new HashMap<>(); + + public static CatalogServiceImpl getInstance() { + return INSTANCE; + } + + private CatalogServiceImpl() { + } + + /** + * This function reads settings and loads connectors to the data stores. + * This will be invoked during start up and also when settings are updated. + * + * @param settings settings. + */ + public void loadConnectors(Settings settings) { + doPrivileged(() -> { + InputStream inputStream = CatalogSettings.CATALOG_CONFIG.get(settings); + if (inputStream != null) { + ObjectMapper objectMapper = new ObjectMapper(); + objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + try { + List catalogs = + objectMapper.readValue(inputStream, new TypeReference<>() { + }); + LOG.info(catalogs.toString()); + validateCatalogs(catalogs); + constructConnectors(catalogs); + } catch (IOException e) { + LOG.error("Catalog Configuration File uploaded is malformed. Verify and re-upload."); + throw new IllegalArgumentException( + "Malformed Catalog Configuration Json" + e.getMessage()); + } + } + return null; + }); + } + + @Override + public StorageEngine getStorageEngine(String catalog) { + if (catalog == null || !storageEngineMap.containsKey(catalog)) { + return storageEngineMap.get(OPEN_SEARCH); + } + return storageEngineMap.get(catalog); + } + + @Override + public Set getCatalogs() { + Set catalogs = storageEngineMap.keySet(); + catalogs.remove(OPEN_SEARCH); + return catalogs; + } + + @Override + public void registerOpenSearchStorageEngine(StorageEngine storageEngine) { + storageEngineMap.put(OPEN_SEARCH, storageEngine); + } + + private T doPrivileged(PrivilegedExceptionAction action) { + try { + return SecurityAccess.doPrivileged(action); + } catch (IOException e) { + throw new IllegalStateException("Failed to perform privileged action", e); + } + } + + private StorageEngine createStorageEngine(CatalogMetadata catalog) throws URISyntaxException { + StorageEngine storageEngine; + ConnectorType connector = catalog.getConnector(); + switch (connector) { + case PROMETHEUS: + storageEngine = null; + break; + default: + LOG.info( + "Unknown connector \"{}\". " + + "Please re-upload catalog configuration with a supported connector.", + connector); + throw new IllegalStateException( + "Unknown connector. Connector doesn't exist in the list of supported."); + } + return storageEngine; + } + + private void constructConnectors(List catalogs) throws URISyntaxException { + storageEngineMap = new HashMap<>(); + for (CatalogMetadata catalog : catalogs) { + String catalogName = catalog.getName(); + StorageEngine storageEngine = createStorageEngine(catalog); + storageEngineMap.put(catalogName, storageEngine); + } + } + + /** + * This can be moved to a different validator class + * when we introduce more connectors. + * + * @param catalogs catalogs. + */ + private void validateCatalogs(List catalogs) { + + Set reviewedCatalogs = new HashSet<>(); + for (CatalogMetadata catalog : catalogs) { + + if (StringUtils.isEmpty(catalog.getName())) { + LOG.error("Found a catalog with no name. {}", catalog.toString()); + throw new IllegalArgumentException( + "Missing Name Field from a catalog. Name is a required parameter."); + } + + if (StringUtils.isEmpty(catalog.getUri())) { + LOG.error("Found a catalog with no uri. {}", catalog.toString()); + throw new IllegalArgumentException( + "Missing URI Field from a catalog. URI is a required parameter."); + } + + String catalogName = catalog.getName(); + if (reviewedCatalogs.contains(catalogName)) { + LOG.error("Found duplicate catalog names"); + throw new IllegalArgumentException("Catalogs with same name are not allowed."); + } else { + reviewedCatalogs.add(catalogName); + } + } + } + + +} \ No newline at end of file diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/catalog/CatalogSettings.java b/plugin/src/main/java/org/opensearch/sql/plugin/catalog/CatalogSettings.java new file mode 100644 index 0000000000..20efce1b7a --- /dev/null +++ b/plugin/src/main/java/org/opensearch/sql/plugin/catalog/CatalogSettings.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.plugin.catalog; + +import java.io.InputStream; +import org.opensearch.common.settings.SecureSetting; +import org.opensearch.common.settings.Setting; + +public class CatalogSettings { + + public static final Setting CATALOG_CONFIG = SecureSetting.secureFile( + "plugins.query.federation.catalog.config", + null); +} diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/rest/OpenSearchPluginConfig.java b/plugin/src/main/java/org/opensearch/sql/plugin/rest/OpenSearchPluginConfig.java index c1b860877b..24d7e4e7f5 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/rest/OpenSearchPluginConfig.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/rest/OpenSearchPluginConfig.java @@ -7,7 +7,6 @@ package org.opensearch.sql.plugin.rest; import org.opensearch.client.node.NodeClient; -import org.opensearch.cluster.service.ClusterService; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.monitor.ResourceMonitor; @@ -31,9 +30,6 @@ @Configuration public class OpenSearchPluginConfig { - @Autowired - private ClusterService clusterService; - @Autowired private NodeClient nodeClient; @@ -42,12 +38,7 @@ public class OpenSearchPluginConfig { @Bean public OpenSearchClient client() { - return new OpenSearchNodeClient(clusterService, nodeClient); - } - - @Bean - public StorageEngine storageEngine() { - return new OpenSearchStorageEngine(client(), settings); + return new OpenSearchNodeClient(nodeClient); } @Bean diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java index 31317c1962..eaad009216 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java @@ -18,6 +18,7 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.sql.catalog.CatalogService; import org.opensearch.sql.common.response.ResponseListener; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.common.utils.QueryContext; @@ -26,6 +27,7 @@ import org.opensearch.sql.legacy.metrics.Metrics; import org.opensearch.sql.opensearch.security.SecurityAccess; import org.opensearch.sql.opensearch.setting.OpenSearchSettings; +import org.opensearch.sql.plugin.catalog.CatalogServiceImpl; import org.opensearch.sql.plugin.rest.OpenSearchPluginConfig; import org.opensearch.sql.ppl.PPLService; import org.opensearch.sql.ppl.config.PPLServiceConfig; @@ -53,6 +55,7 @@ public class TransportPPLQueryAction /** Settings required by been initialization. */ private final Settings pluginSettings; + /** Constructor of TransportPPLQueryAction. */ @Inject public TransportPPLQueryAction( @@ -98,6 +101,7 @@ private PPLService createPPLService(NodeClient client) { context.registerBean(ClusterService.class, () -> clusterService); context.registerBean(NodeClient.class, () -> client); context.registerBean(Settings.class, () -> pluginSettings); + context.registerBean(CatalogService.class, CatalogServiceImpl::getInstance); context.register(OpenSearchPluginConfig.class); context.register(PPLServiceConfig.class); context.refresh(); diff --git a/plugin/src/test/java/org/opensearch/sql/plugin/catalog/CatalogServiceImplTest.java b/plugin/src/test/java/org/opensearch/sql/plugin/catalog/CatalogServiceImplTest.java new file mode 100644 index 0000000000..678962cbb5 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/sql/plugin/catalog/CatalogServiceImplTest.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.plugin.catalog; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.HashSet; +import java.util.Set; +import lombok.SneakyThrows; +import org.junit.Assert; +import org.junit.Test; +import org.opensearch.common.settings.MockSecureSettings; +import org.opensearch.common.settings.Settings; + + +public class CatalogServiceImplTest { + + public static final String CATALOG_SETTING_METADATA_KEY = + "plugins.query.federation.catalog.config"; + + + @SneakyThrows + @Test + public void testLoadConnectors() { + Settings settings = getCatalogSettings("catalogs.json"); + CatalogServiceImpl.getInstance().loadConnectors(settings); + Set expected = new HashSet<>() {{ + add("prometheus"); + }}; + Assert.assertEquals(expected, CatalogServiceImpl.getInstance().getCatalogs()); + } + + + @SneakyThrows + @Test + public void testLoadConnectorsWithMultipleCatalogs() { + Settings settings = getCatalogSettings("multiple_catalogs.json"); + CatalogServiceImpl.getInstance().loadConnectors(settings); + Set expected = new HashSet<>() {{ + add("prometheus"); + add("prometheus-1"); + }}; + Assert.assertEquals(expected, CatalogServiceImpl.getInstance().getCatalogs()); + } + + @SneakyThrows + @Test + public void testLoadConnectorsWithMissingName() { + Settings settings = getCatalogSettings("catalog_missing_name.json"); + Assert.assertThrows(IllegalArgumentException.class, + () -> CatalogServiceImpl.getInstance().loadConnectors(settings)); + } + + @SneakyThrows + @Test + public void testLoadConnectorsWithDuplicateCatalogNames() { + Settings settings = getCatalogSettings("duplicate_catalog_names.json"); + Assert.assertThrows(IllegalArgumentException.class, + () -> CatalogServiceImpl.getInstance().loadConnectors(settings)); + } + + @SneakyThrows + @Test + public void testLoadConnectorsWithMalformedJson() { + Settings settings = getCatalogSettings("malformed_catalogs.json"); + Assert.assertThrows(IllegalArgumentException.class, + () -> CatalogServiceImpl.getInstance().loadConnectors(settings)); + } + + + private Settings getCatalogSettings(String filename) throws URISyntaxException, IOException { + MockSecureSettings mockSecureSettings = new MockSecureSettings(); + ClassLoader classLoader = getClass().getClassLoader(); + Path filepath = Paths.get(classLoader.getResource(filename).toURI()); + mockSecureSettings.setFile(CATALOG_SETTING_METADATA_KEY, Files.readAllBytes(filepath)); + return Settings.builder().setSecureSettings(mockSecureSettings).build(); + } + +} diff --git a/plugin/src/test/resources/catalog_missing_name.json b/plugin/src/test/resources/catalog_missing_name.json new file mode 100644 index 0000000000..86dc752cf0 --- /dev/null +++ b/plugin/src/test/resources/catalog_missing_name.json @@ -0,0 +1,11 @@ +[ + { + "connector": "prometheus", + "uri" : "http://localhost:9090", + "authentication" : { + "type" : "basicauth", + "username" : "admin", + "password" : "password" + } + } +] \ No newline at end of file diff --git a/plugin/src/test/resources/catalogs.json b/plugin/src/test/resources/catalogs.json new file mode 100644 index 0000000000..aae3403462 --- /dev/null +++ b/plugin/src/test/resources/catalogs.json @@ -0,0 +1,12 @@ +[ + { + "name" : "prometheus", + "connector": "prometheus", + "uri" : "http://localhost:9090", + "authentication" : { + "type" : "basicauth", + "username" : "admin", + "password" : "password" + } + } +] \ No newline at end of file diff --git a/plugin/src/test/resources/duplicate_catalog_names.json b/plugin/src/test/resources/duplicate_catalog_names.json new file mode 100644 index 0000000000..dab85770e9 --- /dev/null +++ b/plugin/src/test/resources/duplicate_catalog_names.json @@ -0,0 +1,20 @@ +[ + { + "connector": "prometheus", + "uri" : "http://localhost:9090", + "authentication" : { + "type" : "basicauth", + "username" : "admin", + "password" : "password" + } + }, + { + "connector": "prometheus", + "uri" : "http://localhost:9219", + "authentication" : { + "type" : "basicauth", + "username" : "admin", + "password" : "password" + } + } +] \ No newline at end of file diff --git a/plugin/src/test/resources/malformed_catalogs.json b/plugin/src/test/resources/malformed_catalogs.json new file mode 100644 index 0000000000..716bd363ce --- /dev/null +++ b/plugin/src/test/resources/malformed_catalogs.json @@ -0,0 +1 @@ +fasdfasdfasdf diff --git a/plugin/src/test/resources/multiple_catalogs.json b/plugin/src/test/resources/multiple_catalogs.json new file mode 100644 index 0000000000..112ecad858 --- /dev/null +++ b/plugin/src/test/resources/multiple_catalogs.json @@ -0,0 +1,22 @@ +[ + { + "name" : "prometheus", + "connector": "prometheus", + "uri" : "http://localhost:9090", + "authentication" : { + "type" : "basicauth", + "username" : "admin", + "password" : "password" + } + }, + { + "name" : "prometheus-1", + "connector": "prometheus", + "uri" : "http://localhost:9090", + "authentication" : { + "type" : "basicauth", + "username" : "admin", + "password" : "password" + } + } +] \ No newline at end of file diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/PPLService.java b/ppl/src/main/java/org/opensearch/sql/ppl/PPLService.java index 866326f562..ce5ba0f56f 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/PPLService.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/PPLService.java @@ -14,7 +14,9 @@ import org.apache.logging.log4j.Logger; import org.opensearch.sql.analysis.AnalysisContext; import org.opensearch.sql.analysis.Analyzer; +import org.opensearch.sql.analysis.ExpressionAnalyzer; import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.catalog.CatalogService; import org.opensearch.sql.common.response.ResponseListener; import org.opensearch.sql.common.utils.QueryContext; import org.opensearch.sql.executor.ExecutionEngine; @@ -31,20 +33,17 @@ import org.opensearch.sql.ppl.parser.AstExpressionBuilder; import org.opensearch.sql.ppl.utils.PPLQueryDataAnonymizer; import org.opensearch.sql.ppl.utils.UnresolvedPlanHelper; -import org.opensearch.sql.storage.StorageEngine; @RequiredArgsConstructor public class PPLService { private final PPLSyntaxParser parser; - private final Analyzer analyzer; - - private final StorageEngine storageEngine; - - private final ExecutionEngine executionEngine; + private final ExecutionEngine openSearchExecutionEngine; private final BuiltinFunctionRepository repository; + private final CatalogService catalogService; + private final PPLQueryDataAnonymizer anonymizer = new PPLQueryDataAnonymizer(); private static final Logger LOG = LogManager.getLogger(); @@ -57,7 +56,7 @@ public class PPLService { */ public void execute(PPLQueryRequest request, ResponseListener listener) { try { - executionEngine.execute(plan(request), listener); + openSearchExecutionEngine.execute(plan(request), listener); } catch (Exception e) { listener.onFailure(e); } @@ -67,12 +66,12 @@ public void execute(PPLQueryRequest request, ResponseListener lis * Explain the query in {@link PPLQueryRequest} using {@link ResponseListener} to * get and format explain response. * - * @param request {@link PPLQueryRequest} + * @param request {@link PPLQueryRequest} * @param listener {@link ResponseListener} for explain response */ public void explain(PPLQueryRequest request, ResponseListener listener) { try { - executionEngine.explain(plan(request), listener); + openSearchExecutionEngine.explain(plan(request), listener); } catch (Exception e) { listener.onFailure(e); } @@ -83,16 +82,16 @@ private PhysicalPlan plan(PPLQueryRequest request) { ParseTree cst = parser.parse(request.getRequest()); UnresolvedPlan ast = cst.accept( new AstBuilder(new AstExpressionBuilder(), request.getRequest())); - LOG.info("[{}] Incoming request {}", QueryContext.getRequestId(), anonymizer.anonymizeData(ast)); - // 2.Analyze abstract syntax to generate logical plan - LogicalPlan logicalPlan = analyzer.analyze(UnresolvedPlanHelper.addSelectAll(ast), - new AnalysisContext()); + LogicalPlan logicalPlan = + new Analyzer(new ExpressionAnalyzer(repository), catalogService).analyze( + UnresolvedPlanHelper.addSelectAll(ast), + new AnalysisContext()); // 3.Generate optimal physical plan from logical plan - return new Planner(storageEngine, LogicalPlanOptimizer.create(new DSL(repository))) + return new Planner(LogicalPlanOptimizer.create(new DSL(repository))) .plan(logicalPlan); } diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/config/PPLServiceConfig.java b/ppl/src/main/java/org/opensearch/sql/ppl/config/PPLServiceConfig.java index 72eb991671..bd6c4e3937 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/config/PPLServiceConfig.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/config/PPLServiceConfig.java @@ -6,14 +6,12 @@ package org.opensearch.sql.ppl.config; -import org.opensearch.sql.analysis.Analyzer; -import org.opensearch.sql.analysis.ExpressionAnalyzer; +import org.opensearch.sql.catalog.CatalogService; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; import org.opensearch.sql.ppl.PPLService; import org.opensearch.sql.ppl.antlr.PPLSyntaxParser; -import org.opensearch.sql.storage.StorageEngine; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -24,23 +22,24 @@ public class PPLServiceConfig { @Autowired - private StorageEngine storageEngine; + private ExecutionEngine executionEngine; @Autowired - private ExecutionEngine executionEngine; + private CatalogService catalogService; @Autowired private BuiltinFunctionRepository functionRepository; - @Bean - public Analyzer analyzer() { - return new Analyzer(new ExpressionAnalyzer(functionRepository), storageEngine); - } - + /** + * The registration of OpenSearch storage engine happens here because + * OpenSearchStorageEngine is dependent on NodeClient. + * + * @return PPLService. + */ @Bean public PPLService pplService() { - return new PPLService(new PPLSyntaxParser(), analyzer(), storageEngine, executionEngine, - functionRepository); + return new PPLService(new PPLSyntaxParser(), executionEngine, + functionRepository, catalogService); } } diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index d7f97e3d35..6d5de4dcc6 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -174,10 +174,10 @@ public UnresolvedPlan visitStatsCommand(StatsCommandContext ctx) { Optional.ofNullable(ctx.statsByClause()) .map(OpenSearchPPLParser.StatsByClauseContext::fieldList) .map(expr -> expr.fieldExpression().stream() - .map(groupCtx -> - (UnresolvedExpression) new Alias(getTextInQuery(groupCtx), - internalVisitExpression(groupCtx))) - .collect(Collectors.toList())) + .map(groupCtx -> + (UnresolvedExpression) new Alias(getTextInQuery(groupCtx), + internalVisitExpression(groupCtx))) + .collect(Collectors.toList())) .orElse(Collections.emptyList()); UnresolvedExpression span = @@ -334,10 +334,10 @@ protected UnresolvedPlan aggregateResult(UnresolvedPlan aggregate, UnresolvedPla public UnresolvedPlan visitKmeansCommand(KmeansCommandContext ctx) { ImmutableMap.Builder builder = ImmutableMap.builder(); ctx.kmeansParameter() - .forEach(x -> { - builder.put(x.children.get(0).toString(), - (Literal) internalVisitExpression(x.children.get(2))); - }); + .forEach(x -> { + builder.put(x.children.get(0).toString(), + (Literal) internalVisitExpression(x.children.get(2))); + }); return new Kmeans(builder.build()); } @@ -348,10 +348,10 @@ public UnresolvedPlan visitKmeansCommand(KmeansCommandContext ctx) { public UnresolvedPlan visitAdCommand(AdCommandContext ctx) { ImmutableMap.Builder builder = ImmutableMap.builder(); ctx.adParameter() - .forEach(x -> { - builder.put(x.children.get(0).toString(), - (Literal) internalVisitExpression(x.children.get(2))); - }); + .forEach(x -> { + builder.put(x.children.get(0).toString(), + (Literal) internalVisitExpression(x.children.get(2))); + }); return new AD(builder.build()); } diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 6e5893d6a3..99483d2403 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -271,7 +271,11 @@ public UnresolvedExpression visitMultiFieldRelevanceFunction( @Override public UnresolvedExpression visitTableSource(TableSourceContext ctx) { - return visitIdentifiers(Arrays.asList(ctx)); + if (ctx.getChild(0) instanceof IdentsAsQualifiedNameContext) { + return visitIdentifiers(((IdentsAsQualifiedNameContext) ctx.getChild(0)).ident()); + } else { + return visitIdentifiers(Arrays.asList(ctx)); + } } /** @@ -374,4 +378,5 @@ private List multiFieldRelevanceArguments( v.relevanceArgValue().getText()), DataType.STRING)))); return builder.build(); } + } diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java index 0123d3a40b..ec513c7c4d 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java @@ -75,7 +75,7 @@ public String anonymizeData(UnresolvedPlan plan) { @Override public String visitRelation(Relation node, String context) { - return StringUtils.format("source=%s", node.getTableName()); + return StringUtils.format("source=%s", node.getFullyQualifiedTableNameWithCatalog()); } @Override diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/PPLServiceTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/PPLServiceTest.java index 7f28aeee40..8c8760c66d 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/PPLServiceTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/PPLServiceTest.java @@ -11,13 +11,16 @@ import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import java.util.Collections; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; +import org.opensearch.sql.catalog.CatalogService; import org.opensearch.sql.common.response.ResponseListener; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.executor.ExecutionEngine; @@ -43,6 +46,9 @@ public class PPLServiceTest { @Mock private ExecutionEngine executionEngine; + @Mock + private CatalogService catalogService; + @Mock private Table table; @@ -63,6 +69,7 @@ public void setUp() { context.registerBean(StorageEngine.class, () -> storageEngine); context.registerBean(ExecutionEngine.class, () -> executionEngine); + context.registerBean(CatalogService.class, () -> catalogService); context.register(PPLServiceConfig.class); context.refresh(); pplService = context.getBean(PPLService.class); @@ -70,6 +77,7 @@ public void setUp() { @Test public void testExecuteShouldPass() { + when(catalogService.getStorageEngine(any())).thenReturn(storageEngine); doAnswer(invocation -> { ResponseListener listener = invocation.getArgument(1); listener.onResponse(new QueryResponse(schema, Collections.emptyList())); @@ -92,6 +100,7 @@ public void onFailure(Exception e) { @Test public void testExecuteCsvFormatShouldPass() { + when(catalogService.getStorageEngine(any())).thenReturn(storageEngine); doAnswer(invocation -> { ResponseListener listener = invocation.getArgument(1); listener.onResponse(new QueryResponse(schema, Collections.emptyList())); @@ -113,6 +122,7 @@ public void onFailure(Exception e) { @Test public void testExplainShouldPass() { + when(catalogService.getStorageEngine(any())).thenReturn(storageEngine); doAnswer(invocation -> { ResponseListener listener = invocation.getArgument(1); listener.onResponse(new ExplainResponse(new ExplainResponseNode("test"))); @@ -151,7 +161,7 @@ public void onFailure(Exception e) { @Test public void testExplainWithIllegalQueryShouldBeCaughtByHandler() { pplService.explain(new PPLQueryRequest("search", null, null), - new ResponseListener() { + new ResponseListener<>() { @Override public void onResponse(ExplainResponse pplQueryResponse) { Assert.fail(); @@ -164,6 +174,29 @@ public void onFailure(Exception e) { }); } + @Test + public void testPrometheusQuery() { + when(catalogService.getStorageEngine(any())).thenReturn(storageEngine); + doAnswer(invocation -> { + ResponseListener listener = invocation.getArgument(1); + listener.onResponse(new QueryResponse(schema, Collections.emptyList())); + return null; + }).when(executionEngine).execute(any(), any()); + + pplService.execute(new PPLQueryRequest("source = prometheus.http_requests_total", null, null), + new ResponseListener<>() { + @Override + public void onResponse(QueryResponse pplQueryResponse) { + + } + + @Override + public void onFailure(Exception e) { + Assert.fail(); + } + }); + } + @Test public void test() { pplService.execute(new PPLQueryRequest("search", null, null), diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/config/PPLServiceConfigTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/config/PPLServiceConfigTest.java deleted file mode 100644 index a63b3b6899..0000000000 --- a/ppl/src/test/java/org/opensearch/sql/ppl/config/PPLServiceConfigTest.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.ppl.config; - -import static org.junit.Assert.assertNotNull; - -import org.junit.Test; -import org.opensearch.sql.ppl.PPLService; - -public class PPLServiceConfigTest { - @Test - public void testConfigPPLServiceShouldPass() { - PPLServiceConfig config = new PPLServiceConfig(); - PPLService service = config.pplService(); - assertNotNull(service); - } -} diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index ce5f8f9ec5..8fbf502019 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -30,7 +30,6 @@ import static org.opensearch.sql.ast.dsl.AstDSL.map; import static org.opensearch.sql.ast.dsl.AstDSL.nullLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.parse; -import static org.opensearch.sql.ast.dsl.AstDSL.project; import static org.opensearch.sql.ast.dsl.AstDSL.projectWithArg; import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; import static org.opensearch.sql.ast.dsl.AstDSL.rareTopN; @@ -47,7 +46,6 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.sql.ast.Node; -import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.SpanUnit; @@ -61,7 +59,7 @@ public class AstBuilderTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); - private final PPLSyntaxParser parser = new PPLSyntaxParser(); + private PPLSyntaxParser parser = new PPLSyntaxParser(); @Test public void testSearchCommand() { @@ -73,6 +71,27 @@ public void testSearchCommand() { ); } + @Test + public void testPrometheusSearchCommand() { + assertEqual("search source = prometheus.http_requests_total", + relation(qualifiedName("http_requests_total")) + ); + } + + @Test + public void testSearchCommandWithCatalogEscape() { + assertEqual("search source = `prometheus.http_requests_total`", + relation("prometheus.http_requests_total") + ); + } + + @Test + public void testSearchCommandWithDotInIndexName() { + assertEqual("search source = http_requests_total.test", + relation("test") + ); + } + @Test public void testSearchCommandString() { assertEqual("search source=t a=\"a\"", @@ -610,18 +629,18 @@ public void testParseCommand() { @Test public void testKmeansCommand() { assertEqual("source=t | kmeans centroids=3 iterations=2 distance_type='l1'", - new Kmeans(relation("t"), ImmutableMap.builder() - .put("centroids", new Literal(3, DataType.INTEGER)) - .put("iterations", new Literal(2, DataType.INTEGER)) - .put("distance_type", new Literal("l1", DataType.STRING)) - .build() - )); + new Kmeans(relation("t"), ImmutableMap.builder() + .put("centroids", new Literal(3, DataType.INTEGER)) + .put("iterations", new Literal(2, DataType.INTEGER)) + .put("distance_type", new Literal("l1", DataType.STRING)) + .build() + )); } @Test public void testKmeansCommandWithoutParameter() { assertEqual("source=t | kmeans", - new Kmeans(relation("t"), ImmutableMap.of())); + new Kmeans(relation("t"), ImmutableMap.of())); } @Test @@ -639,50 +658,50 @@ public void testDescribeCommandWithMultipleIndices() { @Test public void test_fitRCFADCommand_withoutDataFormat() { assertEqual("source=t | AD shingle_size=10 time_decay=0.0001 time_field='timestamp' " - + "anomaly_rate=0.1 anomaly_score_threshold=0.1 sample_size=256 " - + "number_of_trees=256 time_zone='PST' output_after=256 " - + "training_data_size=256", - new AD(relation("t"), ImmutableMap.builder() - .put("anomaly_rate", new Literal(0.1, DataType.DOUBLE)) - .put("anomaly_score_threshold", new Literal(0.1, DataType.DOUBLE)) - .put("sample_size", new Literal(256, DataType.INTEGER)) - .put("number_of_trees", new Literal(256, DataType.INTEGER)) - .put("time_zone", new Literal("PST", DataType.STRING)) - .put("output_after", new Literal(256, DataType.INTEGER)) - .put("shingle_size", new Literal(10, DataType.INTEGER)) - .put("time_decay", new Literal(0.0001, DataType.DOUBLE)) - .put("time_field", new Literal("timestamp", DataType.STRING)) - .put("training_data_size", new Literal(256, DataType.INTEGER)) - .build() - )); + + "anomaly_rate=0.1 anomaly_score_threshold=0.1 sample_size=256 " + + "number_of_trees=256 time_zone='PST' output_after=256 " + + "training_data_size=256", + new AD(relation("t"), ImmutableMap.builder() + .put("anomaly_rate", new Literal(0.1, DataType.DOUBLE)) + .put("anomaly_score_threshold", new Literal(0.1, DataType.DOUBLE)) + .put("sample_size", new Literal(256, DataType.INTEGER)) + .put("number_of_trees", new Literal(256, DataType.INTEGER)) + .put("time_zone", new Literal("PST", DataType.STRING)) + .put("output_after", new Literal(256, DataType.INTEGER)) + .put("shingle_size", new Literal(10, DataType.INTEGER)) + .put("time_decay", new Literal(0.0001, DataType.DOUBLE)) + .put("time_field", new Literal("timestamp", DataType.STRING)) + .put("training_data_size", new Literal(256, DataType.INTEGER)) + .build() + )); } @Test public void test_fitRCFADCommand_withDataFormat() { assertEqual("source=t | AD shingle_size=10 time_decay=0.0001 time_field='timestamp' " - + "anomaly_rate=0.1 anomaly_score_threshold=0.1 sample_size=256 " - + "number_of_trees=256 time_zone='PST' output_after=256 " - + "training_data_size=256 date_format='HH:mm:ss yyyy-MM-dd'", - new AD(relation("t"), ImmutableMap.builder() - .put("anomaly_rate", new Literal(0.1, DataType.DOUBLE)) - .put("anomaly_score_threshold", new Literal(0.1, DataType.DOUBLE)) - .put("sample_size", new Literal(256, DataType.INTEGER)) - .put("number_of_trees", new Literal(256, DataType.INTEGER)) - .put("date_format", new Literal("HH:mm:ss yyyy-MM-dd", DataType.STRING)) - .put("time_zone", new Literal("PST", DataType.STRING)) - .put("output_after", new Literal(256, DataType.INTEGER)) - .put("shingle_size", new Literal(10, DataType.INTEGER)) - .put("time_decay", new Literal(0.0001, DataType.DOUBLE)) - .put("time_field", new Literal("timestamp", DataType.STRING)) - .put("training_data_size", new Literal(256, DataType.INTEGER)) - .build() - )); + + "anomaly_rate=0.1 anomaly_score_threshold=0.1 sample_size=256 " + + "number_of_trees=256 time_zone='PST' output_after=256 " + + "training_data_size=256 date_format='HH:mm:ss yyyy-MM-dd'", + new AD(relation("t"), ImmutableMap.builder() + .put("anomaly_rate", new Literal(0.1, DataType.DOUBLE)) + .put("anomaly_score_threshold", new Literal(0.1, DataType.DOUBLE)) + .put("sample_size", new Literal(256, DataType.INTEGER)) + .put("number_of_trees", new Literal(256, DataType.INTEGER)) + .put("date_format", new Literal("HH:mm:ss yyyy-MM-dd", DataType.STRING)) + .put("time_zone", new Literal("PST", DataType.STRING)) + .put("output_after", new Literal(256, DataType.INTEGER)) + .put("shingle_size", new Literal(10, DataType.INTEGER)) + .put("time_decay", new Literal(0.0001, DataType.DOUBLE)) + .put("time_field", new Literal("timestamp", DataType.STRING)) + .put("training_data_size", new Literal(256, DataType.INTEGER)) + .build() + )); } @Test public void test_batchRCFADCommand() { assertEqual("source=t | AD", - new AD(relation("t"),ImmutableMap.of())); + new AD(relation("t"), ImmutableMap.of())); } protected void assertEqual(String query, Node expectedPlan) { diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java index f2aff5a7e7..bb3315d5c8 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java @@ -25,7 +25,6 @@ import static org.opensearch.sql.ast.dsl.AstDSL.exprList; import static org.opensearch.sql.ast.dsl.AstDSL.field; import static org.opensearch.sql.ast.dsl.AstDSL.filter; -import static org.opensearch.sql.ast.dsl.AstDSL.floatLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.function; import static org.opensearch.sql.ast.dsl.AstDSL.in; import static org.opensearch.sql.ast.dsl.AstDSL.intLiteral; @@ -44,10 +43,14 @@ import static org.opensearch.sql.ast.dsl.AstDSL.xor; import com.google.common.collect.ImmutableMap; +import java.util.Arrays; +import java.util.Collections; import org.junit.Ignore; import org.junit.Test; import org.opensearch.sql.ast.expression.AllFields; +import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.RelevanceFieldList; public class AstExpressionBuilderTest extends AstBuilderTest { diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java index 46af993fc1..7caa4bab13 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java @@ -7,17 +7,26 @@ package org.opensearch.sql.ppl.utils; import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.when; import static org.opensearch.sql.ast.dsl.AstDSL.field; import static org.opensearch.sql.ast.dsl.AstDSL.projectWithArg; import static org.opensearch.sql.ast.dsl.AstDSL.relation; +import com.google.common.collect.ImmutableSet; import java.util.Collections; +import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.catalog.CatalogService; import org.opensearch.sql.ppl.antlr.PPLSyntaxParser; import org.opensearch.sql.ppl.parser.AstBuilder; import org.opensearch.sql.ppl.parser.AstExpressionBuilder; +@RunWith(MockitoJUnitRunner.class) public class PPLQueryDataAnonymizerTest { private final PPLSyntaxParser parser = new PPLSyntaxParser(); @@ -29,6 +38,13 @@ public void testSearchCommand() { ); } + @Test + public void testPrometheusPPLCommand() { + assertEquals("source=prometheus.http_requests_process", + anonymize("source=prometheus.http_requests_process") + ); + } + @Test public void testWhereCommand() { assertEquals("source=t | where a = ***", diff --git a/release-notes/opensearch-sql.release-notes-2.3.0.0.md b/release-notes/opensearch-sql.release-notes-2.3.0.0.md new file mode 100644 index 0000000000..9ad5daa256 --- /dev/null +++ b/release-notes/opensearch-sql.release-notes-2.3.0.0.md @@ -0,0 +1,24 @@ +### Version 2.3.0.0 Release Notes + +Compatible with OpenSearch and OpenSearch Dashboards Version 2.3.0 + +### Features +* Add maketime and makedate datetime functions ([#755](https://github.com/opensearch-project/sql/pull/755)) + +### Enhancements +* Refactor implementation of relevance queries ([#746](https://github.com/opensearch-project/sql/pull/746)) +* Extend query size limit using scroll ([#716](https://github.com/opensearch-project/sql/pull/716)) +* Add any case of arguments in relevancy based functions to be allowed ([#744](https://github.com/opensearch-project/sql/pull/744)) + +### Bug Fixes +* Fix unit test in PowerBI connector ([#800](https://github.com/opensearch-project/sql/pull/800)) + +### Infrastructure +* Schedule request in worker thread ([#748](https://github.com/opensearch-project/sql/pull/748)) +* Deprecated ClusterService and Using NodeClient to fetch metadata ([#774](https://github.com/opensearch-project/sql/pull/774)) +* Change master node timeout to new API ([#793](https://github.com/opensearch-project/sql/pull/793)) + +### Documentation +* Adding documentation about double quote implementation ([#723](https://github.com/opensearch-project/sql/pull/723)) +* Add PPL security setting documentation ([#777](https://github.com/opensearch-project/sql/pull/777)) +* Update PPL docs link for workbench ([#758](https://github.com/opensearch-project/sql/pull/758)) diff --git a/scripts/bwctest.sh b/scripts/bwctest.sh old mode 100644 new mode 100755 diff --git a/sql-jdbc/build.gradle b/sql-jdbc/build.gradle index dd629e438f..a696a7c973 100644 --- a/sql-jdbc/build.gradle +++ b/sql-jdbc/build.gradle @@ -24,7 +24,7 @@ plugins { group 'org.opensearch.client' // keep version in sync with version in Driver source -version '2.2.0.0' +version '1.1.0.1' boolean snapshot = "true".equals(System.getProperty("build.snapshot", "false")); if (snapshot) { diff --git a/sql-odbc/scripts/build_libcurl-vcpkg.ps1 b/sql-odbc/scripts/build_libcurl-vcpkg.ps1 deleted file mode 100644 index 8fa08b228f..0000000000 --- a/sql-odbc/scripts/build_libcurl-vcpkg.ps1 +++ /dev/null @@ -1,11 +0,0 @@ -$SRC_DIR = $args[0] -$LIBCURL_WIN_ARCH = $args[1] - -if (!("${SRC_DIR}/packages/curl_${LIBCURL_WIN_ARCH}-windows" | Test-Path)) -{ - git clone https://github.com/Microsoft/vcpkg.git $SRC_DIR - Set-Location $SRC_DIR - cmd.exe /c bootstrap-vcpkg.bat - .\vcpkg.exe integrate install - .\vcpkg.exe install curl[tool]:${LIBCURL_WIN_ARCH}-windows -} diff --git a/sql-odbc/scripts/build_windows.ps1 b/sql-odbc/scripts/build_windows.ps1 index 48e32345b6..49b857ed8d 100644 --- a/sql-odbc/scripts/build_windows.ps1 +++ b/sql-odbc/scripts/build_windows.ps1 @@ -21,9 +21,8 @@ $BUILD_DIR = "${WORKING_DIR}\build" # $BUILD_DIR = "${WORKING_DIR}\build\${CONFIGURATION}${BITNESS}" New-Item -Path $BUILD_DIR -ItemType Directory -Force | Out-Null -$VCPKG_DIR = "${WORKING_DIR}/src/vcpkg" - -.\scripts\build_libcurl-vcpkg.ps1 $VCPKG_DIR $LIBCURL_WIN_ARCH +$VCPKG_DIR = $Env:VCPKG_ROOT +vcpkg.exe install curl[tool]:${LIBCURL_WIN_ARCH}-windows Set-Location $CURRENT_DIR diff --git a/sql/src/main/java/org/opensearch/sql/sql/SQLService.java b/sql/src/main/java/org/opensearch/sql/sql/SQLService.java index 991e9df12a..76de0f6249 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/SQLService.java +++ b/sql/src/main/java/org/opensearch/sql/sql/SQLService.java @@ -36,8 +36,6 @@ public class SQLService { private final Analyzer analyzer; - private final StorageEngine storageEngine; - private final ExecutionEngine executionEngine; private final BuiltinFunctionRepository repository; @@ -103,7 +101,7 @@ public LogicalPlan analyze(UnresolvedPlan ast) { * Generate optimal physical plan from logical plan. */ public PhysicalPlan plan(LogicalPlan logicalPlan) { - return new Planner(storageEngine, LogicalPlanOptimizer.create(new DSL(repository))) + return new Planner(LogicalPlanOptimizer.create(new DSL(repository))) .plan(logicalPlan); } diff --git a/sql/src/main/java/org/opensearch/sql/sql/config/SQLServiceConfig.java b/sql/src/main/java/org/opensearch/sql/sql/config/SQLServiceConfig.java index 61807f084b..2d22d92081 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/config/SQLServiceConfig.java +++ b/sql/src/main/java/org/opensearch/sql/sql/config/SQLServiceConfig.java @@ -8,6 +8,7 @@ import org.opensearch.sql.analysis.Analyzer; import org.opensearch.sql.analysis.ExpressionAnalyzer; +import org.opensearch.sql.catalog.CatalogService; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; @@ -27,22 +28,28 @@ public class SQLServiceConfig { @Autowired - private StorageEngine storageEngine; + private ExecutionEngine executionEngine; @Autowired - private ExecutionEngine executionEngine; + private CatalogService catalogService; @Autowired private BuiltinFunctionRepository functionRepository; @Bean public Analyzer analyzer() { - return new Analyzer(new ExpressionAnalyzer(functionRepository), storageEngine); + return new Analyzer(new ExpressionAnalyzer(functionRepository), catalogService); } + /** + * The registration of OpenSearch storage engine happens here because + * OpenSearchStorageEngine is dependent on NodeClient. + * + * @return SQLService. + */ @Bean public SQLService sqlService() { - return new SQLService(new SQLSyntaxParser(), analyzer(), storageEngine, executionEngine, + return new SQLService(new SQLSyntaxParser(), analyzer(), executionEngine, functionRepository); } diff --git a/sql/src/test/java/org/opensearch/sql/sql/SQLServiceTest.java b/sql/src/test/java/org/opensearch/sql/sql/SQLServiceTest.java index 1c49d8d2d4..774c5e2d52 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/SQLServiceTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/SQLServiceTest.java @@ -21,6 +21,7 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.catalog.CatalogService; import org.opensearch.sql.common.response.ResponseListener; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.executor.ExecutionEngine.ExplainResponse; @@ -44,6 +45,9 @@ class SQLServiceTest { @Mock private ExecutionEngine executionEngine; + @Mock + private CatalogService catalogService; + @Mock private ExecutionEngine.Schema schema; @@ -51,6 +55,7 @@ class SQLServiceTest { public void setUp() { context.registerBean(StorageEngine.class, () -> storageEngine); context.registerBean(ExecutionEngine.class, () -> executionEngine); + context.registerBean(CatalogService.class, () -> catalogService); context.register(SQLServiceConfig.class); context.refresh(); sqlService = context.getBean(SQLService.class); diff --git a/sql/src/test/java/org/opensearch/sql/sql/config/SQLServiceConfigTest.java b/sql/src/test/java/org/opensearch/sql/sql/config/SQLServiceConfigTest.java deleted file mode 100644 index e52dbaa13a..0000000000 --- a/sql/src/test/java/org/opensearch/sql/sql/config/SQLServiceConfigTest.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.sql.config; - -import static org.junit.jupiter.api.Assertions.assertNotNull; - -import org.junit.jupiter.api.Test; - -class SQLServiceConfigTest { - - @Test - public void shouldReturnSQLService() { - SQLServiceConfig config = new SQLServiceConfig(); - assertNotNull(config.sqlService()); - } - -} diff --git a/workbench/opensearch_dashboards.json b/workbench/opensearch_dashboards.json index b992549d7d..79aefec25f 100644 --- a/workbench/opensearch_dashboards.json +++ b/workbench/opensearch_dashboards.json @@ -1,7 +1,7 @@ { "id": "queryWorkbenchDashboards", - "version": "2.2.0.0", - "opensearchDashboardsVersion": "2.2.0", + "version": "2.3.0.0", + "opensearchDashboardsVersion": "2.3.0", "server": true, "ui": true, "requiredPlugins": ["navigation"], diff --git a/workbench/package.json b/workbench/package.json index 74cf2c9f41..2fddbb9937 100644 --- a/workbench/package.json +++ b/workbench/package.json @@ -1,6 +1,6 @@ { "name": "opensearch-query-workbench", - "version": "2.2.0.0", + "version": "2.3.0.0", "description": "Query Workbench", "main": "index.js", "license": "Apache-2.0",