diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProvider.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProvider.java index c170040718..3564fa9552 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProvider.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProvider.java @@ -5,6 +5,7 @@ package org.opensearch.sql.spark.dispatcher; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.utils.IDUtils; @@ -12,7 +13,9 @@ public class DatasourceEmbeddedQueryIdProvider implements QueryIdProvider { @Override - public String getQueryId(DispatchQueryRequest dispatchQueryRequest) { + public String getQueryId( + DispatchQueryRequest dispatchQueryRequest, + AsyncQueryRequestContext asyncQueryRequestContext) { return IDUtils.encode(dispatchQueryRequest.getDatasource()); } } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryIdProvider.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryIdProvider.java index 2167eb6b7a..a108ca1209 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryIdProvider.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/QueryIdProvider.java @@ -5,9 +5,11 @@ package org.opensearch.sql.spark.dispatcher; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; /** Interface for extension point to specify queryId. Called when new query is executed. */ public interface QueryIdProvider { - String getQueryId(DispatchQueryRequest dispatchQueryRequest); + String getQueryId( + DispatchQueryRequest dispatchQueryRequest, AsyncQueryRequestContext asyncQueryRequestContext); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 0061ea7179..a424db4c34 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -69,7 +69,8 @@ private DispatchQueryResponse handleFlintExtensionQuery( DataSourceMetadata dataSourceMetadata) { IndexQueryDetails indexQueryDetails = getIndexQueryDetails(dispatchQueryRequest); DispatchQueryContext context = - getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata) + getDefaultDispatchContextBuilder( + dispatchQueryRequest, dataSourceMetadata, asyncQueryRequestContext) .indexQueryDetails(indexQueryDetails) .asyncQueryRequestContext(asyncQueryRequestContext) .build(); @@ -84,7 +85,8 @@ private DispatchQueryResponse handleDefaultQuery( DataSourceMetadata dataSourceMetadata) { DispatchQueryContext context = - getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata) + getDefaultDispatchContextBuilder( + dispatchQueryRequest, dataSourceMetadata, asyncQueryRequestContext) .asyncQueryRequestContext(asyncQueryRequestContext) .build(); @@ -93,11 +95,13 @@ private DispatchQueryResponse handleDefaultQuery( } private DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchContextBuilder( - DispatchQueryRequest dispatchQueryRequest, DataSourceMetadata dataSourceMetadata) { + DispatchQueryRequest dispatchQueryRequest, + DataSourceMetadata dataSourceMetadata, + AsyncQueryRequestContext asyncQueryRequestContext) { return DispatchQueryContext.builder() .dataSourceMetadata(dataSourceMetadata) .tags(getDefaultTagsForJobSubmission(dispatchQueryRequest)) - .queryId(queryIdProvider.getQueryId(dispatchQueryRequest)); + .queryId(queryIdProvider.getQueryId(dispatchQueryRequest, asyncQueryRequestContext)); } private AsyncQueryHandler getQueryHandlerForFlintExtensionQuery( diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java index 34ededc74d..d82d3bdab7 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java @@ -185,7 +185,7 @@ public void setUp() { public void createDropIndexQuery() { givenSparkExecutionEngineConfigIsSupplied(); givenValidDataSourceMetadataExist(); - when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID); + when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); String indexName = "flint_datasource_name_table_name_index_name_index"; givenFlintIndexMetadataExists(indexName); givenCancelJobRunSucceed(); @@ -209,7 +209,7 @@ public void createDropIndexQuery() { public void createVacuumIndexQuery() { givenSparkExecutionEngineConfigIsSupplied(); givenValidDataSourceMetadataExist(); - when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID); + when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); String indexName = "flint_datasource_name_table_name_index_name_index"; givenFlintIndexMetadataExists(indexName); @@ -231,7 +231,7 @@ public void createVacuumIndexQuery() { public void createAlterIndexQuery() { givenSparkExecutionEngineConfigIsSupplied(); givenValidDataSourceMetadataExist(); - when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID); + when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); String indexName = "flint_datasource_name_table_name_index_name_index"; givenFlintIndexMetadataExists(indexName); givenCancelJobRunSucceed(); @@ -261,7 +261,7 @@ public void createAlterIndexQuery() { public void createStreamingQuery() { givenSparkExecutionEngineConfigIsSupplied(); givenValidDataSourceMetadataExist(); - when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID); + when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); when(awsemrServerless.startJobRun(any())) .thenReturn(new StartJobRunResult().withApplicationId(APPLICATION_ID).withJobRunId(JOB_ID)); @@ -297,7 +297,7 @@ private void verifyStartJobRunCalled() { public void createCreateIndexQuery() { givenSparkExecutionEngineConfigIsSupplied(); givenValidDataSourceMetadataExist(); - when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID); + when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); when(awsemrServerless.startJobRun(any())) .thenReturn(new StartJobRunResult().withApplicationId(APPLICATION_ID).withJobRunId(JOB_ID)); @@ -321,7 +321,7 @@ public void createCreateIndexQuery() { public void createRefreshQuery() { givenSparkExecutionEngineConfigIsSupplied(); givenValidDataSourceMetadataExist(); - when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID); + when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); when(awsemrServerless.startJobRun(any())) .thenReturn(new StartJobRunResult().withApplicationId(APPLICATION_ID).withJobRunId(JOB_ID)); @@ -344,7 +344,7 @@ public void createInteractiveQuery() { givenSparkExecutionEngineConfigIsSupplied(); givenValidDataSourceMetadataExist(); givenSessionExists(); - when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID); + when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); when(sessionIdProvider.getSessionId(any())).thenReturn(SESSION_ID); givenSessionExists(); // called twice when(awsemrServerless.startJobRun(any())) @@ -538,7 +538,8 @@ private void givenGetJobRunReturnJobRunWithState(String state) { } private void verifyGetQueryIdCalled() { - verify(queryIdProvider).getQueryId(dispatchQueryRequestArgumentCaptor.capture()); + verify(queryIdProvider) + .getQueryId(dispatchQueryRequestArgumentCaptor.capture(), eq(asyncQueryRequestContext)); DispatchQueryRequest dispatchQueryRequest = dispatchQueryRequestArgumentCaptor.getValue(); assertEquals(ACCOUNT_ID, dispatchQueryRequest.getAccountId()); assertEquals(APPLICATION_ID, dispatchQueryRequest.getApplicationId()); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProviderTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProviderTest.java new file mode 100644 index 0000000000..7f1c92dff3 --- /dev/null +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/DatasourceEmbeddedQueryIdProviderTest.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.mockito.Mockito.verifyNoInteractions; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; + +@ExtendWith(MockitoExtension.class) +class DatasourceEmbeddedQueryIdProviderTest { + @Mock AsyncQueryRequestContext asyncQueryRequestContext; + + DatasourceEmbeddedQueryIdProvider datasourceEmbeddedQueryIdProvider = + new DatasourceEmbeddedQueryIdProvider(); + + @Test + public void test() { + String queryId = + datasourceEmbeddedQueryIdProvider.getQueryId( + DispatchQueryRequest.builder().datasource("DATASOURCE").build(), + asyncQueryRequestContext); + + assertNotNull(queryId); + verifyNoInteractions(asyncQueryRequestContext); + } +}