Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.11] Read extra Spark submit parameters from cluster settings #2236

Merged
merged 1 commit into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@
import static org.mockito.Mockito.when;
import static org.opensearch.common.unit.TimeValue.timeValueMinutes;
import static org.opensearch.sql.opensearch.setting.LegacyOpenDistroSettings.legacySettings;
import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.METRICS_ROLLING_INTERVAL_SETTING;
import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.METRICS_ROLLING_WINDOW_SETTING;
import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.PPL_ENABLED_SETTING;
import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.QUERY_MEMORY_LIMIT_SETTING;
import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.QUERY_SIZE_LIMIT_SETTING;
import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.SPARK_EXECUTION_ENGINE_CONFIG;
import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.SQL_CURSOR_KEEP_ALIVE_SETTING;
import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.SQL_ENABLED_SETTING;
import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.SQL_SLOWLOG_SETTING;

import java.util.List;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -47,14 +56,13 @@ void getSettingValue() {
@Test
void getSettingValueWithPresetValuesInYml() {
when(clusterSettings.get(ClusterName.CLUSTER_NAME_SETTING)).thenReturn(ClusterName.DEFAULT);
when(clusterSettings.get(
(Setting<ByteSizeValue>) OpenSearchSettings.QUERY_MEMORY_LIMIT_SETTING))
when(clusterSettings.get((Setting<ByteSizeValue>) QUERY_MEMORY_LIMIT_SETTING))
.thenReturn(new ByteSizeValue(20));
when(clusterSettings.get(
not(
or(
eq(ClusterName.CLUSTER_NAME_SETTING),
eq((Setting<ByteSizeValue>) OpenSearchSettings.QUERY_MEMORY_LIMIT_SETTING)))))
eq((Setting<ByteSizeValue>) QUERY_MEMORY_LIMIT_SETTING)))))
.thenReturn(null);
OpenSearchSettings settings = new OpenSearchSettings(clusterSettings);
ByteSizeValue sizeValue = settings.getSettingValue(Settings.Key.QUERY_MEMORY_LIMIT);
Expand Down Expand Up @@ -150,21 +158,41 @@ public void updateLegacySettingsFallback() {
.put(LegacySettings.Key.METRICS_ROLLING_INTERVAL.getKeyValue(), 100L)
.build();

assertEquals(OpenSearchSettings.SQL_ENABLED_SETTING.get(settings), false);
assertEquals(OpenSearchSettings.SQL_SLOWLOG_SETTING.get(settings), 10);
assertEquals(
OpenSearchSettings.SQL_CURSOR_KEEP_ALIVE_SETTING.get(settings), timeValueMinutes(1));
assertEquals(OpenSearchSettings.PPL_ENABLED_SETTING.get(settings), true);
assertEquals(SQL_ENABLED_SETTING.get(settings), false);
assertEquals(SQL_SLOWLOG_SETTING.get(settings), 10);
assertEquals(SQL_CURSOR_KEEP_ALIVE_SETTING.get(settings), timeValueMinutes(1));
assertEquals(PPL_ENABLED_SETTING.get(settings), true);
assertEquals(
OpenSearchSettings.QUERY_MEMORY_LIMIT_SETTING.get(settings),
QUERY_MEMORY_LIMIT_SETTING.get(settings),
new ByteSizeValue((int) (JvmInfo.jvmInfo().getMem().getHeapMax().getBytes() * 0.2)));
assertEquals(OpenSearchSettings.QUERY_SIZE_LIMIT_SETTING.get(settings), 100);
assertEquals(OpenSearchSettings.METRICS_ROLLING_WINDOW_SETTING.get(settings), 2000L);
assertEquals(OpenSearchSettings.METRICS_ROLLING_INTERVAL_SETTING.get(settings), 100L);
assertEquals(QUERY_SIZE_LIMIT_SETTING.get(settings), 100);
assertEquals(METRICS_ROLLING_WINDOW_SETTING.get(settings), 2000L);
assertEquals(METRICS_ROLLING_INTERVAL_SETTING.get(settings), 100L);
}

@Test
void legacySettingsShouldBeDeprecatedBeforeRemove() {
assertEquals(15, legacySettings().size());
}

@Test
void getSparkExecutionEngineConfigSetting() {
// Default is empty string
assertEquals(
"",
SPARK_EXECUTION_ENGINE_CONFIG.get(
org.opensearch.common.settings.Settings.builder().build()));

// Configurable at runtime
String sparkConfig =
"{\n"
+ " \"sparkSubmitParameters\": \"--conf spark.dynamicAllocation.enabled=false\"\n"
+ "}";
assertEquals(
sparkConfig,
SPARK_EXECUTION_ENGINE_CONFIG.get(
org.opensearch.common.settings.Settings.builder()
.put(SPARK_EXECUTION_ENGINE_CONFIG.getKey(), sparkConfig)
.build()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ public CreateAsyncQueryResponse createAsyncQuery(
createAsyncQueryRequest.getDatasource(),
createAsyncQueryRequest.getLang(),
sparkExecutionEngineConfig.getExecutionRoleARN(),
clusterName.value()));

clusterName.value(),
sparkExecutionEngineConfig.getSparkSubmitParameters()));
asyncQueryJobMetadataStorageService.storeJobMetadata(
new AsyncQueryJobMetadata(
sparkExecutionEngineConfig.getApplicationId(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.function.Supplier;
import lombok.AllArgsConstructor;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.datasource.model.DataSourceType;
import org.opensearch.sql.datasources.auth.AuthenticationType;

/** Define Spark Submit Parameters. */
@AllArgsConstructor
@RequiredArgsConstructor
public class SparkSubmitParameters {
public static final String SPACE = " ";
Expand All @@ -32,10 +34,14 @@ public class SparkSubmitParameters {
private final String className;
private final Map<String, String> config;

/** Extra parameters to append finally */
private String extraParameters;

public static class Builder {

private final String className;
private final Map<String, String> config;
private String extraParameters;

private Builder() {
className = DEFAULT_CLASS_NAME;
Expand Down Expand Up @@ -130,8 +136,13 @@ public Builder structuredStreaming(Boolean isStructuredStreaming) {
return this;
}

public Builder extraParameters(String params) {
extraParameters = params;
return this;
}

public SparkSubmitParameters build() {
return new SparkSubmitParameters(className, config);
return new SparkSubmitParameters(className, config, extraParameters);
}
}

Expand All @@ -148,6 +159,10 @@ public String toString() {
stringBuilder.append(config.get(key));
stringBuilder.append(SPACE);
}

if (extraParameters != null) {
stringBuilder.append(extraParameters);
}
return stringBuilder.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ public class SparkExecutionEngineConfig {
private String region;
private String executionRoleARN;

/** Additional Spark submit parameters to append to request. */
private String sparkSubmitParameters;

public static SparkExecutionEngineConfig toSparkExecutionEngineConfig(String jsonString) {
return new Gson().fromJson(jsonString, SparkExecutionEngineConfig.class);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ private DispatchQueryResponse handleIndexQuery(
dataSourceService.getRawDataSourceMetadata(
dispatchQueryRequest.getDatasource()))
.structuredStreaming(indexDetails.getAutoRefresh())
.extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams())
.build()
.toString(),
tags,
Expand All @@ -170,6 +171,7 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ
.dataSource(
dataSourceService.getRawDataSourceMetadata(
dispatchQueryRequest.getDatasource()))
.extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams())
.build()
.toString(),
tags,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,22 @@

package org.opensearch.sql.spark.dispatcher.model;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.spark.rest.model.LangType;

@AllArgsConstructor
@Data
@RequiredArgsConstructor // required explicitly
public class DispatchQueryRequest {
private final String applicationId;
private final String query;
private final String datasource;
private final LangType langType;
private final String executionRoleARN;
private final String clusterName;

/** Optional extra Spark submit parameters to include in final request */
private String extraSparkSubmitParams;
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.sql.spark.asyncquery;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
Expand All @@ -20,6 +22,7 @@
import java.util.Optional;
import org.json.JSONObject;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
Expand All @@ -43,11 +46,17 @@ public class AsyncQueryExecutorServiceImplTest {
@Mock private AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService;
@Mock private Settings settings;

@Test
void testCreateAsyncQuery() {
AsyncQueryExecutorServiceImpl jobExecutorService =
private AsyncQueryExecutorService jobExecutorService;

@BeforeEach
void setUp() {
jobExecutorService =
new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings);
}

@Test
void testCreateAsyncQuery() {
CreateAsyncQueryRequest createAsyncQueryRequest =
new CreateAsyncQueryRequest(
"select * from my_glue.default.http_logs", "my_glue", LangType.SQL);
Expand Down Expand Up @@ -83,11 +92,36 @@ void testCreateAsyncQuery() {
Assertions.assertEquals(EMR_JOB_ID, createAsyncQueryResponse.getQueryId());
}

@Test
void testCreateAsyncQueryWithExtraSparkSubmitParameter() {
when(settings.getSettingValue(Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG))
.thenReturn(
"{"
+ "\"applicationId\": \"00fd775baqpu4g0p\","
+ "\"executionRoleARN\": \"arn:aws:iam::270824043731:role/emr-job-execution-role\","
+ "\"region\": \"eu-west-1\","
+ "\"sparkSubmitParameters\": \"--conf spark.dynamicAllocation.enabled=false\""
+ "}");
when(settings.getSettingValue(Settings.Key.CLUSTER_NAME))
.thenReturn(new ClusterName(TEST_CLUSTER_NAME));
when(sparkQueryDispatcher.dispatch(any()))
.thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false, null));

jobExecutorService.createAsyncQuery(
new CreateAsyncQueryRequest(
"select * from my_glue.default.http_logs", "my_glue", LangType.SQL));

verify(sparkQueryDispatcher, times(1))
.dispatch(
argThat(
actualReq ->
actualReq
.getExtraSparkSubmitParams()
.equals("--conf spark.dynamicAllocation.enabled=false")));
}

@Test
void testGetAsyncQueryResultsWithJobNotFoundException() {
AsyncQueryExecutorServiceImpl jobExecutorService =
new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings);
when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID))
.thenReturn(Optional.empty());
AsyncQueryNotFoundException asyncQueryNotFoundException =
Expand All @@ -102,9 +136,6 @@ void testGetAsyncQueryResultsWithJobNotFoundException() {

@Test
void testGetAsyncQueryResultsWithInProgressJob() {
AsyncQueryExecutorServiceImpl jobExecutorService =
new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings);
when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID))
.thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)));
JSONObject jobResult = new JSONObject();
Expand All @@ -131,9 +162,6 @@ void testGetAsyncQueryResultsWithSuccessJob() throws IOException {
new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)))
.thenReturn(jobResult);

AsyncQueryExecutorServiceImpl jobExecutorService =
new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings);
AsyncQueryExecutionResponse asyncQueryExecutionResponse =
jobExecutorService.getAsyncQueryResults(EMR_JOB_ID);

Expand Down Expand Up @@ -164,15 +192,11 @@ void testGetAsyncQueryResultsWithDisabledExecutionEngine() {

@Test
void testCancelJobWithJobNotFound() {
AsyncQueryExecutorService asyncQueryExecutorService =
new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings);
when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID))
.thenReturn(Optional.empty());
AsyncQueryNotFoundException asyncQueryNotFoundException =
Assertions.assertThrows(
AsyncQueryNotFoundException.class,
() -> asyncQueryExecutorService.cancelQuery(EMR_JOB_ID));
AsyncQueryNotFoundException.class, () -> jobExecutorService.cancelQuery(EMR_JOB_ID));
Assertions.assertEquals(
"QueryId: " + EMR_JOB_ID + " not found", asyncQueryNotFoundException.getMessage());
verifyNoInteractions(sparkQueryDispatcher);
Expand All @@ -181,15 +205,12 @@ void testCancelJobWithJobNotFound() {

@Test
void testCancelJob() {
AsyncQueryExecutorService asyncQueryExecutorService =
new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService, sparkQueryDispatcher, settings);
when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID))
.thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)));
when(sparkQueryDispatcher.cancelJob(
new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)))
.thenReturn(EMR_JOB_ID);
String jobId = asyncQueryExecutorService.cancelQuery(EMR_JOB_ID);
String jobId = jobExecutorService.cancelQuery(EMR_JOB_ID);
Assertions.assertEquals(EMR_JOB_ID, jobId);
verifyNoInteractions(settings);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.spark.asyncquery.model;

import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import org.junit.jupiter.api.Test;

public class SparkSubmitParametersTest {

@Test
public void testBuildWithoutExtraParameters() {
String params = SparkSubmitParameters.Builder.builder().build().toString();

assertNotNull(params);
}

@Test
public void testBuildWithExtraParameters() {
String params =
SparkSubmitParameters.Builder.builder().extraParameters("--conf A=1").build().toString();

// Assert the conf is included with a space
assertTrue(params.endsWith(" --conf A=1"));
}
}
Loading