From 78a959078b090d7c8b29f0aaffcf1da1e52b276c Mon Sep 17 00:00:00 2001 From: Tomoyuki MORITA Date: Tue, 30 Jul 2024 13:51:35 -0700 Subject: [PATCH] Add RequestContext parameter to verifyDataSourceAccessAndGetRawMetada method (#2866) * Add RequestContext parameter to verifyDataSourceAccessAndGetRawMetadata method Signed-off-by: Tomoyuki Morita * Add comments Signed-off-by: Tomoyuki Morita * Fix style Signed-off-by: Tomoyuki Morita --------- Signed-off-by: Tomoyuki Morita --- .../model/AsyncQueryRequestContext.java | 6 +- .../dispatcher/SparkQueryDispatcher.java | 2 +- .../asyncquery/AsyncQueryCoreIntegTest.java | 3 +- .../dispatcher/SparkQueryDispatcherTest.java | 72 ++++++++++++------- .../sql/datasource/DataSourceService.java | 5 +- .../sql/datasource/RequestContext.java | 15 ++++ .../sql/analysis/AnalyzerTestBase.java | 4 +- .../service/DataSourceServiceImpl.java | 4 +- .../service/DataSourceServiceImplTest.java | 8 ++- 9 files changed, 84 insertions(+), 35 deletions(-) create mode 100644 core/src/main/java/org/opensearch/sql/datasource/RequestContext.java diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryRequestContext.java b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryRequestContext.java index 56176faefb..d5a478d592 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryRequestContext.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryRequestContext.java @@ -5,7 +5,7 @@ package org.opensearch.sql.spark.asyncquery.model; +import org.opensearch.sql.datasource.RequestContext; + /** Context interface to provide additional request related information */ -public interface AsyncQueryRequestContext { - Object getAttribute(String name); -} +public interface AsyncQueryRequestContext extends RequestContext {} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 0e871f9ddc..0061ea7179 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -44,7 +44,7 @@ public DispatchQueryResponse dispatch( AsyncQueryRequestContext asyncQueryRequestContext) { DataSourceMetadata dataSourceMetadata = this.dataSourceService.verifyDataSourceAccessAndGetRawMetadata( - dispatchQueryRequest.getDatasource()); + dispatchQueryRequest.getDatasource(), asyncQueryRequestContext); if (LangType.SQL.equals(dispatchQueryRequest.getLangType())) { String query = dispatchQueryRequest.getQuery(); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java index 99d4cc722e..34ededc74d 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java @@ -512,7 +512,8 @@ private void givenFlintIndexMetadataExists(String indexName) { } private void givenValidDataSourceMetadataExist() { - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(DATASOURCE_NAME)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + DATASOURCE_NAME, asyncQueryRequestContext)) .thenReturn( new DataSourceMetadata.Builder() .setName(DATASOURCE_NAME) diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index f9a83ef9f6..a7a79c758e 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -180,7 +180,8 @@ void testDispatchSelectQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -223,7 +224,8 @@ void testDispatchSelectQueryWithLakeFormation() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadataWithLakeFormation(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -255,7 +257,8 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadataWithBasicAuth(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -278,7 +281,8 @@ void testDispatchSelectQueryCreateNewSession() { doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any(), any()); when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -304,7 +308,8 @@ void testDispatchSelectQueryReuseSession() { when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID); when(session.isOperationalForDataSource(any())).thenReturn(true); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -324,7 +329,8 @@ void testDispatchSelectQueryFailedCreateSession() { doReturn(true).when(sessionManager).isEnabled(); doThrow(RuntimeException.class).when(sessionManager).createSession(any(), any()); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); Assertions.assertThrows( @@ -358,7 +364,8 @@ void testDispatchCreateAutoRefreshIndexQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -393,7 +400,8 @@ void testDispatchCreateManualRefreshIndexQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + "my_glue", asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -426,7 +434,8 @@ void testDispatchWithPPLQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -450,7 +459,8 @@ void testDispatchWithSparkUDFQuery() { "CREATE TEMPORARY FUNCTION square AS 'org.apache.spark.sql.functions.expr(\"num * num\")'"); for (String query : udfQueries) { DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); IllegalArgumentException illegalArgumentException = @@ -489,7 +499,8 @@ void testInvalidSQLQueryDispatchToSpark() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -532,7 +543,8 @@ void testDispatchQueryWithoutATableAndDataSourceName() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -568,7 +580,8 @@ void testDispatchIndexQueryWithoutADatasourceName() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -589,8 +602,7 @@ void testDispatchMaterializedViewQuery() { tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); tags.put(JOB_TYPE_TAG_KEY, JobType.STREAMING.getText()); String query = - "CREATE MATERIALIZED VIEW mv_1 AS query=select * from my_glue.default.logs WITH" - + " (auto_refresh = true)"; + "CREATE MATERIALIZED VIEW mv_1 AS select * from logs WITH" + " (auto_refresh = true)"; String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query, "streaming"); StartJobRequest expected = new StartJobRequest( @@ -604,7 +616,8 @@ void testDispatchMaterializedViewQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -637,7 +650,8 @@ void testDispatchShowMVQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -670,7 +684,8 @@ void testRefreshIndexQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -703,7 +718,8 @@ void testDispatchDescribeIndexQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -739,7 +755,8 @@ void testDispatchAlterToAutoRefreshIndexQuery() { "query_execution_result_my_glue"); when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + "my_glue", asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); DispatchQueryResponse dispatchQueryResponse = @@ -762,7 +779,8 @@ void testDispatchAlterToManualRefreshIndexQuery() { "ALTER INDEX elb_and_requestUri ON my_glue.default.http_logs WITH" + " (auto_refresh = false)"; DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + "my_glue", asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); when(queryHandlerFactory.getIndexDMLHandler()) .thenReturn( @@ -785,7 +803,8 @@ void testDispatchDropIndexQuery() { String query = "DROP INDEX elb_and_requestUri ON my_glue.default.http_logs"; DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + "my_glue", asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); when(queryHandlerFactory.getIndexDMLHandler()) .thenReturn( @@ -808,7 +827,8 @@ void testDispatchVacuumIndexQuery() { String query = "VACUUM INDEX elb_and_requestUri ON my_glue.default.http_logs"; DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue")) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + "my_glue", asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); when(queryHandlerFactory.getIndexDMLHandler()) .thenReturn( @@ -824,7 +844,8 @@ void testDispatchVacuumIndexQuery() { @Test void testDispatchWithUnSupportedDataSourceType() { - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_prometheus")) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + "my_prometheus", asyncQueryRequestContext)) .thenReturn(constructPrometheusDataSourceType()); String query = "select * from my_prometheus.default.http_logs"; @@ -1018,7 +1039,8 @@ void testGetQueryResponseWithSuccess() { void testDispatchQueryWithExtraSparkSubmitParameters() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); - when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + MY_GLUE, asyncQueryRequestContext)) .thenReturn(dataSourceMetadata); String extraParameters = "--conf spark.dynamicAllocation.enabled=false"; diff --git a/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java b/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java index 6af5d19e5c..a8caa4719a 100644 --- a/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java +++ b/core/src/main/java/org/opensearch/sql/datasource/DataSourceService.java @@ -82,6 +82,9 @@ public interface DataSourceService { * Specifically for addressing use cases in SparkQueryDispatcher. * * @param dataSourceName of the {@link DataSource} + * @param context request context used by the implementation. It is passed by async-query-core. + * refer {@link RequestContext} */ - DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(String dataSourceName); + DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata( + String dataSourceName, RequestContext context); } diff --git a/core/src/main/java/org/opensearch/sql/datasource/RequestContext.java b/core/src/main/java/org/opensearch/sql/datasource/RequestContext.java new file mode 100644 index 0000000000..199930d340 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/datasource/RequestContext.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.datasource; + +/** + * Context interface to provide additional request related information. It is introduced to allow + * async-query-core library user to pass request context information to implementations of data + * accessors. + */ +public interface RequestContext { + Object getAttribute(String name); +} diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java index b35cfbb5e1..0bf959a1b7 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java @@ -28,6 +28,7 @@ import org.opensearch.sql.config.TestConfig; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasource.RequestContext; import org.opensearch.sql.datasource.model.DataSource; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceType; @@ -236,7 +237,8 @@ public Boolean dataSourceExists(String dataSourceName) { } @Override - public DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(String dataSourceName) { + public DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata( + String dataSourceName, RequestContext requestContext) { return null; } } diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java b/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java index 61f3c8cd5d..81b6432891 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/service/DataSourceServiceImpl.java @@ -11,6 +11,7 @@ import java.util.*; import java.util.stream.Collectors; import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasource.RequestContext; import org.opensearch.sql.datasource.model.DataSource; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceStatus; @@ -122,7 +123,8 @@ public Boolean dataSourceExists(String dataSourceName) { } @Override - public DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(String dataSourceName) { + public DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata( + String dataSourceName, RequestContext requestContext) { DataSourceMetadata dataSourceMetadata = getRawDataSourceMetadata(dataSourceName); verifyDataSourceAccess(dataSourceMetadata); return dataSourceMetadata; diff --git a/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java b/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java index 5a94945e5b..9a1022706f 100644 --- a/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java +++ b/datasources/src/test/java/org/opensearch/sql/datasources/service/DataSourceServiceImplTest.java @@ -36,6 +36,7 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasource.RequestContext; import org.opensearch.sql.datasource.model.DataSource; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceStatus; @@ -52,6 +53,7 @@ class DataSourceServiceImplTest { @Mock private DataSourceFactory dataSourceFactory; @Mock private StorageEngine storageEngine; @Mock private DataSourceMetadataStorage dataSourceMetadataStorage; + @Mock private RequestContext requestContext; @Mock private DataSourceUserAuthorizationHelper dataSourceUserAuthorizationHelper; @@ -461,7 +463,9 @@ void testVerifyDataSourceAccessAndGetRawDataSourceMetadataWithDisabledData() { DatasourceDisabledException datasourceDisabledException = Assertions.assertThrows( DatasourceDisabledException.class, - () -> dataSourceService.verifyDataSourceAccessAndGetRawMetadata("testDS")); + () -> + dataSourceService.verifyDataSourceAccessAndGetRawMetadata( + "testDS", requestContext)); Assertions.assertEquals( "Datasource testDS is disabled.", datasourceDisabledException.getMessage()); } @@ -484,7 +488,7 @@ void testVerifyDataSourceAccessAndGetRawDataSourceMetadata() { when(dataSourceMetadataStorage.getDataSourceMetadata("testDS")) .thenReturn(Optional.of(dataSourceMetadata)); DataSourceMetadata dataSourceMetadata1 = - dataSourceService.verifyDataSourceAccessAndGetRawMetadata("testDS"); + dataSourceService.verifyDataSourceAccessAndGetRawMetadata("testDS", requestContext); assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.uri")); assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.type")); assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.username"));