Skip to content

Commit

Permalink
Add UT
Browse files Browse the repository at this point in the history
  • Loading branch information
noCharger committed Sep 3, 2024
1 parent 7edf0d3 commit 6d23a46
Show file tree
Hide file tree
Showing 13 changed files with 49 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
import java.util.List;
import java.util.Optional;
import lombok.AllArgsConstructor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.json.JSONObject;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException;
Expand All @@ -37,13 +35,10 @@ public class AsyncQueryExecutorServiceImpl implements AsyncQueryExecutorService
private SparkQueryDispatcher sparkQueryDispatcher;
private SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier;

private static final Logger LOGGER = LogManager.getLogger(AsyncQueryExecutorServiceImpl.class);

@Override
public CreateAsyncQueryResponse createAsyncQuery(
CreateAsyncQueryRequest createAsyncQueryRequest,
AsyncQueryRequestContext asyncQueryRequestContext) {
LOGGER.info("CreateAsyncQueryRequest: " + createAsyncQueryRequest.getQuery());
SparkExecutionEngineConfig sparkExecutionEngineConfig =
sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(asyncQueryRequestContext);
DispatchQueryResponse dispatchQueryResponse =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
import com.amazonaws.services.emrserverless.model.GetJobRunResult;
import java.util.Map;
import lombok.RequiredArgsConstructor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.json.JSONObject;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
Expand Down Expand Up @@ -42,8 +40,6 @@ public class BatchQueryHandler extends AsyncQueryHandler {
protected final MetricsService metricsService;
protected final SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider;

private static final Logger LOGGER = LogManager.getLogger(BatchQueryHandler.class);

@Override
protected JSONObject getResponseFromResultIndex(
AsyncQueryJobMetadata asyncQueryJobMetadata,
Expand Down Expand Up @@ -106,9 +102,6 @@ public DispatchQueryResponse submit(
tags,
false,
dataSourceMetadata.getResultIndex());

LOGGER.info("Submit batch query: " + dispatchQueryRequest.getQuery());

String jobId = emrServerlessClient.startJobRun(startJobRequest);
metricsService.incrementNumericalMetric(EMR_BATCH_QUERY_JOBS_CREATION_COUNT);
return DispatchQueryResponse.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
package org.opensearch.sql.spark.dispatcher;

import java.util.Map;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
Expand All @@ -31,8 +29,6 @@
* index, and new job is submitted to Spark.
*/
public class RefreshQueryHandler extends BatchQueryHandler {
private static final Logger LOGGER = LogManager.getLogger(RefreshQueryHandler.class);

private final FlintIndexMetadataService flintIndexMetadataService;
private final FlintIndexOpFactory flintIndexOpFactory;

Expand Down Expand Up @@ -78,8 +74,6 @@ public DispatchQueryResponse submit(
DispatchQueryRequest dispatchQueryRequest, DispatchQueryContext context) {
leaseManager.borrow(new LeaseRequest(JobType.BATCH, dispatchQueryRequest.getDatasource()));

LOGGER.info("Submit refresh query: " + dispatchQueryRequest.getQuery());

DispatchQueryResponse resp = super.submit(dispatchQueryRequest, context);
DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata();
return DispatchQueryResponse.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ public void setUp() {
flintIndexStateModelService,
flintIndexClient,
flintIndexMetadataService,
emrServerlessClientFactory);
emrServerlessClientFactory,
asyncQueryScheduler);
QueryHandlerFactory queryHandlerFactory =
new QueryHandlerFactory(
jobExecutionResponseReader,
Expand Down Expand Up @@ -519,6 +520,7 @@ private void givenFlintIndexMetadataExists(String indexName) {
.appId(APPLICATION_ID)
.jobId(JOB_ID)
.opensearchIndexName(indexName)
.flintIndexOptions(new FlintIndexOptions())
.build()));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.sql.spark.flint.FlintIndexClient;
import org.opensearch.sql.spark.flint.FlintIndexMetadataService;
import org.opensearch.sql.spark.flint.FlintIndexStateModelService;
import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler;

@ExtendWith(MockitoExtension.class)
class FlintIndexOpFactoryTest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.opensearch.sql.spark.flint.FlintIndexState;
import org.opensearch.sql.spark.flint.FlintIndexStateModel;
import org.opensearch.sql.spark.flint.FlintIndexStateModelService;
import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler;

@ExtendWith(MockitoExtension.class)
class FlintIndexOpVacuumTest {
Expand Down
2 changes: 1 addition & 1 deletion async-query/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ jacocoTestCoverageVerification {
// ignore because XContext IOException
'org.opensearch.sql.spark.execution.statestore.StateStore',
'org.opensearch.sql.spark.rest.*',
'org.opensearch.sql.spark.scheduler.OpenSearchRefreshIndexJobRequestParser',
'org.opensearch.sql.spark.scheduler.OpenSearchScheduleQueryJobRequestParser',
'org.opensearch.sql.spark.transport.model.*'
]
limit {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@

package org.opensearch.sql.spark.config;

import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG;

import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.Optional;
import lombok.RequiredArgsConstructor;
import org.apache.commons.lang3.StringUtils;
import org.opensearch.sql.common.setting.Settings;

/** Load SparkExecutionEngineConfigClusterSetting from settings with privilege check. */
Expand All @@ -15,20 +20,17 @@ public class SparkExecutionEngineConfigClusterSettingLoader {
private final Settings settings;

public Optional<SparkExecutionEngineConfigClusterSetting> load() {
// String sparkExecutionEngineConfigSettingString =
// this.settings.getSettingValue(SPARK_EXECUTION_ENGINE_CONFIG);
SparkExecutionEngineConfigClusterSetting setting =
new SparkExecutionEngineConfigClusterSetting("test", "test", "test", "test", "test");
return Optional.of(setting);
// if (!StringUtils.isBlank(sparkExecutionEngineConfigSettingString)) {
// return Optional.of(
// AccessController.doPrivileged(
// (PrivilegedAction<SparkExecutionEngineConfigClusterSetting>)
// () ->
// SparkExecutionEngineConfigClusterSetting.toSparkExecutionEngineConfig(
// sparkExecutionEngineConfigSettingString)));
// } else {
// return Optional.empty();
// }
String sparkExecutionEngineConfigSettingString =
this.settings.getSettingValue(SPARK_EXECUTION_ENGINE_CONFIG);
if (!StringUtils.isBlank(sparkExecutionEngineConfigSettingString)) {
return Optional.of(
AccessController.doPrivileged(
(PrivilegedAction<SparkExecutionEngineConfigClusterSetting>)
() ->
SparkExecutionEngineConfigClusterSetting.toSparkExecutionEngineConfig(
sparkExecutionEngineConfigSettingString)));
} else {
return Optional.empty();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import org.opensearch.index.engine.DocumentMissingException;
import org.opensearch.index.engine.VersionConflictEngineException;
import org.opensearch.jobscheduler.spi.ScheduledJobRunner;
import org.opensearch.sql.spark.scheduler.job.AsyncQueryScheduledQueryJob;
import org.opensearch.sql.spark.scheduler.job.ScheduledAsyncQueryJob;
import org.opensearch.sql.spark.scheduler.model.AsyncQuerySchedulerRequest;
import org.opensearch.sql.spark.scheduler.model.OpenSearchScheduleQueryJobRequest;

Expand Down Expand Up @@ -197,6 +197,6 @@ private void assertIndexExists() {

/** Returns the job runner instance for the scheduler. */
public static ScheduledJobRunner getJobRunner() {
return AsyncQueryScheduledQueryJob.getJobRunnerInstance();
return ScheduledAsyncQueryJob.getJobRunnerInstance();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import org.opensearch.threadpool.ThreadPool;

/**
* The job runner class for scheduling refresh index query.
* The job runner class for scheduling async query.
*
* <p>The job runner should be a singleton class if it uses OpenSearch client or other objects
* passed from OpenSearch. Because when registering the job runner to JobScheduler plugin,
Expand All @@ -33,12 +33,12 @@
* and using singleton job runner to ensure we register a usable job runner instance to JobScheduler
* plugin.
*/
public class AsyncQueryScheduledQueryJob implements ScheduledJobRunner {
private static final Logger LOGGER = LogManager.getLogger(AsyncQueryScheduledQueryJob.class);
public class ScheduledAsyncQueryJob implements ScheduledJobRunner {
private static final Logger LOGGER = LogManager.getLogger(ScheduledAsyncQueryJob.class);

public static AsyncQueryScheduledQueryJob INSTANCE = new AsyncQueryScheduledQueryJob();
public static ScheduledAsyncQueryJob INSTANCE = new ScheduledAsyncQueryJob();

public static AsyncQueryScheduledQueryJob getJobRunnerInstance() {
public static ScheduledAsyncQueryJob getJobRunnerInstance() {
return INSTANCE;
}

Expand All @@ -47,7 +47,7 @@ public static AsyncQueryScheduledQueryJob getJobRunnerInstance() {
private Client client;
private AsyncQueryExecutorService asyncQueryExecutorService;

private AsyncQueryScheduledQueryJob() {
private ScheduledAsyncQueryJob() {
// Singleton class, use getJobRunnerInstance method instead of constructor
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;
import org.opensearch.sql.spark.response.OpenSearchJobExecutionResponseReader;
import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler;
import org.opensearch.sql.spark.scheduler.OpenSearchAsyncQueryScheduler;
import org.opensearch.sql.storage.DataSourceFactory;
import org.opensearch.test.OpenSearchIntegTestCase;
Expand All @@ -125,6 +126,7 @@ public class AsyncQueryExecutorServiceSpec extends OpenSearchIntegTestCase {
protected StateStore stateStore;
protected SessionStorageService sessionStorageService;
protected StatementStorageService statementStorageService;
protected AsyncQueryScheduler asyncQueryScheduler;
protected AsyncQueryRequestContext asyncQueryRequestContext;
protected SessionIdProvider sessionIdProvider = new DatasourceEmbeddedSessionIdProvider();

Expand Down Expand Up @@ -205,6 +207,7 @@ public void setup() {
new OpenSearchSessionStorageService(stateStore, new SessionModelXContentSerializer());
statementStorageService =
new OpenSearchStatementStorageService(stateStore, new StatementModelXContentSerializer());
asyncQueryScheduler = new OpenSearchAsyncQueryScheduler(client, clusterService);
}

protected FlintIndexOpFactory getFlintIndexOpFactory(
Expand All @@ -213,7 +216,8 @@ protected FlintIndexOpFactory getFlintIndexOpFactory(
flintIndexStateModelService,
flintIndexClient,
flintIndexMetadataService,
emrServerlessClientFactory);
emrServerlessClientFactory,
asyncQueryScheduler);
}

@After
Expand Down Expand Up @@ -299,7 +303,8 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService(
flintIndexStateModelService,
flintIndexClient,
new FlintIndexMetadataServiceImpl(client),
emrServerlessClientFactory),
emrServerlessClientFactory,
asyncQueryScheduler),
emrServerlessClientFactory,
new OpenSearchMetricsService(),
new OpenSearchAsyncQueryScheduler(client, clusterService),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import org.opensearch.sql.spark.scheduler.model.OpenSearchScheduleQueryJobRequest;
import org.opensearch.threadpool.ThreadPool;

public class AsyncQueryScheduledQueryJobTest {
public class ScheduledAsyncQueryJobTest {

@Mock(answer = Answers.RETURNS_DEEP_STUBS)
private ClusterService clusterService;
Expand All @@ -48,14 +48,14 @@ public class AsyncQueryScheduledQueryJobTest {

@Mock private JobExecutionContext context;

private AsyncQueryScheduledQueryJob jobRunner;
private ScheduledAsyncQueryJob jobRunner;

private AsyncQueryScheduledQueryJob spyJobRunner;
private ScheduledAsyncQueryJob spyJobRunner;

@BeforeEach
public void setup() {
MockitoAnnotations.openMocks(this);
jobRunner = AsyncQueryScheduledQueryJob.getJobRunnerInstance();
jobRunner = ScheduledAsyncQueryJob.getJobRunnerInstance();
jobRunner.setClient(null);
jobRunner.setClusterService(null);
jobRunner.setThreadPool(null);
Expand Down Expand Up @@ -98,7 +98,7 @@ public void testRunJobWithCorrectParameter() {

@Test
public void testRunJobWithIncorrectParameter() {
jobRunner = AsyncQueryScheduledQueryJob.getJobRunnerInstance();
jobRunner = ScheduledAsyncQueryJob.getJobRunnerInstance();
jobRunner.setClusterService(clusterService);
jobRunner.setThreadPool(threadPool);
jobRunner.setClient(client);
Expand Down Expand Up @@ -163,9 +163,9 @@ public void testRunJobWithUninitializedServices() {

@Test
public void testGetJobRunnerInstanceMultipleCalls() {
AsyncQueryScheduledQueryJob instance1 = AsyncQueryScheduledQueryJob.getJobRunnerInstance();
AsyncQueryScheduledQueryJob instance2 = AsyncQueryScheduledQueryJob.getJobRunnerInstance();
AsyncQueryScheduledQueryJob instance3 = AsyncQueryScheduledQueryJob.getJobRunnerInstance();
ScheduledAsyncQueryJob instance1 = ScheduledAsyncQueryJob.getJobRunnerInstance();
ScheduledAsyncQueryJob instance2 = ScheduledAsyncQueryJob.getJobRunnerInstance();
ScheduledAsyncQueryJob instance3 = ScheduledAsyncQueryJob.getJobRunnerInstance();

assertSame(instance1, instance2);
assertSame(instance2, instance3);
Expand Down
6 changes: 3 additions & 3 deletions plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction;
import org.opensearch.sql.spark.scheduler.OpenSearchAsyncQueryScheduler;
import org.opensearch.sql.spark.scheduler.OpenSearchScheduleQueryJobRequestParser;
import org.opensearch.sql.spark.scheduler.job.AsyncQueryScheduledQueryJob;
import org.opensearch.sql.spark.scheduler.job.ScheduledAsyncQueryJob;
import org.opensearch.sql.spark.storage.SparkStorageFactory;
import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction;
import org.opensearch.sql.spark.transport.TransportCreateAsyncQueryRequestAction;
Expand Down Expand Up @@ -247,7 +247,7 @@ public Collection<Object> createComponents(
injector.getInstance(FlintIndexOpFactory.class));
AsyncQueryExecutorService asyncQueryExecutorService =
injector.getInstance(AsyncQueryExecutorService.class);
AsyncQueryScheduledQueryJob.getJobRunnerInstance()
ScheduledAsyncQueryJob.getJobRunnerInstance()
.loadJobResource(client, clusterService, threadPool, asyncQueryExecutorService);

return ImmutableList.of(
Expand All @@ -266,7 +266,7 @@ public String getJobIndex() {

@Override
public ScheduledJobRunner getJobRunner() {
return AsyncQueryScheduledQueryJob.getJobRunnerInstance();
return ScheduledAsyncQueryJob.getJobRunnerInstance();
}

@Override
Expand Down

0 comments on commit 6d23a46

Please sign in to comment.