Skip to content

Commit

Permalink
Change JobExecutionResponseReader to an interface (#2693)
Browse files Browse the repository at this point in the history
* Change JobExecutionResponseReader to an interface

Signed-off-by: Tomoyuki Morita <[email protected]>

* Fix comment

Signed-off-by: Tomoyuki Morita <[email protected]>

---------

Signed-off-by: Tomoyuki Morita <[email protected]>
(cherry picked from commit 3dd1729)
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] committed May 29, 2024
1 parent 1768eb6 commit 21d5c6f
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public class BatchQueryHandler extends AsyncQueryHandler {
protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata) {
// either empty json when the result is not available or data with status
// Fetch from Result Index
return jobExecutionResponseReader.getResultFromOpensearchIndex(
return jobExecutionResponseReader.getResultWithJobId(
asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,75 +5,25 @@

package org.opensearch.sql.spark.response;

import static org.opensearch.sql.datasource.model.DataSourceMetadata.DEFAULT_RESULT_INDEX;
import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD;
import static org.opensearch.sql.spark.data.constants.SparkConstants.JOB_ID_FIELD;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.json.JSONObject;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.common.action.ActionFuture;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;

public class JobExecutionResponseReader {
private final Client client;
private static final Logger LOG = LogManager.getLogger();

/** Interface for reading job execution result */
public interface JobExecutionResponseReader {
/**
* JobExecutionResponseReader for spark query.
* Retrieves the job execution result based on the job ID.
*
* @param client Opensearch client
* @param jobId The job ID.
* @param resultLocation The location identifier where the result is stored (optional).
* @return A JSONObject containing the result data.
*/
public JobExecutionResponseReader(Client client) {
this.client = client;
}

public JSONObject getResultFromOpensearchIndex(String jobId, String resultIndex) {
return searchInSparkIndex(QueryBuilders.termQuery(JOB_ID_FIELD, jobId), resultIndex);
}

public JSONObject getResultWithQueryId(String queryId, String resultIndex) {
return searchInSparkIndex(QueryBuilders.termQuery("queryId", queryId), resultIndex);
}
JSONObject getResultWithJobId(String jobId, String resultLocation);

private JSONObject searchInSparkIndex(QueryBuilder query, String resultIndex) {
SearchRequest searchRequest = new SearchRequest();
String searchResultIndex = resultIndex == null ? DEFAULT_RESULT_INDEX : resultIndex;
searchRequest.indices(searchResultIndex);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(query);
searchRequest.source(searchSourceBuilder);
ActionFuture<SearchResponse> searchResponseActionFuture;
JSONObject data = new JSONObject();
try {
searchResponseActionFuture = client.search(searchRequest);
} catch (IndexNotFoundException e) {
// if there is no result index (e.g., EMR-S hasn't created the index yet), we return empty
// json
LOG.info(resultIndex + " is not created yet.");
return data;
} catch (Exception e) {
throw new RuntimeException(e);
}
SearchResponse searchResponse = searchResponseActionFuture.actionGet();
if (searchResponse.status().getStatus() != 200) {
throw new RuntimeException(
"Fetching result from "
+ searchResultIndex
+ " index failed with status : "
+ searchResponse.status());
} else {
for (SearchHit searchHit : searchResponse.getHits().getHits()) {
data.put(DATA_FIELD, searchHit.getSourceAsMap());
}
return data;
}
}
/**
* Retrieves the job execution result based on the query ID.
*
* @param queryId The query ID.
* @param resultLocation The location identifier where the result is stored (optional).
* @return A JSONObject containing the result data.
*/
JSONObject getResultWithQueryId(String queryId, String resultLocation);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.spark.response;

import static org.opensearch.sql.datasource.model.DataSourceMetadata.DEFAULT_RESULT_INDEX;
import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD;
import static org.opensearch.sql.spark.data.constants.SparkConstants.JOB_ID_FIELD;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.json.JSONObject;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.common.action.ActionFuture;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;

/** JobExecutionResponseReader implementation for reading response from OpenSearch index. */
public class OpenSearchJobExecutionResponseReader implements JobExecutionResponseReader {
private final Client client;
private static final Logger LOG = LogManager.getLogger();

public OpenSearchJobExecutionResponseReader(Client client) {
this.client = client;
}

@Override
public JSONObject getResultWithJobId(String jobId, String resultLocation) {
return searchInSparkIndex(QueryBuilders.termQuery(JOB_ID_FIELD, jobId), resultLocation);
}

@Override
public JSONObject getResultWithQueryId(String queryId, String resultLocation) {
return searchInSparkIndex(QueryBuilders.termQuery("queryId", queryId), resultLocation);
}

private JSONObject searchInSparkIndex(QueryBuilder query, String resultIndex) {
SearchRequest searchRequest = new SearchRequest();
String searchResultIndex = resultIndex == null ? DEFAULT_RESULT_INDEX : resultIndex;
searchRequest.indices(searchResultIndex);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(query);
searchRequest.source(searchSourceBuilder);
ActionFuture<SearchResponse> searchResponseActionFuture;
JSONObject data = new JSONObject();
try {
searchResponseActionFuture = client.search(searchRequest);
} catch (IndexNotFoundException e) {
// if there is no result index (e.g., EMR-S hasn't created the index yet), we return empty
// json
LOG.info(resultIndex + " is not created yet.");
return data;
} catch (Exception e) {
throw new RuntimeException(e);
}
SearchResponse searchResponse = searchResponseActionFuture.actionGet();
if (searchResponse.status().getStatus() != 200) {
throw new RuntimeException(
"Fetching result from "
+ searchResultIndex
+ " index failed with status : "
+ searchResponse.status());
} else {
for (SearchHit searchHit : searchResponse.getHits().getHits()) {
data.put(DATA_FIELD, searchHit.getSourceAsMap());
}
return data;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory;
import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;
import org.opensearch.sql.spark.response.OpenSearchJobExecutionResponseReader;

@RequiredArgsConstructor
public class AsyncExecutorServiceModule extends AbstractModule {
Expand Down Expand Up @@ -87,15 +88,15 @@ public SparkQueryDispatcher sparkQueryDispatcher(

@Provides
public QueryHandlerFactory queryhandlerFactory(
JobExecutionResponseReader jobExecutionResponseReader,
JobExecutionResponseReader openSearchJobExecutionResponseReader,
FlintIndexMetadataServiceImpl flintIndexMetadataReader,
SessionManager sessionManager,
DefaultLeaseManager defaultLeaseManager,
IndexDMLResultStorageService indexDMLResultStorageService,
FlintIndexOpFactory flintIndexOpFactory,
EMRServerlessClientFactory emrServerlessClientFactory) {
return new QueryHandlerFactory(
jobExecutionResponseReader,
openSearchJobExecutionResponseReader,
flintIndexMetadataReader,
sessionManager,
defaultLeaseManager,
Expand Down Expand Up @@ -172,7 +173,7 @@ public FlintIndexMetadataServiceImpl flintIndexMetadataReader(NodeClient client)

@Provides
public JobExecutionResponseReader jobExecutionResponseReader(NodeClient client) {
return new JobExecutionResponseReader(client);
return new OpenSearchJobExecutionResponseReader(client);
}

private void registerStateStoreMetrics(StateStore stateStore) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory;
import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;
import org.opensearch.sql.spark.response.OpenSearchJobExecutionResponseReader;
import org.opensearch.sql.storage.DataSourceFactory;
import org.opensearch.test.OpenSearchIntegTestCase;

Expand Down Expand Up @@ -225,7 +226,7 @@ private DataSourceServiceImpl createDataSourceService() {
protected AsyncQueryExecutorService createAsyncQueryExecutorService(
EMRServerlessClientFactory emrServerlessClientFactory) {
return createAsyncQueryExecutorService(
emrServerlessClientFactory, new JobExecutionResponseReader(client));
emrServerlessClientFactory, new OpenSearchJobExecutionResponseReader(client));
}

/** Pass a custom response reader which can mock interaction between PPL plugin and EMR-S job. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.opensearch.sql.spark.execution.statement.StatementState;
import org.opensearch.sql.spark.flint.FlintIndexType;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;
import org.opensearch.sql.spark.response.OpenSearchJobExecutionResponseReader;
import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest;
import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse;
import org.opensearch.sql.spark.rest.model.LangType;
Expand Down Expand Up @@ -425,9 +426,9 @@ private class AssertionHelper {
* current interaction. Intercept both get methods for different query handler which
* will only call either of them.
*/
new JobExecutionResponseReader(client) {
new JobExecutionResponseReader() {
@Override
public JSONObject getResultFromOpensearchIndex(String jobId, String resultIndex) {
public JSONObject getResultWithJobId(String jobId, String resultIndex) {
return interaction.interact(new InteractionStep(emrClient, jobId, resultIndex));
}

Expand Down Expand Up @@ -497,7 +498,8 @@ private InteractionStep(LocalEMRSClient emrClient, String queryId, String result

/** Simulate PPL plugin search query_execution_result */
JSONObject pluginSearchQueryResult() {
return new JobExecutionResponseReader(client).getResultWithQueryId(queryId, resultIndex);
return new OpenSearchJobExecutionResponseReader(client)
.getResultWithQueryId(queryId, resultIndex);
}

/** Simulate EMR-S bulk writes query_execution_result with refresh = wait_for */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,7 @@ void testGetQueryResponse() {
when(emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID))
.thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.PENDING)));
// simulate result index is not created yet
when(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, null))
when(jobExecutionResponseReader.getResultWithJobId(EMR_JOB_ID, null))
.thenReturn(new JSONObject());

JSONObject result = sparkQueryDispatcher.getQueryResponse(asyncQueryJobMetadata());
Expand Down Expand Up @@ -978,12 +978,11 @@ void testGetQueryResponseWithSuccess() {
resultMap.put(STATUS_FIELD, "SUCCESS");
resultMap.put(ERROR_FIELD, "");
queryResult.put(DATA_FIELD, resultMap);
when(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, null))
.thenReturn(queryResult);
when(jobExecutionResponseReader.getResultWithJobId(EMR_JOB_ID, null)).thenReturn(queryResult);

JSONObject result = sparkQueryDispatcher.getQueryResponse(asyncQueryJobMetadata());

verify(jobExecutionResponseReader, times(1)).getResultFromOpensearchIndex(EMR_JOB_ID, null);
verify(jobExecutionResponseReader, times(1)).getResultWithJobId(EMR_JOB_ID, null);
Assertions.assertEquals(
new HashSet<>(Arrays.asList(DATA_FIELD, STATUS_FIELD, ERROR_FIELD)), result.keySet());
JSONObject dataJson = new JSONObject();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;
Expand All @@ -30,12 +31,14 @@
import org.opensearch.search.SearchHits;

@ExtendWith(MockitoExtension.class)
public class AsyncQueryExecutionResponseReaderTest {
public class OpenSearchJobExecutionResponseReaderTest {
@Mock private Client client;
@Mock private SearchResponse searchResponse;
@Mock private SearchHit searchHit;
@Mock private ActionFuture<SearchResponse> searchResponseActionFuture;

@InjectMocks OpenSearchJobExecutionResponseReader jobExecutionResponseReader;

@Test
public void testGetResultFromOpensearchIndex() {
when(client.search(any())).thenReturn(searchResponseActionFuture);
Expand All @@ -46,9 +49,8 @@ public void testGetResultFromOpensearchIndex() {
new SearchHits(
new SearchHit[] {searchHit}, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F));
Mockito.when(searchHit.getSourceAsMap()).thenReturn(Map.of("stepId", EMR_JOB_ID));
JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client);
assertFalse(
jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, null).isEmpty());

assertFalse(jobExecutionResponseReader.getResultWithJobId(EMR_JOB_ID, null).isEmpty());
}

@Test
Expand All @@ -61,9 +63,8 @@ public void testGetResultFromCustomIndex() {
new SearchHits(
new SearchHit[] {searchHit}, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F));
Mockito.when(searchHit.getSourceAsMap()).thenReturn(Map.of("stepId", EMR_JOB_ID));
JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client);
assertFalse(
jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, "foo").isEmpty());

assertFalse(jobExecutionResponseReader.getResultWithJobId(EMR_JOB_ID, "foo").isEmpty());
}

@Test
Expand All @@ -72,11 +73,11 @@ public void testInvalidSearchResponse() {
when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse);
when(searchResponse.status()).thenReturn(RestStatus.NO_CONTENT);

JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client);
RuntimeException exception =
assertThrows(
RuntimeException.class,
() -> jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, null));
() -> jobExecutionResponseReader.getResultWithJobId(EMR_JOB_ID, null));

Assertions.assertEquals(
"Fetching result from "
+ DEFAULT_RESULT_INDEX
Expand All @@ -88,17 +89,16 @@ public void testInvalidSearchResponse() {
@Test
public void testSearchFailure() {
when(client.search(any())).thenThrow(RuntimeException.class);
JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client);

assertThrows(
RuntimeException.class,
() -> jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, null));
() -> jobExecutionResponseReader.getResultWithJobId(EMR_JOB_ID, null));
}

@Test
public void testIndexNotFoundException() {
when(client.search(any())).thenThrow(IndexNotFoundException.class);
JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client);
assertTrue(
jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, "foo").isEmpty());

assertTrue(jobExecutionResponseReader.getResultWithJobId(EMR_JOB_ID, "foo").isEmpty());
}
}

0 comments on commit 21d5c6f

Please sign in to comment.