Skip to content

Commit

Permalink
Updating JobExecutionResponseReader interface to add RequestContext
Browse files Browse the repository at this point in the history
  • Loading branch information
AMIT YADAV committed Oct 9, 2024
1 parent ac8678c commit a6c0644
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ protected JSONObject getResponseFromResultIndex(
AsyncQueryRequestContext asyncQueryRequestContext) {
// either empty json when the result is not available or data with status
// Fetch from Result Index
return jobExecutionResponseReader.getResultWithJobId(
asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex());
return jobExecutionResponseReader.getResultFromResultIndex(
asyncQueryJobMetadata, asyncQueryRequestContext);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ protected JSONObject getResponseFromResultIndex(
AsyncQueryRequestContext asyncQueryRequestContext) {
String queryId = asyncQueryJobMetadata.getQueryId();
return jobExecutionResponseReader.getResultWithQueryId(
queryId, asyncQueryJobMetadata.getResultIndex());
queryId, asyncQueryJobMetadata.getResultIndex(), asyncQueryRequestContext);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ protected JSONObject getResponseFromResultIndex(
AsyncQueryRequestContext asyncQueryRequestContext) {
String queryId = asyncQueryJobMetadata.getQueryId();
return jobExecutionResponseReader.getResultWithQueryId(
queryId, asyncQueryJobMetadata.getResultIndex());
queryId, asyncQueryJobMetadata.getResultIndex(), asyncQueryRequestContext);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,21 @@
package org.opensearch.sql.spark.response;

import org.json.JSONObject;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;

/** Interface for reading job execution result */
public interface JobExecutionResponseReader {
/**
* Retrieves the job execution result based on the job ID.
*
* @param jobId The job ID.
* @param resultLocation The location identifier where the result is stored (optional).
* @param asyncQueryJobMetadata metadata will have jobId and resultLocation and other required
* params.
* @return A JSONObject containing the result data.
*/
JSONObject getResultWithJobId(String jobId, String resultLocation);
JSONObject getResultFromResultIndex(
AsyncQueryJobMetadata asyncQueryJobMetadata,
AsyncQueryRequestContext asyncQueryRequestContext);

/**
* Retrieves the job execution result based on the query ID.
Expand All @@ -25,5 +29,6 @@ public interface JobExecutionResponseReader {
* @param resultLocation The location identifier where the result is stored (optional).
* @return A JSONObject containing the result data.
*/
JSONObject getResultWithQueryId(String queryId, String resultLocation);
JSONObject getResultWithQueryId(
String queryId, String resultLocation, AsyncQueryRequestContext asyncQueryRequestContext);
}
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,8 @@ public void getResultOfInteractiveQuery() {
.sessionId(SESSION_ID)
.resultIndex(RESULT_INDEX));
JSONObject result = getValidExecutionResponse();
when(jobExecutionResponseReader.getResultWithQueryId(QUERY_ID, RESULT_INDEX))
when(jobExecutionResponseReader.getResultWithQueryId(
QUERY_ID, RESULT_INDEX, asyncQueryRequestContext))
.thenReturn(result);

AsyncQueryExecutionResponse response =
Expand All @@ -471,7 +472,8 @@ public void getResultOfIndexDMLQuery() {
.jobId(DROP_INDEX_JOB_ID)
.resultIndex(RESULT_INDEX));
JSONObject result = getValidExecutionResponse();
when(jobExecutionResponseReader.getResultWithQueryId(QUERY_ID, RESULT_INDEX))
when(jobExecutionResponseReader.getResultWithQueryId(
QUERY_ID, RESULT_INDEX, asyncQueryRequestContext))
.thenReturn(result);

AsyncQueryExecutionResponse response =
Expand All @@ -491,7 +493,18 @@ public void getResultOfRefreshQuery() {
.jobType(JobType.BATCH)
.resultIndex(RESULT_INDEX));
JSONObject result = getValidExecutionResponse();
when(jobExecutionResponseReader.getResultWithJobId(JOB_ID, RESULT_INDEX)).thenReturn(result);
when(jobExecutionResponseReader.getResultFromResultIndex(
AsyncQueryJobMetadata.builder()
.applicationId(APPLICATION_ID)
.queryId(QUERY_ID)
.jobId(JOB_ID)
.datasourceName(DATASOURCE_NAME)
.resultIndex(RESULT_INDEX)
.jobType(JobType.BATCH)
.metadata(ImmutableMap.of())
.build(),
asyncQueryRequestContext))
.thenReturn(result);

AsyncQueryExecutionResponse response =
asyncQueryExecutorService.getAsyncQueryResults(QUERY_ID, asyncQueryRequestContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,17 @@ void testGetQueryResponse() {
when(emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID))
.thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.PENDING)));
// simulate result index is not created yet
when(jobExecutionResponseReader.getResultWithJobId(EMR_JOB_ID, null))
when(jobExecutionResponseReader.getResultFromResultIndex(
AsyncQueryJobMetadata.builder()
.jobId(EMR_JOB_ID)
.queryId(QUERY_ID)
.applicationId(EMRS_APPLICATION_ID)
.jobId(EMR_JOB_ID)
.jobType(JobType.INTERACTIVE)
.datasourceName(MY_GLUE)
.metadata(ImmutableMap.of())
.build(),
asyncQueryRequestContext))
.thenReturn(new JSONObject());

JSONObject result =
Expand All @@ -782,7 +792,7 @@ void testGetQueryResponseWithSession() {
doReturn(StatementState.WAITING).when(statement).getStatementState();
doReturn(new JSONObject())
.when(jobExecutionResponseReader)
.getResultWithQueryId(eq(MOCK_STATEMENT_ID), any());
.getResultWithQueryId(eq(MOCK_STATEMENT_ID), any(), eq(asyncQueryRequestContext));

JSONObject result =
sparkQueryDispatcher.getQueryResponse(
Expand All @@ -798,7 +808,7 @@ void testGetQueryResponseWithInvalidSession() {
doReturn(Optional.empty()).when(sessionManager).getSession(MOCK_SESSION_ID, MY_GLUE);
doReturn(new JSONObject())
.when(jobExecutionResponseReader)
.getResultWithQueryId(eq(MOCK_STATEMENT_ID), any());
.getResultWithQueryId(eq(MOCK_STATEMENT_ID), any(), eq(asyncQueryRequestContext));

IllegalArgumentException exception =
Assertions.assertThrows(
Expand All @@ -818,7 +828,7 @@ void testGetQueryResponseWithStatementNotExist() {
doReturn(Optional.empty()).when(session).get(any(), eq(asyncQueryRequestContext));
doReturn(new JSONObject())
.when(jobExecutionResponseReader)
.getResultWithQueryId(eq(MOCK_STATEMENT_ID), any());
.getResultWithQueryId(eq(MOCK_STATEMENT_ID), any(), eq(asyncQueryRequestContext));

IllegalArgumentException exception =
Assertions.assertThrows(
Expand All @@ -840,12 +850,25 @@ void testGetQueryResponseWithSuccess() {
resultMap.put(STATUS_FIELD, "SUCCESS");
resultMap.put(ERROR_FIELD, "");
queryResult.put(DATA_FIELD, resultMap);
when(jobExecutionResponseReader.getResultWithJobId(EMR_JOB_ID, null)).thenReturn(queryResult);
AsyncQueryJobMetadata asyncQueryJobMetadata =
AsyncQueryJobMetadata.builder()
.queryId(QUERY_ID)
.applicationId(EMRS_APPLICATION_ID)
.jobId(EMR_JOB_ID)
.jobType(JobType.INTERACTIVE)
.datasourceName(MY_GLUE)
.metadata(ImmutableMap.of())
.jobId(EMR_JOB_ID)
.build();
when(jobExecutionResponseReader.getResultFromResultIndex(
asyncQueryJobMetadata, asyncQueryRequestContext))
.thenReturn(queryResult);

JSONObject result =
sparkQueryDispatcher.getQueryResponse(asyncQueryJobMetadata(), asyncQueryRequestContext);

verify(jobExecutionResponseReader, times(1)).getResultWithJobId(EMR_JOB_ID, null);
verify(jobExecutionResponseReader, times(1))
.getResultFromResultIndex(asyncQueryJobMetadata, asyncQueryRequestContext);
assertEquals(
new HashSet<>(Arrays.asList(DATA_FIELD, STATUS_FIELD, ERROR_FIELD)), result.keySet());
JSONObject dataJson = new JSONObject();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;

/** JobExecutionResponseReader implementation for reading response from OpenSearch index. */
public class OpenSearchJobExecutionResponseReader implements JobExecutionResponseReader {
Expand All @@ -32,12 +34,12 @@ public OpenSearchJobExecutionResponseReader(Client client) {
}

@Override
public JSONObject getResultWithJobId(String jobId, String resultLocation) {
return searchInSparkIndex(QueryBuilders.termQuery(JOB_ID_FIELD, jobId), resultLocation);
public JSONObject getResultFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata, AsyncQueryRequestContext asyncQueryRequestContext) {
return searchInSparkIndex(QueryBuilders.termQuery(JOB_ID_FIELD, asyncQueryJobMetadata.getJobId()), asyncQueryJobMetadata.getResultIndex());
}

@Override
public JSONObject getResultWithQueryId(String queryId, String resultLocation) {
public JSONObject getResultWithQueryId(String queryId, String resultLocation, AsyncQueryRequestContext asyncQueryRequestContext) {
return searchInSparkIndex(QueryBuilders.termQuery("queryId", queryId), resultLocation);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.opensearch.sql.protocol.response.format.JsonResponseFormatter;
import org.opensearch.sql.protocol.response.format.ResponseFormatter;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.asyncquery.model.MockFlintSparkJob;
import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext;
Expand Down Expand Up @@ -428,13 +429,13 @@ private class AssertionHelper {
*/
new JobExecutionResponseReader() {
@Override
public JSONObject getResultWithJobId(String jobId, String resultIndex) {
return interaction.interact(new InteractionStep(emrClient, jobId, resultIndex));
public JSONObject getResultFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata, AsyncQueryRequestContext asyncQueryRequestContext) {
return interaction.interact(new InteractionStep(emrClient, asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex()));
}

@Override
public JSONObject getResultWithQueryId(String queryId, String resultIndex) {
return interaction.interact(new InteractionStep(emrClient, queryId, resultIndex));
public JSONObject getResultWithQueryId(String queryId, String resultIndex, AsyncQueryRequestContext asyncQueryRequestContext) {
return interaction.interact(new InteractionStep(emrClient, queryId, resultIndex));
}
});
this.createQueryResponse =
Expand Down Expand Up @@ -501,7 +502,7 @@ private InteractionStep(LocalEMRSClient emrClient, String queryId, String result
/** Simulate PPL plugin search query_execution_result */
JSONObject pluginSearchQueryResult() {
return new OpenSearchJobExecutionResponseReader(client)
.getResultWithQueryId(queryId, resultIndex);
.getResultWithQueryId(queryId, resultIndex, null);
}

/** Simulate EMR-S bulk writes query_execution_result with refresh = wait_for */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;

@ExtendWith(MockitoExtension.class)
public class OpenSearchJobExecutionResponseReaderTest {
Expand All @@ -50,7 +51,7 @@ public void testGetResultFromOpensearchIndex() {
new SearchHit[] {searchHit}, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F));
Mockito.when(searchHit.getSourceAsMap()).thenReturn(Map.of("stepId", EMR_JOB_ID));

assertFalse(jobExecutionResponseReader.getResultWithJobId(EMR_JOB_ID, null).isEmpty());
assertFalse(jobExecutionResponseReader.getResultFromResultIndex(AsyncQueryJobMetadata.builder().jobId(EMR_JOB_ID).build(), null).isEmpty());
}

@Test
Expand All @@ -64,7 +65,7 @@ public void testGetResultFromCustomIndex() {
new SearchHit[] {searchHit}, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F));
Mockito.when(searchHit.getSourceAsMap()).thenReturn(Map.of("stepId", EMR_JOB_ID));

assertFalse(jobExecutionResponseReader.getResultWithJobId(EMR_JOB_ID, "foo").isEmpty());
assertFalse(jobExecutionResponseReader.getResultFromResultIndex(AsyncQueryJobMetadata.builder().jobId(EMR_JOB_ID).resultIndex("foo").build(), null).isEmpty());
}

@Test
Expand All @@ -76,7 +77,7 @@ public void testInvalidSearchResponse() {
RuntimeException exception =
assertThrows(
RuntimeException.class,
() -> jobExecutionResponseReader.getResultWithJobId(EMR_JOB_ID, null));
() -> jobExecutionResponseReader.getResultFromResultIndex(AsyncQueryJobMetadata.builder().jobId(EMR_JOB_ID).build(), null));

Assertions.assertEquals(
"Fetching result from "
Expand All @@ -92,13 +93,13 @@ public void testSearchFailure() {

assertThrows(
RuntimeException.class,
() -> jobExecutionResponseReader.getResultWithJobId(EMR_JOB_ID, null));
() -> jobExecutionResponseReader.getResultFromResultIndex(AsyncQueryJobMetadata.builder().jobId(EMR_JOB_ID).build(), null));
}

@Test
public void testIndexNotFoundException() {
when(client.search(any())).thenThrow(IndexNotFoundException.class);

assertTrue(jobExecutionResponseReader.getResultWithJobId(EMR_JOB_ID, "foo").isEmpty());
assertTrue(jobExecutionResponseReader.getResultFromResultIndex(AsyncQueryJobMetadata.builder().jobId(EMR_JOB_ID).resultIndex("foo").build(), null).isEmpty());
}
}

0 comments on commit a6c0644

Please sign in to comment.