Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating JobExecutionResponseReader interface to add RequestContext #3062

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ protected JSONObject getResponseFromResultIndex(
AsyncQueryRequestContext asyncQueryRequestContext) {
// either empty json when the result is not available or data with status
// Fetch from Result Index
return jobExecutionResponseReader.getResultWithJobId(
asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex());
return jobExecutionResponseReader.getResultFromResultIndex(
asyncQueryJobMetadata, asyncQueryRequestContext);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ protected JSONObject getResponseFromResultIndex(
AsyncQueryRequestContext asyncQueryRequestContext) {
String queryId = asyncQueryJobMetadata.getQueryId();
return jobExecutionResponseReader.getResultWithQueryId(
queryId, asyncQueryJobMetadata.getResultIndex());
queryId, asyncQueryJobMetadata.getResultIndex(), asyncQueryRequestContext);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ protected JSONObject getResponseFromResultIndex(
AsyncQueryRequestContext asyncQueryRequestContext) {
String queryId = asyncQueryJobMetadata.getQueryId();
return jobExecutionResponseReader.getResultWithQueryId(
queryId, asyncQueryJobMetadata.getResultIndex());
queryId, asyncQueryJobMetadata.getResultIndex(), asyncQueryRequestContext);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,22 @@
package org.opensearch.sql.spark.response;

import org.json.JSONObject;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;

/** Interface for reading job execution result */
public interface JobExecutionResponseReader {
/**
* Retrieves the job execution result based on the job ID.
*
* @param jobId The job ID.
* @param resultLocation The location identifier where the result is stored (optional).
* @param asyncQueryJobMetadata metadata will have jobId and resultLocation and other required
* params.
* @param asyncQueryRequestContext request context passed to AsyncQueryExecutorService
* @return A JSONObject containing the result data.
*/
JSONObject getResultWithJobId(String jobId, String resultLocation);
JSONObject getResultFromResultIndex(
AsyncQueryJobMetadata asyncQueryJobMetadata,
AsyncQueryRequestContext asyncQueryRequestContext);

/**
* Retrieves the job execution result based on the query ID.
Expand All @@ -25,5 +30,6 @@ public interface JobExecutionResponseReader {
* @param resultLocation The location identifier where the result is stored (optional).
* @return A JSONObject containing the result data.
*/
JSONObject getResultWithQueryId(String queryId, String resultLocation);
JSONObject getResultWithQueryId(
String queryId, String resultLocation, AsyncQueryRequestContext asyncQueryRequestContext);
}
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,8 @@ public void getResultOfInteractiveQuery() {
.sessionId(SESSION_ID)
.resultIndex(RESULT_INDEX));
JSONObject result = getValidExecutionResponse();
when(jobExecutionResponseReader.getResultWithQueryId(QUERY_ID, RESULT_INDEX))
when(jobExecutionResponseReader.getResultWithQueryId(
QUERY_ID, RESULT_INDEX, asyncQueryRequestContext))
.thenReturn(result);

AsyncQueryExecutionResponse response =
Expand All @@ -471,7 +472,8 @@ public void getResultOfIndexDMLQuery() {
.jobId(DROP_INDEX_JOB_ID)
.resultIndex(RESULT_INDEX));
JSONObject result = getValidExecutionResponse();
when(jobExecutionResponseReader.getResultWithQueryId(QUERY_ID, RESULT_INDEX))
when(jobExecutionResponseReader.getResultWithQueryId(
QUERY_ID, RESULT_INDEX, asyncQueryRequestContext))
.thenReturn(result);

AsyncQueryExecutionResponse response =
Expand All @@ -491,7 +493,18 @@ public void getResultOfRefreshQuery() {
.jobType(JobType.BATCH)
.resultIndex(RESULT_INDEX));
JSONObject result = getValidExecutionResponse();
when(jobExecutionResponseReader.getResultWithJobId(JOB_ID, RESULT_INDEX)).thenReturn(result);
when(jobExecutionResponseReader.getResultFromResultIndex(
AsyncQueryJobMetadata.builder()
.applicationId(APPLICATION_ID)
.queryId(QUERY_ID)
.jobId(JOB_ID)
.datasourceName(DATASOURCE_NAME)
.resultIndex(RESULT_INDEX)
.jobType(JobType.BATCH)
.metadata(ImmutableMap.of())
.build(),
asyncQueryRequestContext))
.thenReturn(result);

AsyncQueryExecutionResponse response =
asyncQueryExecutorService.getAsyncQueryResults(QUERY_ID, asyncQueryRequestContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,17 @@ void testGetQueryResponse() {
when(emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID))
.thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.PENDING)));
// simulate result index is not created yet
when(jobExecutionResponseReader.getResultWithJobId(EMR_JOB_ID, null))
when(jobExecutionResponseReader.getResultFromResultIndex(
AsyncQueryJobMetadata.builder()
.jobId(EMR_JOB_ID)
.queryId(QUERY_ID)
.applicationId(EMRS_APPLICATION_ID)
.jobId(EMR_JOB_ID)
.jobType(JobType.INTERACTIVE)
.datasourceName(MY_GLUE)
.metadata(ImmutableMap.of())
.build(),
asyncQueryRequestContext))
.thenReturn(new JSONObject());

JSONObject result =
Expand All @@ -782,7 +792,7 @@ void testGetQueryResponseWithSession() {
doReturn(StatementState.WAITING).when(statement).getStatementState();
doReturn(new JSONObject())
.when(jobExecutionResponseReader)
.getResultWithQueryId(eq(MOCK_STATEMENT_ID), any());
.getResultWithQueryId(eq(MOCK_STATEMENT_ID), any(), eq(asyncQueryRequestContext));

JSONObject result =
sparkQueryDispatcher.getQueryResponse(
Expand All @@ -798,7 +808,7 @@ void testGetQueryResponseWithInvalidSession() {
doReturn(Optional.empty()).when(sessionManager).getSession(MOCK_SESSION_ID, MY_GLUE);
doReturn(new JSONObject())
.when(jobExecutionResponseReader)
.getResultWithQueryId(eq(MOCK_STATEMENT_ID), any());
.getResultWithQueryId(eq(MOCK_STATEMENT_ID), any(), eq(asyncQueryRequestContext));

IllegalArgumentException exception =
Assertions.assertThrows(
Expand All @@ -818,7 +828,7 @@ void testGetQueryResponseWithStatementNotExist() {
doReturn(Optional.empty()).when(session).get(any(), eq(asyncQueryRequestContext));
doReturn(new JSONObject())
.when(jobExecutionResponseReader)
.getResultWithQueryId(eq(MOCK_STATEMENT_ID), any());
.getResultWithQueryId(eq(MOCK_STATEMENT_ID), any(), eq(asyncQueryRequestContext));

IllegalArgumentException exception =
Assertions.assertThrows(
Expand All @@ -840,12 +850,25 @@ void testGetQueryResponseWithSuccess() {
resultMap.put(STATUS_FIELD, "SUCCESS");
resultMap.put(ERROR_FIELD, "");
queryResult.put(DATA_FIELD, resultMap);
when(jobExecutionResponseReader.getResultWithJobId(EMR_JOB_ID, null)).thenReturn(queryResult);
AsyncQueryJobMetadata asyncQueryJobMetadata =
AsyncQueryJobMetadata.builder()
.queryId(QUERY_ID)
.applicationId(EMRS_APPLICATION_ID)
.jobId(EMR_JOB_ID)
.jobType(JobType.INTERACTIVE)
.datasourceName(MY_GLUE)
.metadata(ImmutableMap.of())
.jobId(EMR_JOB_ID)
.build();
when(jobExecutionResponseReader.getResultFromResultIndex(
asyncQueryJobMetadata, asyncQueryRequestContext))
.thenReturn(queryResult);

JSONObject result =
sparkQueryDispatcher.getQueryResponse(asyncQueryJobMetadata(), asyncQueryRequestContext);

verify(jobExecutionResponseReader, times(1)).getResultWithJobId(EMR_JOB_ID, null);
verify(jobExecutionResponseReader, times(1))
.getResultFromResultIndex(asyncQueryJobMetadata, asyncQueryRequestContext);
assertEquals(
new HashSet<>(Arrays.asList(DATA_FIELD, STATUS_FIELD, ERROR_FIELD)), result.keySet());
JSONObject dataJson = new JSONObject();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;

/** JobExecutionResponseReader implementation for reading response from OpenSearch index. */
public class OpenSearchJobExecutionResponseReader implements JobExecutionResponseReader {
Expand All @@ -32,12 +34,17 @@ public OpenSearchJobExecutionResponseReader(Client client) {
}

@Override
public JSONObject getResultWithJobId(String jobId, String resultLocation) {
return searchInSparkIndex(QueryBuilders.termQuery(JOB_ID_FIELD, jobId), resultLocation);
public JSONObject getResultFromResultIndex(
AsyncQueryJobMetadata asyncQueryJobMetadata,
AsyncQueryRequestContext asyncQueryRequestContext) {
return searchInSparkIndex(
QueryBuilders.termQuery(JOB_ID_FIELD, asyncQueryJobMetadata.getJobId()),
asyncQueryJobMetadata.getResultIndex());
}

@Override
public JSONObject getResultWithQueryId(String queryId, String resultLocation) {
public JSONObject getResultWithQueryId(
String queryId, String resultLocation, AsyncQueryRequestContext asyncQueryRequestContext) {
return searchInSparkIndex(QueryBuilders.termQuery("queryId", queryId), resultLocation);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.opensearch.sql.protocol.response.format.JsonResponseFormatter;
import org.opensearch.sql.protocol.response.format.ResponseFormatter;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.asyncquery.model.MockFlintSparkJob;
import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext;
Expand Down Expand Up @@ -428,12 +429,21 @@ private class AssertionHelper {
*/
new JobExecutionResponseReader() {
@Override
public JSONObject getResultWithJobId(String jobId, String resultIndex) {
return interaction.interact(new InteractionStep(emrClient, jobId, resultIndex));
public JSONObject getResultFromResultIndex(
AsyncQueryJobMetadata asyncQueryJobMetadata,
AsyncQueryRequestContext asyncQueryRequestContext) {
return interaction.interact(
new InteractionStep(
emrClient,
asyncQueryJobMetadata.getJobId(),
asyncQueryJobMetadata.getResultIndex()));
}

@Override
public JSONObject getResultWithQueryId(String queryId, String resultIndex) {
public JSONObject getResultWithQueryId(
String queryId,
String resultIndex,
AsyncQueryRequestContext asyncQueryRequestContext) {
return interaction.interact(new InteractionStep(emrClient, queryId, resultIndex));
}
});
Expand Down Expand Up @@ -501,7 +511,7 @@ private InteractionStep(LocalEMRSClient emrClient, String queryId, String result
/** Simulate PPL plugin search query_execution_result */
JSONObject pluginSearchQueryResult() {
return new OpenSearchJobExecutionResponseReader(client)
.getResultWithQueryId(queryId, resultIndex);
.getResultWithQueryId(queryId, resultIndex, null);
}

/** Simulate EMR-S bulk writes query_execution_result with refresh = wait_for */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;

@ExtendWith(MockitoExtension.class)
public class OpenSearchJobExecutionResponseReaderTest {
Expand All @@ -50,7 +51,11 @@ public void testGetResultFromOpensearchIndex() {
new SearchHit[] {searchHit}, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F));
Mockito.when(searchHit.getSourceAsMap()).thenReturn(Map.of("stepId", EMR_JOB_ID));

assertFalse(jobExecutionResponseReader.getResultWithJobId(EMR_JOB_ID, null).isEmpty());
assertFalse(
jobExecutionResponseReader
.getResultFromResultIndex(
AsyncQueryJobMetadata.builder().jobId(EMR_JOB_ID).build(), null)
.isEmpty());
}

@Test
Expand All @@ -64,7 +69,11 @@ public void testGetResultFromCustomIndex() {
new SearchHit[] {searchHit}, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F));
Mockito.when(searchHit.getSourceAsMap()).thenReturn(Map.of("stepId", EMR_JOB_ID));

assertFalse(jobExecutionResponseReader.getResultWithJobId(EMR_JOB_ID, "foo").isEmpty());
assertFalse(
jobExecutionResponseReader
.getResultFromResultIndex(
AsyncQueryJobMetadata.builder().jobId(EMR_JOB_ID).resultIndex("foo").build(), null)
.isEmpty());
}

@Test
Expand All @@ -76,7 +85,9 @@ public void testInvalidSearchResponse() {
RuntimeException exception =
assertThrows(
RuntimeException.class,
() -> jobExecutionResponseReader.getResultWithJobId(EMR_JOB_ID, null));
() ->
jobExecutionResponseReader.getResultFromResultIndex(
AsyncQueryJobMetadata.builder().jobId(EMR_JOB_ID).build(), null));

Assertions.assertEquals(
"Fetching result from "
Expand All @@ -92,13 +103,18 @@ public void testSearchFailure() {

assertThrows(
RuntimeException.class,
() -> jobExecutionResponseReader.getResultWithJobId(EMR_JOB_ID, null));
() ->
jobExecutionResponseReader.getResultFromResultIndex(
AsyncQueryJobMetadata.builder().jobId(EMR_JOB_ID).build(), null));
}

@Test
public void testIndexNotFoundException() {
when(client.search(any())).thenThrow(IndexNotFoundException.class);

assertTrue(jobExecutionResponseReader.getResultWithJobId(EMR_JOB_ID, "foo").isEmpty());
assertTrue(
jobExecutionResponseReader
.getResultFromResultIndex(
AsyncQueryJobMetadata.builder().jobId(EMR_JOB_ID).resultIndex("foo").build(), null)
.isEmpty());
}
}
Loading