Skip to content

Commit

Permalink
Pass executor service
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Apr 8, 2021
1 parent 5ed3b66 commit 97f959e
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
import org.elasticsearch.xpack.ml.inference.persistence.ChunkedTrainedModelRestorer;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
Expand All @@ -43,7 +44,8 @@ public void testRestoreWithMultipleSearches() throws IOException, InterruptedExc
putModelDefinitions(expectedDocs, InferenceIndexConstants.LATEST_INDEX_NAME, 0);


ChunkedTrainedModelRestorer restorer = new ChunkedTrainedModelRestorer(modelId, client(), xContentRegistry());
ChunkedTrainedModelRestorer restorer = new ChunkedTrainedModelRestorer(modelId, client(),
client().threadPool().executor(MachineLearning.UTILITY_THREAD_POOL_NAME), xContentRegistry());
restorer.setSearchSize(5);
List<TrainedModelDefinitionDoc> actualDocs = new ArrayList<>();

Expand All @@ -53,7 +55,10 @@ public void testRestoreWithMultipleSearches() throws IOException, InterruptedExc
restorer.restoreModelDefinition(
actualDocs::add,
success -> latch.countDown(),
failure -> {exceptionHolder.set(failure); latch.countDown();});
failure -> {
exceptionHolder.set(failure);
latch.countDown();
});

latch.await();

Expand Down Expand Up @@ -86,7 +91,8 @@ public void testRestoreWithDocumentsInMultipleIndices() throws IOException, Inte
putModelDefinitions(expectedDocs.subList(0, splitPoint), index1, 0);
putModelDefinitions(expectedDocs.subList(splitPoint, numDocs), index2, splitPoint);

ChunkedTrainedModelRestorer restorer = new ChunkedTrainedModelRestorer(modelId, client(), xContentRegistry());
ChunkedTrainedModelRestorer restorer = new ChunkedTrainedModelRestorer(modelId, client(),
client().threadPool().executor(MachineLearning.UTILITY_THREAD_POOL_NAME), xContentRegistry());
restorer.setSearchSize(10);
restorer.setSearchIndex("foo-*");

Expand All @@ -97,13 +103,20 @@ public void testRestoreWithDocumentsInMultipleIndices() throws IOException, Inte
restorer.restoreModelDefinition(
actualDocs::add,
success -> latch.countDown(),
failure -> {exceptionHolder.set(failure); latch.countDown();});
failure -> {
exceptionHolder.set(failure);
latch.countDown();
});

latch.await();

assertNull(exceptionHolder.get());
// TODO this fails because the results are sorted by index first
assertEquals(actualDocs, expectedDocs);
// The results are sorted by index first rather than doc_num
// TODO is this the behaviour we want?
List<TrainedModelDefinitionDoc> reorderedDocs = new ArrayList<>();
reorderedDocs.addAll(expectedDocs.subList(splitPoint, numDocs));
reorderedDocs.addAll(expectedDocs.subList(0, splitPoint));
assertEquals(actualDocs, reorderedDocs);
}

private List<TrainedModelDefinitionDoc> createModelDefinitionDocs(List<String> compressedDefinitions, String modelId) {
Expand All @@ -127,8 +140,7 @@ private List<TrainedModelDefinitionDoc> createModelDefinitionDocs(List<String> c

private void putModelDefinitions(List<TrainedModelDefinitionDoc> docs, String index, int startingDocNum) throws IOException {
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
for (int i = 0; i < docs.size(); i++) {
TrainedModelDefinitionDoc doc = docs.get(i);
for (TrainedModelDefinitionDoc doc : docs) {
try (XContentBuilder xContentBuilder = doc.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) {
IndexRequestBuilder indexRequestBuilder = client().prepareIndex(index)
.setSource(xContentBuilder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchStateStreamer;
Expand Down Expand Up @@ -48,7 +49,8 @@ public void testRestoreState() throws IOException, InterruptedException {
putModelDefinition(docs);

ByteArrayOutputStream outputStream = new ByteArrayOutputStream(modelSize);
PyTorchStateStreamer stateStreamer = new PyTorchStateStreamer(client(),xContentRegistry());
PyTorchStateStreamer stateStreamer = new PyTorchStateStreamer(client(),
client().threadPool().executor(MachineLearning.UTILITY_THREAD_POOL_NAME), xContentRegistry());

AtomicReference<Boolean> onSuccess = new AtomicReference<>();
AtomicReference<Exception> onFailure = new AtomicReference<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ public void startDeployment(TrainedModelDeploymentTask task) {
private void doStartDeployment(TrainedModelDeploymentTask task) {
logger.debug("[{}] Starting model deployment", task.getModelId());

ProcessContext processContext = new ProcessContext(task.getModelId());
ProcessContext processContext = new ProcessContext(task.getModelId(), executorServiceForProcess);

if (processContextByAllocation.putIfAbsent(task.getAllocationId(), processContext) != null) {
throw ExceptionsHelper.serverError("[{}] Could not create process as one already exists", task.getModelId());
}

ActionListener<Boolean> modelLoadedListener = ActionListener.wrap(
success -> {
executorServiceForProcess.execute(() -> processContext.resultProcessor.process(processContext.process.get()));
Expand All @@ -80,9 +80,7 @@ private void doStartDeployment(TrainedModelDeploymentTask task) {
task::markAsFailed
));
},
e -> {
failTask(task, e);
}
e -> failTask(task, e)
);

processContext.startProcess();
Expand Down Expand Up @@ -151,10 +149,10 @@ class ProcessContext {
private final PyTorchResultProcessor resultProcessor;
private final PyTorchStateStreamer stateStreamer;

ProcessContext(String modelId) {
ProcessContext(String modelId, ExecutorService executorService) {
this.modelId = Objects.requireNonNull(modelId);
resultProcessor = new PyTorchResultProcessor(modelId);
this.stateStreamer = new PyTorchStateStreamer(client, xContentRegistry);
this.stateStreamer = new PyTorchStateStreamer(client, executorService, xContentRegistry);
}

synchronized void startProcess() {
Expand All @@ -181,7 +179,7 @@ private Consumer<String> onProcessCrash() {
}

void loadModel(ActionListener<Boolean> listener) {
process.get().loadModel(modelId, new PyTorchStateStreamer(client, xContentRegistry), listener);
process.get().loadModel(modelId, stateStreamer, listener);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,24 @@
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;

import java.io.IOException;
import java.io.InputStream;
import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;

import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;

/**
* This is a use once class it has internal state to track progress
* Searches for and emits {@link TrainedModelDefinitionDoc}s in
* order based on the {@code doc_num}.
*
* This is a one-use class it has internal state to track progress
* and cannot be used again to load another model.
*
* Defaults to searching in {@link InferenceIndexConstants#INDEX_PATTERN}
* if a different index is not set.
*/
public class ChunkedTrainedModelRestorer {

Expand All @@ -51,13 +58,18 @@ public class ChunkedTrainedModelRestorer {

private final Client client;
private final NamedXContentRegistry xContentRegistry;
private final ExecutorService executorService;
private final String modelId;
private String index = InferenceIndexConstants.INDEX_PATTERN;
private int searchSize = 10;
private int numDocsWritten = 0;

public ChunkedTrainedModelRestorer(String modelId, Client client,
public ChunkedTrainedModelRestorer(String modelId,
Client client,
ExecutorService executorService,
NamedXContentRegistry xContentRegistry) {
this.client = client;
this.executorService = executorService;
this.xContentRegistry = xContentRegistry;
this.modelId = modelId;
}
Expand All @@ -66,9 +78,16 @@ public void setSearchSize(int searchSize) {
if (searchSize > MAX_NUM_DEFINITION_DOCS) {
throw new IllegalArgumentException("search size [" + searchSize + "] cannot be bigger than [" + MAX_NUM_DEFINITION_DOCS + "]");
}
if (searchSize <=0) {
throw new IllegalArgumentException("search size [" + searchSize + "] must be greater than 0");
}
this.searchSize = searchSize;
}

public void setSearchIndex(String indexNameOrPattern) {
this.index = indexNameOrPattern;
}

public int getNumDocsWritten() {
return numDocsWritten;
}
Expand All @@ -77,10 +96,9 @@ public void restoreModelDefinition(CheckedConsumer<TrainedModelDefinitionDoc, IO
Consumer<Boolean> successConsumer,
Consumer<Exception> errorConsumer) {

SearchRequest searchRequest = buildSearch(client, modelId, searchSize);
SearchRequest searchRequest = buildSearch(client, modelId, index, searchSize);

client.threadPool().executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() ->
doSearch(searchRequest, modelConsumer, successConsumer, errorConsumer));
executorService.execute(() -> doSearch(searchRequest, modelConsumer, successConsumer, errorConsumer));
}

private void doSearch(SearchRequest searchRequest,
Expand All @@ -96,9 +114,18 @@ private void doSearch(SearchRequest searchRequest,
return;
}

// Set lastNum to a non-zero to prevent an infinite loop
// search after requests in the absolute worse case where
// it has all gone wrong.
// Docs are numbered 0..N. we must have seen at least
// this many docs so far.
int lastNum = numDocsWritten -1;
for (SearchHit hit : searchResponse.getHits().getHits()) {
try {
modelConsumer.accept(parseModelDefinitionDocLenientlyFromSource(hit.getSourceRef(), modelId, xContentRegistry));
TrainedModelDefinitionDoc doc =
parseModelDefinitionDocLenientlyFromSource(hit.getSourceRef(), modelId, xContentRegistry);
lastNum = doc.getDocNum();
modelConsumer.accept(doc);
} catch (IOException e) {
logger.error(new ParameterizedMessage("[{}] error writing model definition", modelId), e);
errorConsumer.accept(e);
Expand All @@ -114,9 +141,9 @@ private void doSearch(SearchRequest searchRequest,
} else {
// search again with after
SearchHit lastHit = searchResponse.getHits().getAt(searchResponse.getHits().getHits().length -1);
SearchRequestBuilder searchRequestBuilder = buildSearchBuilder(client, modelId, searchSize);
searchRequestBuilder.searchAfter(new Object[]{lastHit.getIndex(), lastHit.getId()});
client.threadPool().executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() ->
SearchRequestBuilder searchRequestBuilder = buildSearchBuilder(client, modelId, index, searchSize);
searchRequestBuilder.searchAfter(new Object[]{lastHit.getIndex(), lastNum});
executorService.execute(() ->
doSearch(searchRequestBuilder.request(), modelConsumer, successConsumer, errorConsumer));
}
},
Expand All @@ -131,8 +158,8 @@ private void doSearch(SearchRequest searchRequest,
));
}

private static SearchRequestBuilder buildSearchBuilder(Client client, String modelId, int searchSize) {
return client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
private static SearchRequestBuilder buildSearchBuilder(Client client, String modelId, String index, int searchSize) {
return client.prepareSearch(index)
.setQuery(QueryBuilders.constantScoreQuery(QueryBuilders
.boolQuery()
.filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId))
Expand All @@ -148,8 +175,8 @@ private static SearchRequestBuilder buildSearchBuilder(Client client, String mod
.unmappedType("long"));
}

public static SearchRequest buildSearch(Client client, String modelId, int searchSize) {
return buildSearchBuilder(client, modelId, searchSize).request();
public static SearchRequest buildSearch(Client client, String modelId, String index, int searchSize) {
return buildSearchBuilder(client, modelId, index, searchSize).request();
}

public static TrainedModelDefinitionDoc parseModelDefinitionDocLenientlyFromSource(BytesReference source,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;

import java.io.IOException;
Expand Down Expand Up @@ -405,7 +406,10 @@ public void getTrainedModelForInference(final String modelId, final ActionListen
}

List<TrainedModelDefinitionDoc> docs = new ArrayList<>();
ChunkedTrainedModelRestorer modelRestorer = new ChunkedTrainedModelRestorer(modelId, client, xContentRegistry);
ChunkedTrainedModelRestorer modelRestorer =
new ChunkedTrainedModelRestorer(modelId, client,
client.threadPool().executor(MachineLearning.UTILITY_THREAD_POOL_NAME), xContentRegistry);

// TODO how could we stream in the model definition WHILE parsing it?
// This would reduce the overall memory usage as we won't have to load the whole compressed string
// XContentParser supports streams.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Base64;
import java.util.Locale;
import java.util.Objects;
import java.util.concurrent.ExecutorService;

import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;

Expand All @@ -38,13 +39,15 @@ public class PyTorchStateStreamer {
private static final Logger logger = LogManager.getLogger(PyTorchStateStreamer.class);

private final OriginSettingClient client;
private final ExecutorService executorService;
private final NamedXContentRegistry xContentRegistry;
private volatile boolean isCancelled;
private boolean modelSizeWritten = false;

public PyTorchStateStreamer(Client client, NamedXContentRegistry xContentRegistry) {
public PyTorchStateStreamer(Client client, ExecutorService executorService, NamedXContentRegistry xContentRegistry) {
this.client = new OriginSettingClient(Objects.requireNonNull(client), ML_ORIGIN);
this.xContentRegistry = xContentRegistry;
this.executorService = Objects.requireNonNull(executorService);
this.xContentRegistry = Objects.requireNonNull(xContentRegistry);
}

/**
Expand All @@ -63,7 +66,7 @@ public void cancel() {
* @param listener error and success listener
*/
public void writeStateToStream(String modelId, OutputStream restoreStream, ActionListener<Boolean> listener) {
ChunkedTrainedModelRestorer restorer = new ChunkedTrainedModelRestorer(modelId, client, xContentRegistry);
ChunkedTrainedModelRestorer restorer = new ChunkedTrainedModelRestorer(modelId, client, executorService, xContentRegistry);
// TODO cancel loading
restorer.restoreModelDefinition(doc -> writeChunk(doc, restoreStream), listener::onResponse, listener::onFailure);

Expand Down

0 comments on commit 97f959e

Please sign in to comment.