Skip to content

Commit

Permalink
Read extra Spark submit parameters from cluster settings (#2219)
Browse files Browse the repository at this point in the history
* Add default setting for Spark execution engine

Signed-off-by: Chen Dai <[email protected]>

* Pass extra parameters to Spark dispatcher

Signed-off-by: Chen Dai <[email protected]>

* Wrap read default setting file with previlege action

Signed-off-by: Chen Dai <[email protected]>

* Fix spotless format

Signed-off-by: Chen Dai <[email protected]>

* Use input stream to read default config file

Signed-off-by: Chen Dai <[email protected]>

* Add UT for dispatcher

Signed-off-by: Chen Dai <[email protected]>

* Add more UT

Signed-off-by: Chen Dai <[email protected]>

* Remove default config setting

Signed-off-by: Chen Dai <[email protected]>

* Fix spotless check in spark module

Signed-off-by: Chen Dai <[email protected]>

* Refactor test code

Signed-off-by: Chen Dai <[email protected]>

* Add more UT on config class

Signed-off-by: Chen Dai <[email protected]>

---------

Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen authored Oct 5, 2023
1 parent 45da40f commit 492982c
Show file tree
Hide file tree
Showing 10 changed files with 251 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@
import static org.mockito.Mockito.when;
import static org.opensearch.common.unit.TimeValue.timeValueMinutes;
import static org.opensearch.sql.opensearch.setting.LegacyOpenDistroSettings.legacySettings;
import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.METRICS_ROLLING_INTERVAL_SETTING;
import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.METRICS_ROLLING_WINDOW_SETTING;
import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.PPL_ENABLED_SETTING;
import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.QUERY_MEMORY_LIMIT_SETTING;
import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.QUERY_SIZE_LIMIT_SETTING;
import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.SPARK_EXECUTION_ENGINE_CONFIG;
import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.SQL_CURSOR_KEEP_ALIVE_SETTING;
import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.SQL_ENABLED_SETTING;
import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.SQL_SLOWLOG_SETTING;

import java.util.List;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -47,14 +56,13 @@ void getSettingValue() {
@Test
void getSettingValueWithPresetValuesInYml() {
when(clusterSettings.get(ClusterName.CLUSTER_NAME_SETTING)).thenReturn(ClusterName.DEFAULT);
when(clusterSettings.get(
(Setting<ByteSizeValue>) OpenSearchSettings.QUERY_MEMORY_LIMIT_SETTING))
when(clusterSettings.get((Setting<ByteSizeValue>) QUERY_MEMORY_LIMIT_SETTING))
.thenReturn(new ByteSizeValue(20));
when(clusterSettings.get(
not(
or(
eq(ClusterName.CLUSTER_NAME_SETTING),
eq((Setting<ByteSizeValue>) OpenSearchSettings.QUERY_MEMORY_LIMIT_SETTING)))))
eq((Setting<ByteSizeValue>) QUERY_MEMORY_LIMIT_SETTING)))))
.thenReturn(null);
OpenSearchSettings settings = new OpenSearchSettings(clusterSettings);
ByteSizeValue sizeValue = settings.getSettingValue(Settings.Key.QUERY_MEMORY_LIMIT);
Expand Down Expand Up @@ -150,21 +158,41 @@ public void updateLegacySettingsFallback() {
.put(LegacySettings.Key.METRICS_ROLLING_INTERVAL.getKeyValue(), 100L)
.build();

assertEquals(OpenSearchSettings.SQL_ENABLED_SETTING.get(settings), false);
assertEquals(OpenSearchSettings.SQL_SLOWLOG_SETTING.get(settings), 10);
assertEquals(
OpenSearchSettings.SQL_CURSOR_KEEP_ALIVE_SETTING.get(settings), timeValueMinutes(1));
assertEquals(OpenSearchSettings.PPL_ENABLED_SETTING.get(settings), true);
assertEquals(SQL_ENABLED_SETTING.get(settings), false);
assertEquals(SQL_SLOWLOG_SETTING.get(settings), 10);
assertEquals(SQL_CURSOR_KEEP_ALIVE_SETTING.get(settings), timeValueMinutes(1));
assertEquals(PPL_ENABLED_SETTING.get(settings), true);
assertEquals(
OpenSearchSettings.QUERY_MEMORY_LIMIT_SETTING.get(settings),
QUERY_MEMORY_LIMIT_SETTING.get(settings),
new ByteSizeValue((int) (JvmInfo.jvmInfo().getMem().getHeapMax().getBytes() * 0.2)));
assertEquals(OpenSearchSettings.QUERY_SIZE_LIMIT_SETTING.get(settings), 100);
assertEquals(OpenSearchSettings.METRICS_ROLLING_WINDOW_SETTING.get(settings), 2000L);
assertEquals(OpenSearchSettings.METRICS_ROLLING_INTERVAL_SETTING.get(settings), 100L);
assertEquals(QUERY_SIZE_LIMIT_SETTING.get(settings), 100);
assertEquals(METRICS_ROLLING_WINDOW_SETTING.get(settings), 2000L);
assertEquals(METRICS_ROLLING_INTERVAL_SETTING.get(settings), 100L);
}

@Test
void legacySettingsShouldBeDeprecatedBeforeRemove() {
assertEquals(15, legacySettings().size());
}

@Test
void getSparkExecutionEngineConfigSetting() {
// Default is empty string
assertEquals(
"",
SPARK_EXECUTION_ENGINE_CONFIG.get(
org.opensearch.common.settings.Settings.builder().build()));

// Configurable at runtime
String sparkConfig =
"{\n"
+ " \"sparkSubmitParameters\": \"--conf spark.dynamicAllocation.enabled=false\"\n"
+ "}";
assertEquals(
sparkConfig,
SPARK_EXECUTION_ENGINE_CONFIG.get(
org.opensearch.common.settings.Settings.builder()
.put(SPARK_EXECUTION_ENGINE_CONFIG.getKey(), sparkConfig)
.build()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ public CreateAsyncQueryResponse createAsyncQuery(
createAsyncQueryRequest.getDatasource(),
createAsyncQueryRequest.getLang(),
sparkExecutionEngineConfig.getExecutionRoleARN(),
clusterName.value()));

clusterName.value(),
sparkExecutionEngineConfig.getSparkSubmitParameters()));
asyncQueryJobMetadataStorageService.storeJobMetadata(
new AsyncQueryJobMetadata(
sparkExecutionEngineConfig.getApplicationId(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.function.Supplier;
import lombok.AllArgsConstructor;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.datasource.model.DataSourceType;
import org.opensearch.sql.datasources.auth.AuthenticationType;

/** Define Spark Submit Parameters. */
@AllArgsConstructor
@RequiredArgsConstructor
public class SparkSubmitParameters {
public static final String SPACE = " ";
Expand All @@ -32,10 +34,14 @@ public class SparkSubmitParameters {
private final String className;
private final Map<String, String> config;

/** Extra parameters to append finally */
private String extraParameters;

public static class Builder {

private final String className;
private final Map<String, String> config;
private String extraParameters;

private Builder() {
className = DEFAULT_CLASS_NAME;
Expand Down Expand Up @@ -130,8 +136,13 @@ public Builder structuredStreaming(Boolean isStructuredStreaming) {
return this;
}

public Builder extraParameters(String params) {
extraParameters = params;
return this;
}

public SparkSubmitParameters build() {
return new SparkSubmitParameters(className, config);
return new SparkSubmitParameters(className, config, extraParameters);
}
}

Expand All @@ -148,6 +159,10 @@ public String toString() {
stringBuilder.append(config.get(key));
stringBuilder.append(SPACE);
}

if (extraParameters != null) {
stringBuilder.append(extraParameters);
}
return stringBuilder.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ public class SparkExecutionEngineConfig {
private String region;
private String executionRoleARN;

/** Additional Spark submit parameters to append to request. */
private String sparkSubmitParameters;

public static SparkExecutionEngineConfig toSparkExecutionEngineConfig(String jsonString) {
return new Gson().fromJson(jsonString, SparkExecutionEngineConfig.class);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ private DispatchQueryResponse handleIndexQuery(
dataSourceService.getRawDataSourceMetadata(
dispatchQueryRequest.getDatasource()))
.structuredStreaming(indexDetails.getAutoRefresh())
.extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams())
.build()
.toString(),
tags,
Expand All @@ -170,6 +171,7 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ
.dataSource(
dataSourceService.getRawDataSourceMetadata(
dispatchQueryRequest.getDatasource()))
.extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams())
.build()
.toString(),
tags,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,22 @@

package org.opensearch.sql.spark.dispatcher.model;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.spark.rest.model.LangType;

@AllArgsConstructor
@Data
@RequiredArgsConstructor // required explicitly
public class DispatchQueryRequest {
private final String applicationId;
private final String query;
private final String datasource;
private final LangType langType;
private final String executionRoleARN;
private final String clusterName;

/** Optional extra Spark submit parameters to include in final request */
private String extraSparkSubmitParams;
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.sql.spark.asyncquery;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
Expand All @@ -20,6 +22,7 @@
import java.util.Optional;
import org.json.JSONObject;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
Expand All @@ -43,11 +46,17 @@ public class AsyncQueryExecutorServiceImplTest {
@Mock private AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService;
@Mock private Settings settings;

@Test
void testCreateAsyncQuery() {
AsyncQueryExecutorServiceImpl jobExecutorService =
private AsyncQueryExecutorService jobExecutorService;

@BeforeEach
void setUp() {
jobExecutorService =
new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings);
}

@Test
void testCreateAsyncQuery() {
CreateAsyncQueryRequest createAsyncQueryRequest =
new CreateAsyncQueryRequest(
"select * from my_glue.default.http_logs", "my_glue", LangType.SQL);
Expand Down Expand Up @@ -83,11 +92,36 @@ void testCreateAsyncQuery() {
Assertions.assertEquals(EMR_JOB_ID, createAsyncQueryResponse.getQueryId());
}

@Test
void testCreateAsyncQueryWithExtraSparkSubmitParameter() {
when(settings.getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG))
.thenReturn(
"{"
+ "\"applicationId\": \"00fd775baqpu4g0p\","
+ "\"executionRoleARN\": \"arn:aws:iam::270824043731:role/emr-job-execution-role\","
+ "\"region\": \"eu-west-1\","
+ "\"sparkSubmitParameters\": \"--conf spark.dynamicAllocation.enabled=false\""
+ "}");
when(settings.getSettingValue(Settings.Key.CLUSTER_NAME))
.thenReturn(new ClusterName(TEST_CLUSTER_NAME));
when(sparkQueryDispatcher.dispatch(any()))
.thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false, null));

jobExecutorService.createAsyncQuery(
new CreateAsyncQueryRequest(
"select * from my_glue.default.http_logs", "my_glue", LangType.SQL));

verify(sparkQueryDispatcher, times(1))
.dispatch(
argThat(
actualReq ->
actualReq
.getExtraSparkSubmitParams()
.equals("--conf spark.dynamicAllocation.enabled=false")));
}

@Test
void testGetAsyncQueryResultsWithJobNotFoundException() {
AsyncQueryExecutorServiceImpl jobExecutorService =
new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings);
when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID))
.thenReturn(Optional.empty());
AsyncQueryNotFoundException asyncQueryNotFoundException =
Expand All @@ -102,9 +136,6 @@ void testGetAsyncQueryResultsWithJobNotFoundException() {

@Test
void testGetAsyncQueryResultsWithInProgressJob() {
AsyncQueryExecutorServiceImpl jobExecutorService =
new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings);
when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID))
.thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)));
JSONObject jobResult = new JSONObject();
Expand All @@ -131,9 +162,6 @@ void testGetAsyncQueryResultsWithSuccessJob() throws IOException {
new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)))
.thenReturn(jobResult);

AsyncQueryExecutorServiceImpl jobExecutorService =
new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings);
AsyncQueryExecutionResponse asyncQueryExecutionResponse =
jobExecutorService.getAsyncQueryResults(EMR_JOB_ID);

Expand Down Expand Up @@ -164,15 +192,11 @@ void testGetAsyncQueryResultsWithDisabledExecutionEngine() {

@Test
void testCancelJobWithJobNotFound() {
AsyncQueryExecutorService asyncQueryExecutorService =
new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings);
when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID))
.thenReturn(Optional.empty());
AsyncQueryNotFoundException asyncQueryNotFoundException =
Assertions.assertThrows(
AsyncQueryNotFoundException.class,
() -> asyncQueryExecutorService.cancelQuery(EMR_JOB_ID));
AsyncQueryNotFoundException.class, () -> jobExecutorService.cancelQuery(EMR_JOB_ID));
Assertions.assertEquals(
"QueryId: " + EMR_JOB_ID + " not found", asyncQueryNotFoundException.getMessage());
verifyNoInteractions(sparkQueryDispatcher);
Expand All @@ -181,15 +205,12 @@ void testCancelJobWithJobNotFound() {

@Test
void testCancelJob() {
AsyncQueryExecutorService asyncQueryExecutorService =
new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings);
when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID))
.thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)));
when(sparkQueryDispatcher.cancelJob(
new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)))
.thenReturn(EMR_JOB_ID);
String jobId = asyncQueryExecutorService.cancelQuery(EMR_JOB_ID);
String jobId = jobExecutorService.cancelQuery(EMR_JOB_ID);
Assertions.assertEquals(EMR_JOB_ID, jobId);
verifyNoInteractions(settings);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.spark.asyncquery.model;

import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import org.junit.jupiter.api.Test;

public class SparkSubmitParametersTest {

@Test
public void testBuildWithoutExtraParameters() {
String params = SparkSubmitParameters.Builder.builder().build().toString();

assertNotNull(params);
}

@Test
public void testBuildWithExtraParameters() {
String params =
SparkSubmitParameters.Builder.builder().extraParameters("--conf A=1").build().toString();

// Assert the conf is included with a space
assertTrue(params.endsWith(" --conf A=1"));
}
}
Loading

0 comments on commit 492982c

Please sign in to comment.