diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java index f2d205ffc6..99d4cc722e 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java @@ -7,9 +7,9 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.verify; @@ -77,6 +77,8 @@ import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.sql.spark.leasemanager.LeaseManager; import org.opensearch.sql.spark.metrics.MetricsService; +import org.opensearch.sql.spark.parameter.SparkParameterComposerCollection; +import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; @@ -110,6 +112,7 @@ public class AsyncQueryCoreIntegTest { @Mock FlintIndexClient flintIndexClient; @Mock AsyncQueryRequestContext asyncQueryRequestContext; @Mock MetricsService metricsService; + @Mock SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider; // storage services @Mock AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService; @@ -134,6 +137,16 @@ public class AsyncQueryCoreIntegTest { public void setUp() { emrServerlessClientFactory = (accountId) -> new EmrServerlessClientImpl(awsemrServerless, metricsService); + SparkParameterComposerCollection collection = new SparkParameterComposerCollection(); + collection.register( + DataSourceType.S3GLUE, + (dataSourceMetadata, sparkSubmitParameters, dispatchQueryRequest, context) -> + sparkSubmitParameters.setConfigItem( + "key.from.datasource.composer", "value.from.datasource.composer")); + collection.register( + (sparkSubmitParameters, dispatchQueryRequest, context) -> + sparkSubmitParameters.setConfigItem( + "key.from.generic.composer", "value.from.generic.composer")); SessionManager sessionManager = new SessionManager( sessionStorageService, @@ -156,7 +169,8 @@ public void setUp() { indexDMLResultStorageService, flintIndexOpFactory, emrServerlessClientFactory, - metricsService); + metricsService, + new SparkSubmitParametersBuilderProvider(collection)); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider); @@ -272,7 +286,11 @@ private void verifyStartJobRunCalled() { verify(awsemrServerless).startJobRun(startJobRunRequestArgumentCaptor.capture()); StartJobRunRequest startJobRunRequest = startJobRunRequestArgumentCaptor.getValue(); assertEquals(APPLICATION_ID, startJobRunRequest.getApplicationId()); - assertNotNull(startJobRunRequest.getJobDriver().getSparkSubmit().getSparkSubmitParameters()); + String submitParameters = + startJobRunRequest.getJobDriver().getSparkSubmit().getSparkSubmitParameters(); + assertTrue( + submitParameters.contains("key.from.datasource.composer=value.from.datasource.composer")); + assertTrue(submitParameters.contains("key.from.generic.composer=value.from.generic.composer")); } @Test