Skip to content

Commit

Permalink
Add RequestContext parameter to verifyDataSourceAccessAndGetRawMetada…
Browse files Browse the repository at this point in the history
…ta method

Signed-off-by: Tomoyuki Morita <[email protected]>
  • Loading branch information
ykmr1224 committed Jul 29, 2024
1 parent a5ede64 commit 8bdc3ac
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.sql.spark.asyncquery.model;

import org.opensearch.sql.datasource.RequestContext;

/** Context interface to provide additional request related information */
public interface AsyncQueryRequestContext {
Object getAttribute(String name);
}
public interface AsyncQueryRequestContext extends RequestContext {}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public DispatchQueryResponse dispatch(
AsyncQueryRequestContext asyncQueryRequestContext) {
DataSourceMetadata dataSourceMetadata =
this.dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
dispatchQueryRequest.getDatasource());
dispatchQueryRequest.getDatasource(), asyncQueryRequestContext);

if (LangType.SQL.equals(dispatchQueryRequest.getLangType())) {
String query = dispatchQueryRequest.getQuery();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,8 @@ private void givenFlintIndexMetadataExists(String indexName) {
}

private void givenValidDataSourceMetadataExist() {
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(DATASOURCE_NAME))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
DATASOURCE_NAME, asyncQueryRequestContext))
.thenReturn(
new DataSourceMetadata.Builder()
.setName(DATASOURCE_NAME)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ void testDispatchSelectQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -223,7 +224,8 @@ void testDispatchSelectQueryWithLakeFormation() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadataWithLakeFormation();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -255,7 +257,8 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadataWithBasicAuth();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand All @@ -278,7 +281,8 @@ void testDispatchSelectQueryCreateNewSession() {
doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any(), any());
when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand All @@ -304,7 +308,8 @@ void testDispatchSelectQueryReuseSession() {
when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID);
when(session.isOperationalForDataSource(any())).thenReturn(true);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand All @@ -324,7 +329,8 @@ void testDispatchSelectQueryFailedCreateSession() {
doReturn(true).when(sessionManager).isEnabled();
doThrow(RuntimeException.class).when(sessionManager).createSession(any(), any());
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

Assertions.assertThrows(
Expand Down Expand Up @@ -358,7 +364,8 @@ void testDispatchCreateAutoRefreshIndexQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -393,7 +400,8 @@ void testDispatchCreateManualRefreshIndexQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue"))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
"my_glue", asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -426,7 +434,8 @@ void testDispatchWithPPLQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand All @@ -450,7 +459,8 @@ void testDispatchWithSparkUDFQuery() {
"CREATE TEMPORARY FUNCTION square AS 'org.apache.spark.sql.functions.expr(\"num * num\")'");
for (String query : udfQueries) {
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

IllegalArgumentException illegalArgumentException =
Expand Down Expand Up @@ -489,7 +499,8 @@ void testInvalidSQLQueryDispatchToSpark() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -532,7 +543,8 @@ void testDispatchQueryWithoutATableAndDataSourceName() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -568,7 +580,8 @@ void testDispatchIndexQueryWithoutADatasourceName() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand All @@ -589,8 +602,7 @@ void testDispatchMaterializedViewQuery() {
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
tags.put(JOB_TYPE_TAG_KEY, JobType.STREAMING.getText());
String query =
"CREATE MATERIALIZED VIEW mv_1 AS query=select * from my_glue.default.logs WITH"
+ " (auto_refresh = true)";
"CREATE MATERIALIZED VIEW mv_1 AS select * from logs WITH" + " (auto_refresh = true)";
String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query, "streaming");
StartJobRequest expected =
new StartJobRequest(
Expand All @@ -604,7 +616,8 @@ void testDispatchMaterializedViewQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -637,7 +650,8 @@ void testDispatchShowMVQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -670,7 +684,8 @@ void testRefreshIndexQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -703,7 +718,8 @@ void testDispatchDescribeIndexQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -739,7 +755,8 @@ void testDispatchAlterToAutoRefreshIndexQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue"))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
"my_glue", asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand All @@ -762,7 +779,8 @@ void testDispatchAlterToManualRefreshIndexQuery() {
"ALTER INDEX elb_and_requestUri ON my_glue.default.http_logs WITH"
+ " (auto_refresh = false)";
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue"))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
"my_glue", asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);
when(queryHandlerFactory.getIndexDMLHandler())
.thenReturn(
Expand All @@ -785,7 +803,8 @@ void testDispatchDropIndexQuery() {

String query = "DROP INDEX elb_and_requestUri ON my_glue.default.http_logs";
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue"))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
"my_glue", asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);
when(queryHandlerFactory.getIndexDMLHandler())
.thenReturn(
Expand All @@ -808,7 +827,8 @@ void testDispatchVacuumIndexQuery() {

String query = "VACUUM INDEX elb_and_requestUri ON my_glue.default.http_logs";
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue"))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
"my_glue", asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);
when(queryHandlerFactory.getIndexDMLHandler())
.thenReturn(
Expand All @@ -824,7 +844,8 @@ void testDispatchVacuumIndexQuery() {

@Test
void testDispatchWithUnSupportedDataSourceType() {
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_prometheus"))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
"my_prometheus", asyncQueryRequestContext))
.thenReturn(constructPrometheusDataSourceType());
String query = "select * from my_prometheus.default.http_logs";

Expand Down Expand Up @@ -1018,7 +1039,8 @@ void testGetQueryResponseWithSuccess() {
void testDispatchQueryWithExtraSparkSubmitParameters() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

String extraParameters = "--conf spark.dynamicAllocation.enabled=false";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,6 @@ public interface DataSourceService {
*
* @param dataSourceName of the {@link DataSource}
*/
DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(String dataSourceName);
DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(
String dataSourceName, RequestContext context);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.datasource;

/**
* Context interface to provide additional request related information. It is introduced to allow
* async-query-core library user to pass request context information to implementations of data
* accessors.
*/
public interface RequestContext {
Object getAttribute(String name);
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.opensearch.sql.config.TestConfig;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.datasource.RequestContext;
import org.opensearch.sql.datasource.model.DataSource;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.datasource.model.DataSourceType;
Expand Down Expand Up @@ -236,7 +237,8 @@ public Boolean dataSourceExists(String dataSourceName) {
}

@Override
public DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(String dataSourceName) {
public DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(
String dataSourceName, RequestContext requestContext) {
return null;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.*;
import java.util.stream.Collectors;
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.datasource.RequestContext;
import org.opensearch.sql.datasource.model.DataSource;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.datasource.model.DataSourceStatus;
Expand Down Expand Up @@ -122,7 +123,8 @@ public Boolean dataSourceExists(String dataSourceName) {
}

@Override
public DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(String dataSourceName) {
public DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(
String dataSourceName, RequestContext requestContext) {
DataSourceMetadata dataSourceMetadata = getRawDataSourceMetadata(dataSourceName);
verifyDataSourceAccess(dataSourceMetadata);
return dataSourceMetadata;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.datasource.RequestContext;
import org.opensearch.sql.datasource.model.DataSource;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.datasource.model.DataSourceStatus;
Expand All @@ -52,6 +53,7 @@ class DataSourceServiceImplTest {
@Mock private DataSourceFactory dataSourceFactory;
@Mock private StorageEngine storageEngine;
@Mock private DataSourceMetadataStorage dataSourceMetadataStorage;
@Mock private RequestContext requestContext;

@Mock private DataSourceUserAuthorizationHelper dataSourceUserAuthorizationHelper;

Expand Down Expand Up @@ -461,7 +463,9 @@ void testVerifyDataSourceAccessAndGetRawDataSourceMetadataWithDisabledData() {
DatasourceDisabledException datasourceDisabledException =
Assertions.assertThrows(
DatasourceDisabledException.class,
() -> dataSourceService.verifyDataSourceAccessAndGetRawMetadata("testDS"));
() ->
dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
"testDS", requestContext));
Assertions.assertEquals(
"Datasource testDS is disabled.", datasourceDisabledException.getMessage());
}
Expand All @@ -484,7 +488,7 @@ void testVerifyDataSourceAccessAndGetRawDataSourceMetadata() {
when(dataSourceMetadataStorage.getDataSourceMetadata("testDS"))
.thenReturn(Optional.of(dataSourceMetadata));
DataSourceMetadata dataSourceMetadata1 =
dataSourceService.verifyDataSourceAccessAndGetRawMetadata("testDS");
dataSourceService.verifyDataSourceAccessAndGetRawMetadata("testDS", requestContext);
assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.uri"));
assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.type"));
assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.username"));
Expand Down

0 comments on commit 8bdc3ac

Please sign in to comment.