Skip to content

Commit

Permalink
Inject services
Browse files Browse the repository at this point in the history
  • Loading branch information
noCharger committed Sep 1, 2024
1 parent 0d8b0c0 commit 2bc3d89
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.sql.spark.metrics.MetricsService;
import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;
import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler;

@RequiredArgsConstructor
public class QueryHandlerFactory {
Expand All @@ -27,6 +28,7 @@ public class QueryHandlerFactory {
private final FlintIndexOpFactory flintIndexOpFactory;
private final EMRServerlessClientFactory emrServerlessClientFactory;
private final MetricsService metricsService;
private final AsyncQueryScheduler asyncQueryScheduler;
protected final SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider;

public RefreshQueryHandler getRefreshQueryHandler(String accountId) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
package org.opensearch.sql.spark.scheduler;

public interface AsyncQueryScheduler {}
import java.io.IOException;
import org.opensearch.sql.spark.scheduler.model.AsyncQuerySchedulerRequest;

public interface AsyncQueryScheduler {
void scheduleJob(AsyncQuerySchedulerRequest asyncQuerySchedulerRequest);

void unscheduleJob(String jobId) throws IOException;

void updateJob(AsyncQuerySchedulerRequest asyncQuerySchedulerRequest) throws IOException;

void removeJob(String jobId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import lombok.RequiredArgsConstructor;
import org.apache.commons.io.IOUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand All @@ -37,9 +38,9 @@
import org.opensearch.sql.spark.scheduler.job.AsyncQueryScheduledQueryJob;
import org.opensearch.sql.spark.scheduler.model.AsyncQuerySchedulerRequest;
import org.opensearch.sql.spark.scheduler.model.OpenSearchScheduleQueryJobRequest;
import org.opensearch.threadpool.ThreadPool;

/** Scheduler class for managing asynchronous query jobs. */
@RequiredArgsConstructor
public class OpenSearchAsyncQueryScheduler implements AsyncQueryScheduler {
public static final String SCHEDULER_INDEX_NAME = ".async-query-scheduler";
public static final String SCHEDULER_PLUGIN_JOB_TYPE = "async-query-scheduler";
Expand All @@ -49,19 +50,8 @@ public class OpenSearchAsyncQueryScheduler implements AsyncQueryScheduler {
"async-query-scheduler-index-settings.yml";
private static final Logger LOG = LogManager.getLogger();

private Client client;
private ClusterService clusterService;

/** Loads job resources, setting up required services and job runner instance. */
public void loadJobResource(Client client, ClusterService clusterService, ThreadPool threadPool) {
this.client = client;
this.clusterService = clusterService;
AsyncQueryScheduledQueryJob scheduledQueryJob =
AsyncQueryScheduledQueryJob.getJobRunnerInstance();
scheduledQueryJob.setClusterService(clusterService);
scheduledQueryJob.setThreadPool(threadPool);
scheduledQueryJob.setClient(client);
}
private final Client client;
private final ClusterService clusterService;

/** Schedules a new job by indexing it into the job index. */
public void scheduleJob(AsyncQuerySchedulerRequest asyncQuerySchedulerRequest) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ private AsyncQueryScheduledQueryJob() {
// Singleton class, use getJobRunnerInstance method instead of constructor
}

/** Loads job resources, setting up required services and job runner instance. */
public void loadJobResource(
Client client,
ClusterService clusterService,
ThreadPool threadPool,
AsyncQueryExecutorService asyncQueryExecutorService) {
this.client = client;
this.clusterService = clusterService;
this.threadPool = threadPool;
this.asyncQueryExecutorService = asyncQueryExecutorService;
}

public void setClusterService(ClusterService clusterService) {
this.clusterService = clusterService;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;
import org.opensearch.sql.spark.response.OpenSearchJobExecutionResponseReader;
import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler;
import org.opensearch.sql.spark.scheduler.OpenSearchAsyncQueryScheduler;

@RequiredArgsConstructor
public class AsyncExecutorServiceModule extends AbstractModule {
Expand Down Expand Up @@ -118,6 +120,7 @@ public QueryHandlerFactory queryhandlerFactory(
FlintIndexOpFactory flintIndexOpFactory,
EMRServerlessClientFactory emrServerlessClientFactory,
MetricsService metricsService,
AsyncQueryScheduler asyncQueryScheduler,
SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider) {
return new QueryHandlerFactory(
openSearchJobExecutionResponseReader,
Expand All @@ -128,6 +131,7 @@ public QueryHandlerFactory queryhandlerFactory(
flintIndexOpFactory,
emrServerlessClientFactory,
metricsService,
asyncQueryScheduler,
sparkSubmitParametersBuilderProvider);
}

Expand Down Expand Up @@ -245,6 +249,14 @@ public SessionConfigSupplier sessionConfigSupplier(Settings settings) {
return new OpenSearchSessionConfigSupplier(settings);
}

@Provides
@Singleton
public AsyncQueryScheduler asyncQueryScheduler(NodeClient client, ClusterService clusterService) {
OpenSearchAsyncQueryScheduler scheduler =
new OpenSearchAsyncQueryScheduler(client, clusterService);
return scheduler;
}

private void registerStateStoreMetrics(StateStore stateStore) {
GaugeMetric<Long> activeSessionMetric =
new GaugeMetric<>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;
import org.opensearch.sql.spark.response.OpenSearchJobExecutionResponseReader;
import org.opensearch.sql.spark.scheduler.OpenSearchAsyncQueryScheduler;
import org.opensearch.sql.storage.DataSourceFactory;
import org.opensearch.test.OpenSearchIntegTestCase;

Expand Down Expand Up @@ -301,6 +302,7 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService(
emrServerlessClientFactory),
emrServerlessClientFactory,
new OpenSearchMetricsService(),
new OpenSearchAsyncQueryScheduler(client, clusterService),
sparkSubmitParametersBuilderProvider);
SparkQueryDispatcher sparkQueryDispatcher =
new SparkQueryDispatcher(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import org.opensearch.index.engine.VersionConflictEngineException;
import org.opensearch.jobscheduler.spi.ScheduledJobRunner;
import org.opensearch.sql.spark.scheduler.model.OpenSearchScheduleQueryJobRequest;
import org.opensearch.threadpool.ThreadPool;

public class OpenSearchAsyncQuerySchedulerTest {

Expand All @@ -57,9 +56,6 @@ public class OpenSearchAsyncQuerySchedulerTest {
@Mock(answer = Answers.RETURNS_DEEP_STUBS)
private ClusterService clusterService;

@Mock(answer = Answers.RETURNS_DEEP_STUBS)
private ThreadPool threadPool;

@Mock private ActionFuture<IndexResponse> indexResponseActionFuture;

@Mock private ActionFuture<UpdateResponse> updateResponseActionFuture;
Expand All @@ -77,8 +73,7 @@ public class OpenSearchAsyncQuerySchedulerTest {
@BeforeEach
public void setup() {
MockitoAnnotations.openMocks(this);
scheduler = new OpenSearchAsyncQueryScheduler();
scheduler.loadJobResource(client, clusterService, threadPool);
scheduler = new OpenSearchAsyncQueryScheduler(client, clusterService);
}

@Test
Expand Down
12 changes: 6 additions & 6 deletions plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,6 @@ public Collection<Object> createComponents(
this.client = (NodeClient) client;
this.dataSourceService = createDataSourceService();
dataSourceService.createDataSource(defaultOpenSearchDataSourceMetadata());
this.asyncQueryScheduler = new OpenSearchAsyncQueryScheduler();
this.asyncQueryScheduler.loadJobResource(client, clusterService, threadPool);
LocalClusterState.state().setClusterService(clusterService);
LocalClusterState.state().setPluginSettings((OpenSearchSettings) pluginSettings);
LocalClusterState.state().setClient(client);
Expand Down Expand Up @@ -247,11 +245,13 @@ public Collection<Object> createComponents(
dataSourceService,
injector.getInstance(FlintIndexMetadataServiceImpl.class),
injector.getInstance(FlintIndexOpFactory.class));
AsyncQueryExecutorService asyncQueryExecutorService =
injector.getInstance(AsyncQueryExecutorService.class);
AsyncQueryScheduledQueryJob.getJobRunnerInstance()
.loadJobResource(client, clusterService, threadPool, asyncQueryExecutorService);

return ImmutableList.of(
dataSourceService,
injector.getInstance(AsyncQueryExecutorService.class),
clusterManagerEventListener,
pluginSettings);
dataSourceService, asyncQueryExecutorService, clusterManagerEventListener, pluginSettings);
}

@Override
Expand Down

0 comments on commit 2bc3d89

Please sign in to comment.