diff --git a/DEVELOPER_GUIDE.rst b/DEVELOPER_GUIDE.rst index 516cf23556..923cb459f9 100644 --- a/DEVELOPER_GUIDE.rst +++ b/DEVELOPER_GUIDE.rst @@ -147,6 +147,7 @@ The plugin codebase is in standard layout of Gradle project:: ├── plugin ├── protocol ├── ppl + ├── spark ├── sql ├── sql-cli ├── sql-jdbc @@ -161,6 +162,7 @@ Here are sub-folders (Gradle modules) for plugin source code: - ``core``: core query engine. - ``opensearch``: OpenSearch storage engine. - ``prometheus``: Prometheus storage engine. +- ``spark`` : Spark storage engine - ``protocol``: request/response protocol formatter. - ``common``: common util code. - ``integ-test``: integration and comparison test. diff --git a/docs/user/ppl/admin/spark_connector.rst b/docs/user/ppl/admin/spark_connector.rst new file mode 100644 index 0000000000..8ff8dd944e --- /dev/null +++ b/docs/user/ppl/admin/spark_connector.rst @@ -0,0 +1,92 @@ +.. highlight:: sh + +==================== +Spark Connector +==================== + +.. rubric:: Table of contents + +.. contents:: + :local: + :depth: 1 + + +Introduction +============ + +This page covers spark connector properties for dataSource configuration +and the nuances associated with spark connector. + + +Spark Connector Properties in DataSource Configuration +======================================================== +Spark Connector Properties. + +* ``spark.connector`` [Required]. + * This parameters provides the spark client information for connection. +* ``spark.sql.application`` [Optional]. + * This parameters provides the spark sql application jar. Default value is ``s3://spark-datasource/sql-job.jar``. +* ``emr.cluster`` [Required]. + * This parameters provides the emr cluster id information. +* ``emr.auth.type`` [Required] + * This parameters provides the authentication type information. + * Spark emr connector currently supports ``awssigv4`` authentication mechanism and following parameters are required. + * ``emr.auth.region``, ``emr.auth.access_key`` and ``emr.auth.secret_key`` +* ``spark.datasource.flint.*`` [Optional] + * This parameters provides the Opensearch domain host information for flint integration. + * ``spark.datasource.flint.integration`` [Optional] + * Default value for integration jar is ``s3://spark-datasource/flint-spark-integration-assembly-0.1.0-SNAPSHOT.jar``. + * ``spark.datasource.flint.host`` [Optional] + * Default value for host is ``localhost``. + * ``spark.datasource.flint.port`` [Optional] + * Default value for port is ``9200``. + * ``spark.datasource.flint.scheme`` [Optional] + * Default value for scheme is ``http``. + * ``spark.datasource.flint.auth`` [Optional] + * Default value for auth is ``false``. + * ``spark.datasource.flint.region`` [Optional] + * Default value for auth is ``us-west-2``. + +Example spark dataSource configuration +======================================== + +AWSSigV4 Auth:: + + [{ + "name" : "my_spark", + "connector": "spark", + "properties" : { + "spark.connector": "emr", + "emr.cluster" : "{{clusterId}}", + "emr.auth.type" : "awssigv4", + "emr.auth.region" : "us-east-1", + "emr.auth.access_key" : "{{accessKey}}" + "emr.auth.secret_key" : "{{secretKey}}" + "spark.datasource.flint.host" : "{{opensearchHost}}", + "spark.datasource.flint.port" : "{{opensearchPort}}", + "spark.datasource.flint.scheme" : "{{opensearchScheme}}", + "spark.datasource.flint.auth" : "{{opensearchAuth}}", + "spark.datasource.flint.region" : "{{opensearchRegion}}", + } + }] + + +Spark SQL Support +================== + +`sql` Function +---------------------------- +Spark connector offers `sql` function. This function can be used to run spark sql query. +The function takes spark sql query as input. Argument should be either passed by name or positionArguments should be either passed by name or position. +`source=my_spark.sql('select 1')` +or +`source=my_spark.sql(query='select 1')` +Example:: + + > source=my_spark.sql('select 1') + +---+ + | 1 | + |---+ + | 1 | + +---+ + diff --git a/spark/build.gradle b/spark/build.gradle index 58103dc67f..89842e5ea8 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -19,10 +19,12 @@ dependencies { implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" implementation group: 'org.json', name: 'json', version: '20230227' + implementation group: 'com.amazonaws', name: 'aws-java-sdk-emr', version: '1.12.1' testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') - testImplementation group: 'org.mockito', name: 'mockito-core', version: '3.12.4' - testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: '3.12.4' + testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.2.0' + testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: '5.2.0' + testImplementation 'junit:junit:4.13.1' } test { diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EmrClientImpl.java b/spark/src/main/java/org/opensearch/sql/spark/client/EmrClientImpl.java new file mode 100644 index 0000000000..1e2475c196 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EmrClientImpl.java @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.client; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_INDEX_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_SQL_APPLICATION_JAR; + +import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce; +import com.amazonaws.services.elasticmapreduce.model.ActionOnFailure; +import com.amazonaws.services.elasticmapreduce.model.AddJobFlowStepsRequest; +import com.amazonaws.services.elasticmapreduce.model.AddJobFlowStepsResult; +import com.amazonaws.services.elasticmapreduce.model.DescribeStepRequest; +import com.amazonaws.services.elasticmapreduce.model.HadoopJarStepConfig; +import com.amazonaws.services.elasticmapreduce.model.StepConfig; +import com.amazonaws.services.elasticmapreduce.model.StepStatus; +import com.google.common.annotations.VisibleForTesting; +import java.io.IOException; +import lombok.SneakyThrows; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.json.JSONObject; +import org.opensearch.sql.spark.helper.FlintHelper; +import org.opensearch.sql.spark.response.SparkResponse; + +public class EmrClientImpl implements SparkClient { + private final AmazonElasticMapReduce emr; + private final String emrCluster; + private final FlintHelper flint; + private final String sparkApplicationJar; + private static final Logger logger = LogManager.getLogger(EmrClientImpl.class); + private SparkResponse sparkResponse; + + /** + * Constructor for EMR Client Implementation. + * + * @param emr EMR helper + * @param flint Opensearch args for flint integration jar + * @param sparkResponse Response object to help with retrieving results from Opensearch index + */ + public EmrClientImpl(AmazonElasticMapReduce emr, String emrCluster, FlintHelper flint, + SparkResponse sparkResponse, String sparkApplicationJar) { + this.emr = emr; + this.emrCluster = emrCluster; + this.flint = flint; + this.sparkResponse = sparkResponse; + this.sparkApplicationJar = + sparkApplicationJar == null ? SPARK_SQL_APPLICATION_JAR : sparkApplicationJar; + } + + @Override + public JSONObject sql(String query) throws IOException { + runEmrApplication(query); + return sparkResponse.getResultFromOpensearchIndex(); + } + + @VisibleForTesting + void runEmrApplication(String query) { + + HadoopJarStepConfig stepConfig = new HadoopJarStepConfig() + .withJar("command-runner.jar") + .withArgs("spark-submit", + "--class","org.opensearch.sql.SQLJob", + "--jars", + flint.getFlintIntegrationJar(), + sparkApplicationJar, + query, + SPARK_INDEX_NAME, + flint.getFlintHost(), + flint.getFlintPort(), + flint.getFlintScheme(), + flint.getFlintAuth(), + flint.getFlintRegion() + ); + + StepConfig emrstep = new StepConfig() + .withName("Spark Application") + .withActionOnFailure(ActionOnFailure.CONTINUE) + .withHadoopJarStep(stepConfig); + + AddJobFlowStepsRequest request = new AddJobFlowStepsRequest() + .withJobFlowId(emrCluster) + .withSteps(emrstep); + + AddJobFlowStepsResult result = emr.addJobFlowSteps(request); + logger.info("EMR step ID: " + result.getStepIds()); + + String stepId = result.getStepIds().get(0); + DescribeStepRequest stepRequest = new DescribeStepRequest() + .withClusterId(emrCluster) + .withStepId(stepId); + + waitForStepExecution(stepRequest); + sparkResponse.setValue(stepId); + } + + @SneakyThrows + private void waitForStepExecution(DescribeStepRequest stepRequest) { + // Wait for the step to complete + boolean completed = false; + while (!completed) { + // Get the step status + StepStatus statusDetail = emr.describeStep(stepRequest).getStep().getStatus(); + // Check if the step has completed + if (statusDetail.getState().equals("COMPLETED")) { + completed = true; + logger.info("EMR step completed successfully."); + } else if (statusDetail.getState().equals("FAILED") + || statusDetail.getState().equals("CANCELLED")) { + logger.error("EMR step failed or cancelled."); + throw new RuntimeException("Spark SQL application failed."); + } else { + // Sleep for some time before checking the status again + Thread.sleep(2500); + } + } + } + +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java new file mode 100644 index 0000000000..65d5a01ba2 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.data.constants; + +public class SparkConstants { + public static final String EMR = "emr"; + public static final String STEP_ID_FIELD = "stepId.keyword"; + public static final String SPARK_SQL_APPLICATION_JAR = "s3://spark-datasource/sql-job.jar"; + public static final String SPARK_INDEX_NAME = ".query_execution_result"; + public static final String FLINT_INTEGRATION_JAR = + "s3://spark-datasource/flint-spark-integration-assembly-0.1.0-SNAPSHOT.jar"; + public static final String FLINT_DEFAULT_HOST = "localhost"; + public static final String FLINT_DEFAULT_PORT = "9200"; + public static final String FLINT_DEFAULT_SCHEME = "http"; + public static final String FLINT_DEFAULT_AUTH = "-1"; + public static final String FLINT_DEFAULT_REGION = "us-west-2"; +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java b/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java new file mode 100644 index 0000000000..cb2b31ddc1 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java @@ -0,0 +1,155 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions.response; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.json.JSONArray; +import org.json.JSONObject; +import org.opensearch.sql.data.model.ExprBooleanValue; +import org.opensearch.sql.data.model.ExprByteValue; +import org.opensearch.sql.data.model.ExprDateValue; +import org.opensearch.sql.data.model.ExprDoubleValue; +import org.opensearch.sql.data.model.ExprFloatValue; +import org.opensearch.sql.data.model.ExprIntegerValue; +import org.opensearch.sql.data.model.ExprLongValue; +import org.opensearch.sql.data.model.ExprShortValue; +import org.opensearch.sql.data.model.ExprStringValue; +import org.opensearch.sql.data.model.ExprTimestampValue; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.executor.ExecutionEngine; + +/** + * Default implementation of SparkSqlFunctionResponseHandle. + */ +public class DefaultSparkSqlFunctionResponseHandle implements SparkSqlFunctionResponseHandle { + private Iterator responseIterator; + private ExecutionEngine.Schema schema; + private static final Logger logger = + LogManager.getLogger(DefaultSparkSqlFunctionResponseHandle.class); + + /** + * Constructor. + * + * @param responseObject Spark responseObject. + */ + public DefaultSparkSqlFunctionResponseHandle(JSONObject responseObject) { + constructIteratorAndSchema(responseObject); + } + + private void constructIteratorAndSchema(JSONObject responseObject) { + List result = new ArrayList<>(); + List columnList; + JSONObject items = responseObject.getJSONObject("data"); + logger.info("Spark Application ID: " + items.getString("applicationId")); + columnList = getColumnList(items.getJSONArray("schema")); + for (int i = 0; i < items.getJSONArray("result").length(); i++) { + JSONObject row = new JSONObject( + items.getJSONArray("result").get(i).toString().replace("'", "\"")); + LinkedHashMap linkedHashMap = extractRow(row, columnList); + result.add(new ExprTupleValue(linkedHashMap)); + } + this.schema = new ExecutionEngine.Schema(columnList); + this.responseIterator = result.iterator(); + } + + private static LinkedHashMap extractRow( + JSONObject row, List columnList) { + LinkedHashMap linkedHashMap = new LinkedHashMap<>(); + for (ExecutionEngine.Schema.Column column : columnList) { + ExprType type = column.getExprType(); + if (type == ExprCoreType.BOOLEAN) { + linkedHashMap.put(column.getName(), ExprBooleanValue.of(row.getBoolean(column.getName()))); + } else if (type == ExprCoreType.LONG) { + linkedHashMap.put(column.getName(), new ExprLongValue(row.getLong(column.getName()))); + } else if (type == ExprCoreType.INTEGER) { + linkedHashMap.put(column.getName(), new ExprIntegerValue(row.getInt(column.getName()))); + } else if (type == ExprCoreType.SHORT) { + linkedHashMap.put(column.getName(), new ExprShortValue(row.getInt(column.getName()))); + } else if (type == ExprCoreType.BYTE) { + linkedHashMap.put(column.getName(), new ExprByteValue(row.getInt(column.getName()))); + } else if (type == ExprCoreType.DOUBLE) { + linkedHashMap.put(column.getName(), new ExprDoubleValue(row.getDouble(column.getName()))); + } else if (type == ExprCoreType.FLOAT) { + linkedHashMap.put(column.getName(), new ExprFloatValue(row.getFloat(column.getName()))); + } else if (type == ExprCoreType.DATE) { + linkedHashMap.put(column.getName(), new ExprDateValue(row.getString(column.getName()))); + } else if (type == ExprCoreType.TIMESTAMP) { + linkedHashMap.put(column.getName(), + new ExprTimestampValue(row.getString(column.getName()))); + } else if (type == ExprCoreType.STRING) { + linkedHashMap.put(column.getName(), new ExprStringValue(row.getString(column.getName()))); + } else { + throw new RuntimeException("Result contains invalid data type"); + } + } + + return linkedHashMap; + } + + private List getColumnList(JSONArray schema) { + List columnList = new ArrayList<>(); + for (int i = 0; i < schema.length(); i++) { + JSONObject column = new JSONObject(schema.get(i).toString().replace("'", "\"")); + columnList.add(new ExecutionEngine.Schema.Column( + column.get("column_name").toString(), + column.get("column_name").toString(), + getDataType(column.get("data_type").toString()))); + } + return columnList; + } + + private ExprCoreType getDataType(String sparkDataType) { + switch (sparkDataType) { + case "boolean": + return ExprCoreType.BOOLEAN; + case "long": + return ExprCoreType.LONG; + case "integer": + return ExprCoreType.INTEGER; + case "short": + return ExprCoreType.SHORT; + case "byte": + return ExprCoreType.BYTE; + case "double": + return ExprCoreType.DOUBLE; + case "float": + return ExprCoreType.FLOAT; + case "timestamp": + return ExprCoreType.DATE; + case "date": + return ExprCoreType.TIMESTAMP; + case "string": + case "varchar": + case "char": + return ExprCoreType.STRING; + default: + return ExprCoreType.UNKNOWN; + } + } + + @Override + public boolean hasNext() { + return responseIterator.hasNext(); + } + + @Override + public ExprValue next() { + return responseIterator.next(); + } + + @Override + public ExecutionEngine.Schema schema() { + return schema; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/response/SparkSqlFunctionResponseHandle.java b/spark/src/main/java/org/opensearch/sql/spark/functions/response/SparkSqlFunctionResponseHandle.java new file mode 100644 index 0000000000..da68b591eb --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/response/SparkSqlFunctionResponseHandle.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions.response; + +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.executor.ExecutionEngine; + +/** + * Handle Spark response. + */ +public interface SparkSqlFunctionResponseHandle { + + /** + * Return true if Spark response has more result. + */ + boolean hasNext(); + + /** + * Return Spark response as {@link ExprValue}. Attention, the method must been called when + * hasNext return true. + */ + ExprValue next(); + + /** + * Return ExecutionEngine.Schema of the Spark response. + */ + ExecutionEngine.Schema schema(); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanBuilder.java b/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanBuilder.java index 561f6f2933..28ce7dd19a 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanBuilder.java +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanBuilder.java @@ -24,8 +24,7 @@ public class SparkSqlFunctionTableScanBuilder extends TableScanBuilder { @Override public TableScanOperator build() { - //TODO: return SqlFunctionTableScanOperator - return null; + return new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanOperator.java b/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanOperator.java new file mode 100644 index 0000000000..85e854e422 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanOperator.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions.scan; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.Locale; +import lombok.RequiredArgsConstructor; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.json.JSONObject; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.response.DefaultSparkSqlFunctionResponseHandle; +import org.opensearch.sql.spark.functions.response.SparkSqlFunctionResponseHandle; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.storage.TableScanOperator; + +/** + * This a table scan operator to handle sql table function. + */ +@RequiredArgsConstructor +public class SparkSqlFunctionTableScanOperator extends TableScanOperator { + private final SparkClient sparkClient; + private final SparkQueryRequest request; + private SparkSqlFunctionResponseHandle sparkResponseHandle; + private static final Logger LOG = LogManager.getLogger(); + + @Override + public void open() { + super.open(); + this.sparkResponseHandle = AccessController.doPrivileged( + (PrivilegedAction) () -> { + try { + JSONObject responseObject = sparkClient.sql(request.getSql()); + return new DefaultSparkSqlFunctionResponseHandle(responseObject); + } catch (IOException e) { + LOG.error(e.getMessage()); + throw new RuntimeException( + String.format("Error fetching data from spark server: %s", e.getMessage())); + } + }); + } + + @Override + public boolean hasNext() { + return this.sparkResponseHandle.hasNext(); + } + + @Override + public ExprValue next() { + return this.sparkResponseHandle.next(); + } + + @Override + public String explain() { + return String.format(Locale.ROOT, "sql(%s)", request.getSql()); + } + + @Override + public ExecutionEngine.Schema schema() { + return this.sparkResponseHandle.schema(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/helper/FlintHelper.java b/spark/src/main/java/org/opensearch/sql/spark/helper/FlintHelper.java new file mode 100644 index 0000000000..b3c3c0871a --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/helper/FlintHelper.java @@ -0,0 +1,54 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.helper; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_AUTH; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_HOST; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_PORT; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_REGION; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_SCHEME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INTEGRATION_JAR; + +import lombok.Getter; + +public class FlintHelper { + @Getter + private final String flintIntegrationJar; + @Getter + private final String flintHost; + @Getter + private final String flintPort; + @Getter + private final String flintScheme; + @Getter + private final String flintAuth; + @Getter + private final String flintRegion; + + /** Arguments required to write data to opensearch index using flint integration. + * + * @param flintHost Opensearch host for flint + * @param flintPort Opensearch port for flint integration + * @param flintScheme Opensearch scheme for flint integration + * @param flintAuth Opensearch auth for flint integration + * @param flintRegion Opensearch region for flint integration + */ + public FlintHelper( + String flintIntegrationJar, + String flintHost, + String flintPort, + String flintScheme, + String flintAuth, + String flintRegion) { + this.flintIntegrationJar = + flintIntegrationJar == null ? FLINT_INTEGRATION_JAR : flintIntegrationJar; + this.flintHost = flintHost != null ? flintHost : FLINT_DEFAULT_HOST; + this.flintPort = flintPort != null ? flintPort : FLINT_DEFAULT_PORT; + this.flintScheme = flintScheme != null ? flintScheme : FLINT_DEFAULT_SCHEME; + this.flintAuth = flintAuth != null ? flintAuth : FLINT_DEFAULT_AUTH; + this.flintRegion = flintRegion != null ? flintRegion : FLINT_DEFAULT_REGION; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/response/SparkResponse.java b/spark/src/main/java/org/opensearch/sql/spark/response/SparkResponse.java new file mode 100644 index 0000000000..3e348381f2 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/response/SparkResponse.java @@ -0,0 +1,104 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.response; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_INDEX_NAME; + +import com.google.common.annotations.VisibleForTesting; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.Setter; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.json.JSONObject; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.ActionFuture; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.sql.datasources.exceptions.DataSourceNotFoundException; + +@Data +public class SparkResponse { + private final Client client; + private String value; + private final String field; + private static final Logger LOG = LogManager.getLogger(); + + /** + * Response for spark sql query. + * + * @param client Opensearch client + * @param value Identifier field value + * @param field Identifier field name + */ + public SparkResponse(Client client, String value, String field) { + this.client = client; + this.value = value; + this.field = field; + } + + public JSONObject getResultFromOpensearchIndex() { + return searchInSparkIndex(QueryBuilders.termQuery(field, value)); + } + + private JSONObject searchInSparkIndex(QueryBuilder query) { + SearchRequest searchRequest = new SearchRequest(); + searchRequest.indices(SPARK_INDEX_NAME); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(query); + searchRequest.source(searchSourceBuilder); + ActionFuture searchResponseActionFuture; + try { + searchResponseActionFuture = client.search(searchRequest); + } catch (Exception e) { + throw new RuntimeException(e); + } + SearchResponse searchResponse = searchResponseActionFuture.actionGet(); + if (searchResponse.status().getStatus() != 200) { + throw new RuntimeException( + "Fetching result from " + SPARK_INDEX_NAME + " index failed with status : " + + searchResponse.status()); + } else { + JSONObject data = new JSONObject(); + for (SearchHit searchHit : searchResponse.getHits().getHits()) { + data.put("data", searchHit.getSourceAsMap()); + deleteInSparkIndex(searchHit.getId()); + } + return data; + } + } + + @VisibleForTesting + void deleteInSparkIndex(String id) { + DeleteRequest deleteRequest = new DeleteRequest(SPARK_INDEX_NAME); + deleteRequest.id(id); + ActionFuture deleteResponseActionFuture; + try { + deleteResponseActionFuture = client.delete(deleteRequest); + } catch (Exception e) { + throw new RuntimeException(e); + } + DeleteResponse deleteResponse = deleteResponseActionFuture.actionGet(); + if (deleteResponse.getResult().equals(DocWriteResponse.Result.DELETED)) { + LOG.debug("Spark result successfully deleted ", id); + } else if (deleteResponse.getResult().equals(DocWriteResponse.Result.NOT_FOUND)) { + throw new ResourceNotFoundException("Spark result with id " + + id + " doesn't exist"); + } else { + throw new RuntimeException("Deleting spark result information failed with : " + + deleteResponse.getResult().getLowercase()); + } + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkScan.java b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkScan.java new file mode 100644 index 0000000000..3897e8690e --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkScan.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.storage; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.Setter; +import lombok.ToString; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.storage.TableScanOperator; + +/** + * Spark scan operator. + */ +@EqualsAndHashCode(onlyExplicitlyIncluded = true, callSuper = false) +@ToString(onlyExplicitlyIncluded = true) +public class SparkScan extends TableScanOperator { + + private final SparkClient sparkClient; + + @EqualsAndHashCode.Include + @Getter + @Setter + @ToString.Include + private SparkQueryRequest request; + + + /** + * Constructor. + * + * @param sparkClient sparkClient. + */ + public SparkScan(SparkClient sparkClient) { + this.sparkClient = sparkClient; + this.request = new SparkQueryRequest(); + } + + @Override + public boolean hasNext() { + return false; + } + + @Override + public ExprValue next() { + return null; + } + + @Override + public String explain() { + return getRequest().toString(); + } + +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageFactory.java b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageFactory.java index e4da29f6b4..937679b50e 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageFactory.java +++ b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageFactory.java @@ -5,6 +5,16 @@ package org.opensearch.sql.spark.storage; +import static org.opensearch.sql.spark.data.constants.SparkConstants.EMR; +import static org.opensearch.sql.spark.data.constants.SparkConstants.STEP_ID_FIELD; + +import com.amazonaws.auth.AWSStaticCredentialsProvider; +import com.amazonaws.auth.BasicAWSCredentials; +import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce; +import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduceClientBuilder; +import java.security.AccessController; +import java.security.InvalidParameterException; +import java.security.PrivilegedAction; import java.util.Map; import lombok.RequiredArgsConstructor; import org.opensearch.client.Client; @@ -12,7 +22,11 @@ import org.opensearch.sql.datasource.model.DataSource; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.datasources.auth.AuthenticationType; +import org.opensearch.sql.spark.client.EmrClientImpl; import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.helper.FlintHelper; +import org.opensearch.sql.spark.response.SparkResponse; import org.opensearch.sql.storage.DataSourceFactory; import org.opensearch.sql.storage.StorageEngine; @@ -24,6 +38,26 @@ public class SparkStorageFactory implements DataSourceFactory { private final Client client; private final Settings settings; + // Spark datasource configuration properties + public static final String CONNECTOR_TYPE = "spark.connector"; + public static final String SPARK_SQL_APPLICATION = "spark.sql.application"; + + // EMR configuration properties + public static final String EMR_CLUSTER = "emr.cluster"; + public static final String EMR_AUTH_TYPE = "emr.auth.type"; + public static final String EMR_REGION = "emr.auth.region"; + public static final String EMR_ROLE_ARN = "emr.auth.role_arn"; + public static final String EMR_ACCESS_KEY = "emr.auth.access_key"; + public static final String EMR_SECRET_KEY = "emr.auth.secret_key"; + + // Flint integration jar configuration properties + public static final String FLINT_INTEGRATION = "spark.datasource.flint.integration"; + public static final String FLINT_HOST = "spark.datasource.flint.host"; + public static final String FLINT_PORT = "spark.datasource.flint.port"; + public static final String FLINT_SCHEME = "spark.datasource.flint.scheme"; + public static final String FLINT_AUTH = "spark.datasource.flint.auth"; + public static final String FLINT_REGION = "spark.datasource.flint.region"; + @Override public DataSourceType getDataSourceType() { return DataSourceType.SPARK; @@ -41,11 +75,58 @@ public DataSource createDataSource(DataSourceMetadata metadata) { * This function gets spark storage engine. * * @param requiredConfig spark config options - * @return spark storage engine object + * @return spark storage engine object */ StorageEngine getStorageEngine(Map requiredConfig) { - SparkClient sparkClient = null; - //TODO: Initialize spark client + SparkClient sparkClient; + if (requiredConfig.get(CONNECTOR_TYPE).equals(EMR)) { + sparkClient = + AccessController.doPrivileged((PrivilegedAction) () -> { + validateEMRConfigProperties(requiredConfig); + return new EmrClientImpl( + getEMRClient( + requiredConfig.get(EMR_ACCESS_KEY), + requiredConfig.get(EMR_SECRET_KEY), + requiredConfig.get(EMR_REGION)), + requiredConfig.get(EMR_CLUSTER), + new FlintHelper( + requiredConfig.get(FLINT_INTEGRATION), + requiredConfig.get(FLINT_HOST), + requiredConfig.get(FLINT_PORT), + requiredConfig.get(FLINT_SCHEME), + requiredConfig.get(FLINT_AUTH), + requiredConfig.get(FLINT_REGION)), + new SparkResponse(client, null, STEP_ID_FIELD), + requiredConfig.get(SPARK_SQL_APPLICATION)); + }); + } else { + throw new InvalidParameterException("Spark connector type is invalid."); + } return new SparkStorageEngine(sparkClient); } + + private void validateEMRConfigProperties(Map dataSourceMetadataConfig) + throws IllegalArgumentException { + if (dataSourceMetadataConfig.get(EMR_CLUSTER) == null + || dataSourceMetadataConfig.get(EMR_AUTH_TYPE) == null) { + throw new IllegalArgumentException("EMR config properties are missing."); + } else if (dataSourceMetadataConfig.get(EMR_AUTH_TYPE) + .equals(AuthenticationType.AWSSIGV4AUTH.getName()) + && (dataSourceMetadataConfig.get(EMR_ACCESS_KEY) == null + || dataSourceMetadataConfig.get(EMR_SECRET_KEY) == null)) { + throw new IllegalArgumentException("EMR auth keys are missing."); + } else if (!dataSourceMetadataConfig.get(EMR_AUTH_TYPE) + .equals(AuthenticationType.AWSSIGV4AUTH.getName())) { + throw new IllegalArgumentException("Invalid auth type."); + } + } + + private AmazonElasticMapReduce getEMRClient( + String emrAccessKey, String emrSecretKey, String emrRegion) { + return AmazonElasticMapReduceClientBuilder.standard() + .withCredentials(new AWSStaticCredentialsProvider( + new BasicAWSCredentials(emrAccessKey, emrSecretKey))) + .withRegion(emrRegion) + .build(); + } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkTable.java b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkTable.java index 344db8ab7a..5151405db9 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkTable.java +++ b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkTable.java @@ -7,9 +7,9 @@ import java.util.HashMap; import java.util.Map; -import javax.annotation.Nonnull; import lombok.Getter; import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.planner.DefaultImplementor; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.spark.client.SparkClient; @@ -32,8 +32,7 @@ public class SparkTable implements Table { /** * Constructor for entire Sql Request. */ - public SparkTable(SparkClient sparkService, - @Nonnull SparkQueryRequest sparkQueryRequest) { + public SparkTable(SparkClient sparkService, SparkQueryRequest sparkQueryRequest) { this.sparkClient = sparkService; this.sparkQueryRequest = sparkQueryRequest; } @@ -57,8 +56,10 @@ public Map getFieldTypes() { @Override public PhysicalPlan implement(LogicalPlan plan) { - //TODO: Add plan - return null; + SparkScan metricScan = + new SparkScan(sparkClient); + metricScan.setRequest(sparkQueryRequest); + return plan.accept(new DefaultImplementor(), metricScan); } @Override diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/EmrClientImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/EmrClientImplTest.java new file mode 100644 index 0000000000..a94ac01f2f --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/client/EmrClientImplTest.java @@ -0,0 +1,160 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.client; + +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.constants.TestConstants.EMR_CLUSTER_ID; +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; +import static org.opensearch.sql.spark.utils.TestUtils.getJson; + +import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce; +import com.amazonaws.services.elasticmapreduce.model.AddJobFlowStepsResult; +import com.amazonaws.services.elasticmapreduce.model.DescribeStepResult; +import com.amazonaws.services.elasticmapreduce.model.Step; +import com.amazonaws.services.elasticmapreduce.model.StepStatus; +import lombok.SneakyThrows; +import org.json.JSONObject; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.spark.helper.FlintHelper; +import org.opensearch.sql.spark.response.SparkResponse; + +@ExtendWith(MockitoExtension.class) +public class EmrClientImplTest { + + @Mock + private AmazonElasticMapReduce emr; + @Mock + private FlintHelper flint; + @Mock + private SparkResponse sparkResponse; + + @Test + @SneakyThrows + void testRunEmrApplication() { + AddJobFlowStepsResult addStepsResult = new AddJobFlowStepsResult().withStepIds(EMR_CLUSTER_ID); + when(emr.addJobFlowSteps(any())).thenReturn(addStepsResult); + + StepStatus stepStatus = new StepStatus(); + stepStatus.setState("COMPLETED"); + Step step = new Step(); + step.setStatus(stepStatus); + DescribeStepResult describeStepResult = new DescribeStepResult(); + describeStepResult.setStep(step); + when(emr.describeStep(any())).thenReturn(describeStepResult); + + EmrClientImpl emrClientImpl = new EmrClientImpl( + emr, EMR_CLUSTER_ID, flint, sparkResponse, null); + emrClientImpl.runEmrApplication(QUERY); + } + + @Test + @SneakyThrows + void testRunEmrApplicationFailed() { + AddJobFlowStepsResult addStepsResult = new AddJobFlowStepsResult().withStepIds(EMR_CLUSTER_ID); + when(emr.addJobFlowSteps(any())).thenReturn(addStepsResult); + + StepStatus stepStatus = new StepStatus(); + stepStatus.setState("FAILED"); + Step step = new Step(); + step.setStatus(stepStatus); + DescribeStepResult describeStepResult = new DescribeStepResult(); + describeStepResult.setStep(step); + when(emr.describeStep(any())).thenReturn(describeStepResult); + + EmrClientImpl emrClientImpl = new EmrClientImpl( + emr, EMR_CLUSTER_ID, flint, sparkResponse, null); + RuntimeException exception = Assertions.assertThrows(RuntimeException.class, + () -> emrClientImpl.runEmrApplication(QUERY)); + Assertions.assertEquals("Spark SQL application failed.", + exception.getMessage()); + } + + @Test + @SneakyThrows + void testRunEmrApplicationCancelled() { + AddJobFlowStepsResult addStepsResult = new AddJobFlowStepsResult().withStepIds(EMR_CLUSTER_ID); + when(emr.addJobFlowSteps(any())).thenReturn(addStepsResult); + + StepStatus stepStatus = new StepStatus(); + stepStatus.setState("CANCELLED"); + Step step = new Step(); + step.setStatus(stepStatus); + DescribeStepResult describeStepResult = new DescribeStepResult(); + describeStepResult.setStep(step); + when(emr.describeStep(any())).thenReturn(describeStepResult); + + EmrClientImpl emrClientImpl = new EmrClientImpl( + emr, EMR_CLUSTER_ID, flint, sparkResponse, null); + RuntimeException exception = Assertions.assertThrows(RuntimeException.class, + () -> emrClientImpl.runEmrApplication(QUERY)); + Assertions.assertEquals("Spark SQL application failed.", + exception.getMessage()); + } + + @Test + @SneakyThrows + void testRunEmrApplicationRunnning() { + AddJobFlowStepsResult addStepsResult = new AddJobFlowStepsResult().withStepIds(EMR_CLUSTER_ID); + when(emr.addJobFlowSteps(any())).thenReturn(addStepsResult); + + StepStatus runningStatus = new StepStatus(); + runningStatus.setState("RUNNING"); + Step runningStep = new Step(); + runningStep.setStatus(runningStatus); + DescribeStepResult runningDescribeStepResult = new DescribeStepResult(); + runningDescribeStepResult.setStep(runningStep); + + StepStatus completedStatus = new StepStatus(); + completedStatus.setState("COMPLETED"); + Step completedStep = new Step(); + completedStep.setStatus(completedStatus); + DescribeStepResult completedDescribeStepResult = new DescribeStepResult(); + completedDescribeStepResult.setStep(completedStep); + + when(emr.describeStep(any())).thenReturn(runningDescribeStepResult) + .thenReturn(completedDescribeStepResult); + + EmrClientImpl emrClientImpl = new EmrClientImpl( + emr, EMR_CLUSTER_ID, flint, sparkResponse, null); + emrClientImpl.runEmrApplication(QUERY); + } + + @Test + @SneakyThrows + void testSql() { + AddJobFlowStepsResult addStepsResult = new AddJobFlowStepsResult().withStepIds(EMR_CLUSTER_ID); + when(emr.addJobFlowSteps(any())).thenReturn(addStepsResult); + + StepStatus runningStatus = new StepStatus(); + runningStatus.setState("RUNNING"); + Step runningStep = new Step(); + runningStep.setStatus(runningStatus); + DescribeStepResult runningDescribeStepResult = new DescribeStepResult(); + runningDescribeStepResult.setStep(runningStep); + + StepStatus completedStatus = new StepStatus(); + completedStatus.setState("COMPLETED"); + Step completedStep = new Step(); + completedStep.setStatus(completedStatus); + DescribeStepResult completedDescribeStepResult = new DescribeStepResult(); + completedDescribeStepResult.setStep(completedStep); + + when(emr.describeStep(any())).thenReturn(runningDescribeStepResult) + .thenReturn(completedDescribeStepResult); + when(sparkResponse.getResultFromOpensearchIndex()) + .thenReturn(new JSONObject(getJson("select_query_response.json"))); + + EmrClientImpl emrClientImpl = new EmrClientImpl( + emr, EMR_CLUSTER_ID, flint, sparkResponse, null); + emrClientImpl.sql(QUERY); + + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java b/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java new file mode 100644 index 0000000000..2b1020568a --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java @@ -0,0 +1,11 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.constants; + +public class TestConstants { + public static final String QUERY = "select 1"; + public static final String EMR_CLUSTER_ID = "j-123456789"; +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionImplementationTest.java b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionImplementationTest.java index 33c65ff278..18db5b9471 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionImplementationTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionImplementationTest.java @@ -8,6 +8,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; import java.util.List; import org.junit.jupiter.api.Test; @@ -33,7 +34,7 @@ public class SparkSqlFunctionImplementationTest { void testValueOfAndTypeToString() { FunctionName functionName = new FunctionName("sql"); List namedArgumentExpressionList - = List.of(DSL.namedArgument("query", DSL.literal("select 1"))); + = List.of(DSL.namedArgument("query", DSL.literal(QUERY))); SparkSqlFunctionImplementation sparkSqlFunctionImplementation = new SparkSqlFunctionImplementation(functionName, namedArgumentExpressionList, client); UnsupportedOperationException exception = assertThrows(UnsupportedOperationException.class, @@ -45,12 +46,11 @@ void testValueOfAndTypeToString() { assertEquals(ExprCoreType.STRUCT, sparkSqlFunctionImplementation.type()); } - @Test void testApplyArguments() { FunctionName functionName = new FunctionName("sql"); List namedArgumentExpressionList - = List.of(DSL.namedArgument("query", DSL.literal("select 1"))); + = List.of(DSL.namedArgument("query", DSL.literal(QUERY))); SparkSqlFunctionImplementation sparkSqlFunctionImplementation = new SparkSqlFunctionImplementation(functionName, namedArgumentExpressionList, client); SparkTable sparkTable @@ -58,14 +58,14 @@ void testApplyArguments() { assertNotNull(sparkTable.getSparkQueryRequest()); SparkQueryRequest sparkQueryRequest = sparkTable.getSparkQueryRequest(); - assertEquals("select 1", sparkQueryRequest.getSql()); + assertEquals(QUERY, sparkQueryRequest.getSql()); } @Test void testApplyArgumentsException() { FunctionName functionName = new FunctionName("sql"); List namedArgumentExpressionList - = List.of(DSL.namedArgument("query", DSL.literal("select 1")), + = List.of(DSL.namedArgument("query", DSL.literal(QUERY)), DSL.namedArgument("tmp", DSL.literal(12345))); SparkSqlFunctionImplementation sparkSqlFunctionImplementation = new SparkSqlFunctionImplementation(functionName, namedArgumentExpressionList, client); diff --git a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanBuilderTest.java b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanBuilderTest.java index f5fb0983cc..94c87602b7 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanBuilderTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanBuilderTest.java @@ -5,12 +5,15 @@ package org.opensearch.sql.spark.functions; +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; + import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.mockito.Mock; import org.opensearch.sql.planner.logical.LogicalProject; import org.opensearch.sql.spark.client.SparkClient; import org.opensearch.sql.spark.functions.scan.SparkSqlFunctionTableScanBuilder; +import org.opensearch.sql.spark.functions.scan.SparkSqlFunctionTableScanOperator; import org.opensearch.sql.spark.request.SparkQueryRequest; import org.opensearch.sql.storage.TableScanOperator; @@ -24,19 +27,20 @@ public class SparkSqlFunctionTableScanBuilderTest { @Test void testBuild() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); - sparkQueryRequest.setSql("select 1"); + sparkQueryRequest.setSql(QUERY); SparkSqlFunctionTableScanBuilder sparkSqlFunctionTableScanBuilder = new SparkSqlFunctionTableScanBuilder(sparkClient, sparkQueryRequest); TableScanOperator sqlFunctionTableScanOperator = sparkSqlFunctionTableScanBuilder.build(); - Assertions.assertNull(sqlFunctionTableScanOperator); + Assertions.assertTrue(sqlFunctionTableScanOperator + instanceof SparkSqlFunctionTableScanOperator); } @Test void testPushProject() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); - sparkQueryRequest.setSql("select 1"); + sparkQueryRequest.setSql(QUERY); SparkSqlFunctionTableScanBuilder sparkSqlFunctionTableScanBuilder = new SparkSqlFunctionTableScanBuilder(sparkClient, sparkQueryRequest); diff --git a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanOperatorTest.java b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanOperatorTest.java new file mode 100644 index 0000000000..f6807f9913 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanOperatorTest.java @@ -0,0 +1,181 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; +import static org.opensearch.sql.spark.utils.TestUtils.getJson; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import lombok.SneakyThrows; +import org.json.JSONObject; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.data.model.ExprBooleanValue; +import org.opensearch.sql.data.model.ExprByteValue; +import org.opensearch.sql.data.model.ExprDateValue; +import org.opensearch.sql.data.model.ExprDoubleValue; +import org.opensearch.sql.data.model.ExprFloatValue; +import org.opensearch.sql.data.model.ExprIntegerValue; +import org.opensearch.sql.data.model.ExprLongValue; +import org.opensearch.sql.data.model.ExprShortValue; +import org.opensearch.sql.data.model.ExprStringValue; +import org.opensearch.sql.data.model.ExprTimestampValue; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.scan.SparkSqlFunctionTableScanOperator; +import org.opensearch.sql.spark.request.SparkQueryRequest; + +@ExtendWith(MockitoExtension.class) +public class SparkSqlFunctionTableScanOperatorTest { + + @Mock + private SparkClient sparkClient; + + @Test + @SneakyThrows + void testEmptyQueryWithException() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator + = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + + when(sparkClient.sql(any())) + .thenThrow(new IOException("Error Message")); + RuntimeException runtimeException + = assertThrows(RuntimeException.class, sparkSqlFunctionTableScanOperator::open); + assertEquals("Error fetching data from spark server: Error Message", + runtimeException.getMessage()); + } + + @Test + @SneakyThrows + void testClose() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator + = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + sparkSqlFunctionTableScanOperator.close(); + } + + @Test + @SneakyThrows + void testExplain() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator + = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + + Assertions.assertEquals("sql(select 1)", + sparkSqlFunctionTableScanOperator.explain()); + } + + @Test + @SneakyThrows + void testQueryResponseIterator() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator + = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + + when(sparkClient.sql(any())) + .thenReturn(new JSONObject(getJson("select_query_response.json"))); + sparkSqlFunctionTableScanOperator.open(); + assertTrue(sparkSqlFunctionTableScanOperator.hasNext()); + ExprTupleValue firstRow = new ExprTupleValue(new LinkedHashMap<>() { + { + put("1", new ExprIntegerValue(1)); + } + }); + assertEquals(firstRow, sparkSqlFunctionTableScanOperator.next()); + Assertions.assertFalse(sparkSqlFunctionTableScanOperator.hasNext()); + } + + @Test + @SneakyThrows + void testQueryResponseAllTypes() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator + = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + + when(sparkClient.sql(any())) + .thenReturn(new JSONObject(getJson("all_data_type.json"))); + sparkSqlFunctionTableScanOperator.open(); + assertTrue(sparkSqlFunctionTableScanOperator.hasNext()); + ExprTupleValue firstRow = new ExprTupleValue(new LinkedHashMap<>() { + { + put("boolean", ExprBooleanValue.of(true)); + put("long", new ExprLongValue(922337203)); + put("integer", new ExprIntegerValue(2147483647)); + put("short", new ExprShortValue(32767)); + put("byte", new ExprByteValue(127)); + put("double", new ExprDoubleValue(9223372036854.775807)); + put("float", new ExprFloatValue(21474.83647)); + put("timestamp", new ExprDateValue("2023-07-01 10:31:30")); + put("date", new ExprTimestampValue("2023-07-01 10:31:30")); + put("string", new ExprStringValue("ABC")); + put("char", new ExprStringValue("A")); + } + }); + assertEquals(firstRow, sparkSqlFunctionTableScanOperator.next()); + Assertions.assertFalse(sparkSqlFunctionTableScanOperator.hasNext()); + } + + @Test + @SneakyThrows + void testQueryResponseInvalidDataType() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator + = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + + when(sparkClient.sql(any())) + .thenReturn(new JSONObject(getJson("invalid_data_type.json"))); + + RuntimeException exception = Assertions.assertThrows(RuntimeException.class, + () -> sparkSqlFunctionTableScanOperator.open()); + Assertions.assertEquals("Result contains invalid data type", + exception.getMessage()); + } + + @Test + @SneakyThrows + void testQuerySchema() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator + = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + + when(sparkClient.sql(any())) + .thenReturn( + new JSONObject(getJson("select_query_response.json"))); + sparkSqlFunctionTableScanOperator.open(); + ArrayList columns = new ArrayList<>(); + columns.add(new ExecutionEngine.Schema.Column("1", "1", ExprCoreType.INTEGER)); + ExecutionEngine.Schema expectedSchema = new ExecutionEngine.Schema(columns); + assertEquals(expectedSchema, sparkSqlFunctionTableScanOperator.schema()); + } + +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlTableFunctionResolverTest.java b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlTableFunctionResolverTest.java index 491e9bbd73..e18fac36de 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlTableFunctionResolverTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlTableFunctionResolverTest.java @@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; import java.util.List; import java.util.stream.Collectors; @@ -46,7 +47,7 @@ void testResolve() { = new SparkSqlTableFunctionResolver(client); FunctionName functionName = FunctionName.of("sql"); List expressions - = List.of(DSL.namedArgument("query", DSL.literal("select 1"))); + = List.of(DSL.namedArgument("query", DSL.literal(QUERY))); FunctionSignature functionSignature = new FunctionSignature(functionName, expressions .stream().map(Expression::type).collect(Collectors.toList())); Pair resolution @@ -63,7 +64,7 @@ void testResolve() { assertNotNull(sparkTable.getSparkQueryRequest()); SparkQueryRequest sparkQueryRequest = sparkTable.getSparkQueryRequest(); - assertEquals("select 1", sparkQueryRequest.getSql()); + assertEquals(QUERY, sparkQueryRequest.getSql()); } @Test @@ -72,7 +73,7 @@ void testArgumentsPassedByPosition() { = new SparkSqlTableFunctionResolver(client); FunctionName functionName = FunctionName.of("sql"); List expressions - = List.of(DSL.namedArgument(null, DSL.literal("select 1"))); + = List.of(DSL.namedArgument(null, DSL.literal(QUERY))); FunctionSignature functionSignature = new FunctionSignature(functionName, expressions .stream().map(Expression::type).collect(Collectors.toList())); @@ -91,7 +92,7 @@ void testArgumentsPassedByPosition() { assertNotNull(sparkTable.getSparkQueryRequest()); SparkQueryRequest sparkQueryRequest = sparkTable.getSparkQueryRequest(); - assertEquals("select 1", sparkQueryRequest.getSql()); + assertEquals(QUERY, sparkQueryRequest.getSql()); } @Test @@ -100,7 +101,7 @@ void testMixedArgumentTypes() { = new SparkSqlTableFunctionResolver(client); FunctionName functionName = FunctionName.of("sql"); List expressions - = List.of(DSL.namedArgument("query", DSL.literal("select 1")), + = List.of(DSL.namedArgument("query", DSL.literal(QUERY)), DSL.namedArgument(null, DSL.literal(12345))); FunctionSignature functionSignature = new FunctionSignature(functionName, expressions .stream().map(Expression::type).collect(Collectors.toList())); diff --git a/spark/src/test/java/org/opensearch/sql/spark/response/SparkResponseTest.java b/spark/src/test/java/org/opensearch/sql/spark/response/SparkResponseTest.java new file mode 100644 index 0000000000..20210ea7e5 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/response/SparkResponseTest.java @@ -0,0 +1,124 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.response; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.constants.TestConstants.EMR_CLUSTER_ID; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_INDEX_NAME; + +import java.util.Map; +import org.apache.lucene.search.TotalHits; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.ActionFuture; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.rest.RestStatus; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; + +@ExtendWith(MockitoExtension.class) +public class SparkResponseTest { + @Mock + private Client client; + @Mock + private SearchResponse searchResponse; + @Mock + private DeleteResponse deleteResponse; + @Mock + private SearchHit searchHit; + @Mock + private ActionFuture searchResponseActionFuture; + @Mock + private ActionFuture deleteResponseActionFuture; + + @Test + public void testGetResultFromOpensearchIndex() { + when(client.search(any())).thenReturn(searchResponseActionFuture); + when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); + when(searchResponse.status()).thenReturn(RestStatus.OK); + when(searchResponse.getHits()) + .thenReturn( + new SearchHits( + new SearchHit[] {searchHit}, + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + 1.0F)); + Mockito.when(searchHit.getSourceAsMap()) + .thenReturn(Map.of("stepId", EMR_CLUSTER_ID)); + + + when(client.delete(any())).thenReturn(deleteResponseActionFuture); + when(deleteResponseActionFuture.actionGet()).thenReturn(deleteResponse); + when(deleteResponse.getResult()).thenReturn(DocWriteResponse.Result.DELETED); + + SparkResponse sparkResponse = new SparkResponse(client, EMR_CLUSTER_ID, "stepId"); + assertFalse(sparkResponse.getResultFromOpensearchIndex().isEmpty()); + } + + @Test + public void testInvalidSearchResponse() { + when(client.search(any())).thenReturn(searchResponseActionFuture); + when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); + when(searchResponse.status()).thenReturn(RestStatus.NO_CONTENT); + + SparkResponse sparkResponse = new SparkResponse(client, EMR_CLUSTER_ID, "stepId"); + RuntimeException exception = assertThrows(RuntimeException.class, + () -> sparkResponse.getResultFromOpensearchIndex()); + Assertions.assertEquals( + "Fetching result from " + SPARK_INDEX_NAME + + " index failed with status : " + RestStatus.NO_CONTENT, + exception.getMessage()); + } + + @Test + public void testSearchFailure() { + when(client.search(any())).thenThrow(RuntimeException.class); + SparkResponse sparkResponse = new SparkResponse(client, EMR_CLUSTER_ID, "stepId"); + assertThrows(RuntimeException.class, () -> sparkResponse.getResultFromOpensearchIndex()); + } + + @Test + public void testDeleteFailure() { + when(client.delete(any())).thenThrow(RuntimeException.class); + SparkResponse sparkResponse = new SparkResponse(client, EMR_CLUSTER_ID, "stepId"); + assertThrows(RuntimeException.class, () -> sparkResponse.deleteInSparkIndex("id")); + } + + @Test + public void testNotFoundDeleteResponse() { + when(client.delete(any())).thenReturn(deleteResponseActionFuture); + when(deleteResponseActionFuture.actionGet()).thenReturn(deleteResponse); + when(deleteResponse.getResult()).thenReturn(DocWriteResponse.Result.NOT_FOUND); + + SparkResponse sparkResponse = new SparkResponse(client, EMR_CLUSTER_ID, "stepId"); + RuntimeException exception = assertThrows(ResourceNotFoundException.class, + () -> sparkResponse.deleteInSparkIndex("123")); + Assertions.assertEquals("Spark result with id 123 doesn't exist", exception.getMessage()); + } + + @Test + public void testInvalidDeleteResponse() { + when(client.delete(any())).thenReturn(deleteResponseActionFuture); + when(deleteResponseActionFuture.actionGet()).thenReturn(deleteResponse); + when(deleteResponse.getResult()).thenReturn(DocWriteResponse.Result.NOOP); + + SparkResponse sparkResponse = new SparkResponse(client, EMR_CLUSTER_ID, "stepId"); + RuntimeException exception = assertThrows(RuntimeException.class, + () -> sparkResponse.deleteInSparkIndex("123")); + Assertions.assertEquals( + "Deleting spark result information failed with : noop", exception.getMessage()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkScanTest.java b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkScanTest.java new file mode 100644 index 0000000000..c57142f580 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkScanTest.java @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; + +import lombok.SneakyThrows; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.spark.client.SparkClient; + +@ExtendWith(MockitoExtension.class) +public class SparkScanTest { + @Mock + private SparkClient sparkClient; + + @Test + @SneakyThrows + void testQueryResponseIteratorForQueryRangeFunction() { + SparkScan sparkScan = new SparkScan(sparkClient); + sparkScan.getRequest().setSql(QUERY); + Assertions.assertFalse(sparkScan.hasNext()); + assertNull(sparkScan.next()); + } + + @Test + @SneakyThrows + void testExplain() { + SparkScan sparkScan = new SparkScan(sparkClient); + sparkScan.getRequest().setSql(QUERY); + assertEquals( + "SparkQueryRequest(sql=select 1)", + sparkScan.explain()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageEngineTest.java b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageEngineTest.java index 7adcc725fa..d42e123678 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageEngineTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageEngineTest.java @@ -7,7 +7,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -20,7 +19,6 @@ import org.opensearch.sql.expression.function.FunctionResolver; import org.opensearch.sql.spark.client.SparkClient; import org.opensearch.sql.spark.functions.resolver.SparkSqlTableFunctionResolver; -import org.opensearch.sql.storage.Table; @ExtendWith(MockitoExtension.class) public class SparkStorageEngineTest { diff --git a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageFactoryTest.java b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageFactoryTest.java index 4142cfe355..c68adf2039 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageFactoryTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageFactoryTest.java @@ -5,6 +5,9 @@ package org.opensearch.sql.spark.storage; +import static org.opensearch.sql.spark.constants.TestConstants.EMR_CLUSTER_ID; + +import java.security.InvalidParameterException; import java.util.HashMap; import lombok.SneakyThrows; import org.junit.jupiter.api.Assertions; @@ -37,18 +40,140 @@ void testGetConnectorType() { @Test @SneakyThrows void testGetStorageEngine() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "emr"); + properties.put("emr.cluster", EMR_CLUSTER_ID); + properties.put("emr.auth.type", "awssigv4"); + properties.put("emr.auth.access_key", "access_key"); + properties.put("emr.auth.secret_key", "secret_key"); + properties.put("emr.auth.region", "region"); SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); StorageEngine storageEngine - = sparkStorageFactory.getStorageEngine(new HashMap<>()); + = sparkStorageFactory.getStorageEngine(properties); Assertions.assertTrue(storageEngine instanceof SparkStorageEngine); } @Test - void createDataSourceSuccessWithLocalhost() { + @SneakyThrows + void testInvalidConnectorType() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "random"); + SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); + InvalidParameterException exception = Assertions.assertThrows(InvalidParameterException.class, + () -> sparkStorageFactory.getStorageEngine(properties)); + Assertions.assertEquals("Spark connector type is invalid.", + exception.getMessage()); + } + + @Test + @SneakyThrows + void testMissingAuth() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "emr"); + properties.put("emr.cluster", EMR_CLUSTER_ID); + SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); + IllegalArgumentException exception = Assertions.assertThrows(IllegalArgumentException.class, + () -> sparkStorageFactory.getStorageEngine(properties)); + Assertions.assertEquals("EMR config properties are missing.", + exception.getMessage()); + } + + @Test + @SneakyThrows + void testUnsupportedEmrAuth() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "emr"); + properties.put("emr.cluster", EMR_CLUSTER_ID); + properties.put("emr.auth.type", "basic"); + SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); + IllegalArgumentException exception = Assertions.assertThrows(IllegalArgumentException.class, + () -> sparkStorageFactory.getStorageEngine(properties)); + Assertions.assertEquals("Invalid auth type.", + exception.getMessage()); + } + + @Test + @SneakyThrows + void testMissingCluster() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "emr"); + properties.put("emr.auth.type", "awssigv4"); + SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); + IllegalArgumentException exception = Assertions.assertThrows(IllegalArgumentException.class, + () -> sparkStorageFactory.getStorageEngine(properties)); + Assertions.assertEquals("EMR config properties are missing.", + exception.getMessage()); + } + + @Test + @SneakyThrows + void testMissingAuthKeys() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "emr"); + properties.put("emr.cluster", EMR_CLUSTER_ID); + properties.put("emr.auth.type", "awssigv4"); + SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); + IllegalArgumentException exception = Assertions.assertThrows(IllegalArgumentException.class, + () -> sparkStorageFactory.getStorageEngine(properties)); + Assertions.assertEquals("EMR auth keys are missing.", + exception.getMessage()); + } + + @Test + @SneakyThrows + void testMissingAuthSecretKey() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "emr"); + properties.put("emr.cluster", EMR_CLUSTER_ID); + properties.put("emr.auth.type", "awssigv4"); + properties.put("emr.auth.access_key", "test"); + SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); + IllegalArgumentException exception = Assertions.assertThrows(IllegalArgumentException.class, + () -> sparkStorageFactory.getStorageEngine(properties)); + Assertions.assertEquals("EMR auth keys are missing.", + exception.getMessage()); + } + + @Test + void testCreateDataSourceSuccess() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "emr"); + properties.put("emr.cluster", EMR_CLUSTER_ID); + properties.put("emr.auth.type", "awssigv4"); + properties.put("emr.auth.access_key", "access_key"); + properties.put("emr.auth.secret_key", "secret_key"); + properties.put("emr.auth.region", "region"); + properties.put("spark.datasource.flint.host", "localhost"); + properties.put("spark.datasource.flint.port", "9200"); + properties.put("spark.datasource.flint.scheme", "http"); + properties.put("spark.datasource.flint.auth", "false"); + properties.put("spark.datasource.flint.region", "us-west-2"); + + DataSourceMetadata metadata = new DataSourceMetadata(); + metadata.setName("spark"); + metadata.setConnector(DataSourceType.SPARK); + metadata.setProperties(properties); + + DataSource dataSource = new SparkStorageFactory(client, settings).createDataSource(metadata); + Assertions.assertTrue(dataSource.getStorageEngine() instanceof SparkStorageEngine); + } + + @Test + void testSetSparkJars() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "emr"); + properties.put("spark.sql.application", "s3://spark/spark-sql-job.jar"); + properties.put("emr.cluster", EMR_CLUSTER_ID); + properties.put("emr.auth.type", "awssigv4"); + properties.put("emr.auth.access_key", "access_key"); + properties.put("emr.auth.secret_key", "secret_key"); + properties.put("emr.auth.region", "region"); + properties.put("spark.datasource.flint.integration", "s3://spark/flint-spark-integration.jar"); + DataSourceMetadata metadata = new DataSourceMetadata(); metadata.setName("spark"); metadata.setConnector(DataSourceType.SPARK); - metadata.setProperties(new HashMap<>()); + metadata.setProperties(properties); DataSource dataSource = new SparkStorageFactory(client, settings).createDataSource(metadata); Assertions.assertTrue(dataSource.getStorageEngine() instanceof SparkStorageEngine); diff --git a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkTableTest.java b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkTableTest.java index d3487d65c1..39bd2eb199 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkTableTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkTableTest.java @@ -7,16 +7,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.opensearch.sql.planner.logical.LogicalPlanDSL.project; -import static org.opensearch.sql.planner.logical.LogicalPlanDSL.relation; +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; -import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; -import java.util.List; import java.util.Map; import lombok.SneakyThrows; import org.junit.jupiter.api.Assertions; @@ -25,10 +22,10 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.data.type.ExprType; -import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.spark.client.SparkClient; import org.opensearch.sql.spark.functions.scan.SparkSqlFunctionTableScanBuilder; +import org.opensearch.sql.spark.functions.scan.SparkSqlFunctionTableScanOperator; import org.opensearch.sql.spark.request.SparkQueryRequest; import org.opensearch.sql.storage.read.TableScanBuilder; @@ -51,7 +48,7 @@ void testUnsupportedOperation() { @Test void testCreateScanBuilderWithSqlTableFunction() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); - sparkQueryRequest.setSql("select 1"); + sparkQueryRequest.setSql(QUERY); SparkTable sparkTable = new SparkTable(client, sparkQueryRequest); TableScanBuilder tableScanBuilder = sparkTable.createScanBuilder(); @@ -75,13 +72,11 @@ void testGetFieldTypesFromSparkQueryRequest() { @Test void testImplementWithSqlFunction() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); - sparkQueryRequest.setSql("select 1"); - SparkTable sparkTable = + sparkQueryRequest.setSql(QUERY); + SparkTable sparkMetricTable = new SparkTable(client, sparkQueryRequest); - List finalProjectList = new ArrayList<>(); - PhysicalPlan plan = sparkTable.implement( - project(relation("sql", sparkTable), - finalProjectList, null)); - assertNull(plan); + PhysicalPlan plan = sparkMetricTable.implement( + new SparkSqlFunctionTableScanBuilder(client, sparkQueryRequest)); + assertTrue(plan instanceof SparkSqlFunctionTableScanOperator); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/utils/TestUtils.java b/spark/src/test/java/org/opensearch/sql/spark/utils/TestUtils.java new file mode 100644 index 0000000000..0630a85096 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/utils/TestUtils.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.utils; + +import java.io.IOException; +import java.util.Objects; + +public class TestUtils { + + /** + * Get Json document from the files in resources folder. + * @param filename filename. + * @return String. + * @throws IOException IOException. + */ + public static String getJson(String filename) throws IOException { + ClassLoader classLoader = TestUtils.class.getClassLoader(); + return new String( + Objects.requireNonNull(classLoader.getResourceAsStream(filename)).readAllBytes()); + } + +} + diff --git a/spark/src/test/resources/all_data_type.json b/spark/src/test/resources/all_data_type.json new file mode 100644 index 0000000000..a046912319 --- /dev/null +++ b/spark/src/test/resources/all_data_type.json @@ -0,0 +1,22 @@ +{ + "data": { + "result": [ + "{'boolean':true,'long':922337203,'integer':2147483647,'short':32767,'byte':127,'double':9223372036854.775807,'float':21474.83647,'timestamp':'2023-07-01 10:31:30','date':'2023-07-01 10:31:30','string':'ABC','char':'A'}" + ], + "schema": [ + "{'column_name':'boolean','data_type':'boolean'}", + "{'column_name':'long','data_type':'long'}", + "{'column_name':'integer','data_type':'integer'}", + "{'column_name':'short','data_type':'short'}", + "{'column_name':'byte','data_type':'byte'}", + "{'column_name':'double','data_type':'double'}", + "{'column_name':'float','data_type':'float'}", + "{'column_name':'timestamp','data_type':'timestamp'}", + "{'column_name':'date','data_type':'date'}", + "{'column_name':'string','data_type':'string'}", + "{'column_name':'char','data_type':'char'}" + ], + "stepId": "s-123456789", + "applicationId": "application-abc" + } +} diff --git a/spark/src/test/resources/invalid_data_type.json b/spark/src/test/resources/invalid_data_type.json new file mode 100644 index 0000000000..0eb08423c8 --- /dev/null +++ b/spark/src/test/resources/invalid_data_type.json @@ -0,0 +1,12 @@ +{ + "data": { + "result": [ + "{'struct_column':'struct_value'}" + ], + "schema": [ + "{'column_name':'struct_column','data_type':'struct'}" + ], + "stepId": "s-123456789", + "applicationId": "application-abc" + } +} diff --git a/spark/src/test/resources/invalid_response.json b/spark/src/test/resources/invalid_response.json new file mode 100644 index 0000000000..53222e0560 --- /dev/null +++ b/spark/src/test/resources/invalid_response.json @@ -0,0 +1,12 @@ +{ + "content": { + "result": [ + "{'1':1}" + ], + "schema": [ + "{'column_name':'1','data_type':'integer'}" + ], + "stepId": "s-123456789", + "applicationId": "application-abc" + } +} diff --git a/spark/src/test/resources/select_query_response.json b/spark/src/test/resources/select_query_response.json new file mode 100644 index 0000000000..24cb06b49e --- /dev/null +++ b/spark/src/test/resources/select_query_response.json @@ -0,0 +1,12 @@ +{ + "data": { + "result": [ + "{'1':1}" + ], + "schema": [ + "{'column_name':'1','data_type':'integer'}" + ], + "stepId": "s-123456789", + "applicationId": "application-abc" + } +}