Skip to content

Commit

Permalink
less blocking
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Aug 19, 2024
1 parent 3555a91 commit 1a02a3c
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@
import org.elasticsearch.core.Nullable;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig;
import org.elasticsearch.xpack.ml.packageloader.MachineLearningPackageLoader;

import java.io.InputStream;
import java.net.URI;
import java.util.Objects;
import java.util.concurrent.Semaphore;
import java.util.concurrent.ExecutorService;

import static org.elasticsearch.core.Strings.format;

Expand All @@ -40,15 +42,18 @@ class ModelImporter {
private final String modelId;
private final ModelPackageConfig config;
private final ModelDownloadTask task;
private final ExecutorService executorService;

ModelImporter(Client client, String modelId, ModelPackageConfig packageConfig, ModelDownloadTask task) {
ModelImporter(Client client, String modelId, ModelPackageConfig packageConfig, ModelDownloadTask task, ThreadPool threadPool) {
this.client = client;
this.modelId = Objects.requireNonNull(modelId);
this.config = Objects.requireNonNull(packageConfig);
this.task = Objects.requireNonNull(task);
this.executorService = threadPool.executor(MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME);
}

public void doImport(ActionListener<AcknowledgedResponse> finalListener) {

long size = config.getSize();
// simple round up
int totalParts = (int) ((size + DEFAULT_CHUNK_SIZE - 1) / DEFAULT_CHUNK_SIZE);
Expand All @@ -73,8 +78,7 @@ public void doImport(ActionListener<AcknowledgedResponse> finalListener) {
}

downloadParts(
new ModelLoaderUtils.InputStreamChunker(modelInputStream, DEFAULT_CHUNK_SIZE),
totalParts,
new ModelLoaderUtils.InputStreamChunker(modelInputStream, DEFAULT_CHUNK_SIZE, totalParts),
size,
vocabularyParts,
finalListener
Expand All @@ -83,46 +87,87 @@ public void doImport(ActionListener<AcknowledgedResponse> finalListener) {

void downloadParts(
ModelLoaderUtils.InputStreamChunker chunkIterator,
int totalParts,
long size,
@Nullable ModelLoaderUtils.VocabularyParts vocabularyParts,
ActionListener<AcknowledgedResponse> finalListener
) {
var requestLimiter = new Semaphore(MAX_IN_FLIGHT_REQUESTS);

try (var countingListener = new RefCountingListener(1, finalListener.map(ignored -> {
checkDownloadComplete(chunkIterator, totalParts);
var countingListener = new RefCountingListener(1, finalListener.map(ignored -> {
checkDownloadComplete(chunkIterator);
return AcknowledgedResponse.TRUE;
}))) {
try {
// Uploading other artefacts of the model first, that way the model is last and a simple search can be used to check if the
// download is complete
if (vocabularyParts != null) {
requestLimiter.acquire();
uploadVocabulary(vocabularyParts, countingListener.acquire(r -> {
requestLimiter.release();
logger.debug(() -> format("[%s] imported model vocabulary [%s]", modelId, config.getVocabularyFile()));
}));
}));
try {
// Uploading other artefacts of the model first, that way the model is last and a simple search can be used to check if the
// download is complete
if (vocabularyParts != null) {
uploadVocabulary(vocabularyParts, countingListener.acquire(r -> {
logger.debug(() -> format("[%s] imported model vocabulary [%s]", modelId, config.getVocabularyFile()));
}));
}

for (int part = 0; part < MAX_IN_FLIGHT_REQUESTS; ++part) {
if (countingListener.isFailing()) {
break;
}

for (int part = 0; part < totalParts; ++part) {
if (countingListener.isFailing()) {
break;
}
task.setProgress(chunkIterator.getTotalParts(), chunkIterator.getCurrentPart().get());
BytesArray definition = chunkIterator.next();

task.setProgress(totalParts, part);
BytesArray definition = chunkIterator.next();
if (task.isCancelled()) {
throw new TaskCancelledException(format("task cancelled with reason [%s]", task.getReasonCancelled()));
}

if (task.isCancelled()) {
throw new TaskCancelledException(format("task cancelled with reason [%s]", task.getReasonCancelled()));
}
uploadPart(
part,
chunkIterator.getTotalParts(),
size,
definition,
countingListener.acquire(r -> executorService.execute(() -> doNextPart(size, chunkIterator, countingListener)))
);
}
} catch (Exception e) {
countingListener.acquire().onFailure(e);
countingListener.close();
}
}

requestLimiter.acquire();
uploadPart(part, totalParts, size, definition, countingListener.acquire(r -> requestLimiter.release()));
}
} catch (Exception e) {
countingListener.acquire().onFailure(e);
public void doNextPart(long size, ModelLoaderUtils.InputStreamChunker chunkIterator, RefCountingListener countingListener) {
if (countingListener.isFailing()) {
countingListener.close();
return;
}

task.setProgress(chunkIterator.getTotalParts(), chunkIterator.getCurrentPart().get());
try {
logger.info("doing next part " + chunkIterator.getCurrentPart().get() + ", " + chunkIterator.getTotalParts());
BytesArray definition = chunkIterator.next();

if (task.isCancelled()) {
throw new TaskCancelledException(format("task cancelled with reason [%s]", task.getReasonCancelled()));
}

if (definition.length() == 0) {
// done
return;
}

boolean lastPart = chunkIterator.isFinalPart();

uploadPart(
chunkIterator.getCurrentPart().get(),
chunkIterator.getTotalParts(),
size,
definition,
countingListener.acquire(r -> {
if (lastPart) {
countingListener.close();
} else {
executorService.execute(() -> doNextPart(size, chunkIterator, countingListener));
}
})
);
} catch (Exception e) {
countingListener.acquire().onFailure(e);
countingListener.close();
}
}

Expand Down Expand Up @@ -157,7 +202,7 @@ private void uploadPart(
client.execute(PutTrainedModelDefinitionPartAction.INSTANCE, modelPartRequest, listener);
}

private void checkDownloadComplete(ModelLoaderUtils.InputStreamChunker chunkIterator, int totalParts) {
private void checkDownloadComplete(ModelLoaderUtils.InputStreamChunker chunkIterator) {
if (config.getSha256().equals(chunkIterator.getSha256()) == false) {
String message = format(
"Model sha256 checksums do not match, expected [%s] but got [%s]",
Expand All @@ -178,6 +223,6 @@ private void checkDownloadComplete(ModelLoaderUtils.InputStreamChunker chunkIter
throw new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR);
}

logger.debug(format("finished importing model [%s] using [%d] parts", modelId, totalParts));
logger.debug(format("finished importing model [%s] using [%d] parts", modelId, chunkIterator.getTotalParts()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;

import static java.net.HttpURLConnection.HTTP_MOVED_PERM;
Expand Down Expand Up @@ -66,15 +68,19 @@ static class InputStreamChunker {
private final InputStream inputStream;
private final MessageDigest digestSha256 = MessageDigests.sha256();
private final int chunkSize;
private final int totalParts;
private final AtomicLong totalBytesRead = new AtomicLong();
private final AtomicInteger currentPart = new AtomicInteger(-1);

private long totalBytesRead = 0;

InputStreamChunker(InputStream inputStream, int chunkSize) {
InputStreamChunker(InputStream inputStream, int chunkSize, int totalParts) {
this.inputStream = inputStream;
this.chunkSize = chunkSize;
this.totalParts = totalParts;
}

public BytesArray next() throws IOException {
currentPart.incrementAndGet();

int bytesRead = 0;
byte[] buf = new byte[chunkSize];

Expand All @@ -87,17 +93,29 @@ public BytesArray next() throws IOException {
bytesRead += read;
}
digestSha256.update(buf, 0, bytesRead);
totalBytesRead += bytesRead;
totalBytesRead.addAndGet(bytesRead);

return new BytesArray(buf, 0, bytesRead);
}

public boolean isFinalPart() {
return currentPart.get() == totalParts - 1;
}

public String getSha256() {
return MessageDigests.toHexString(digestSha256.digest());
}

public long getTotalBytesRead() {
return totalBytesRead;
return totalBytesRead.get();
}

public int getTotalParts() {
return totalParts;
}

public AtomicInteger getCurrentPart() {
return currentPart;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ public void testSha256AndSize() throws IOException {
assertEquals(64, expectedDigest.length());

int chunkSize = randomIntBetween(100, 10_000);
int totalParts = (bytes.length + chunkSize - 1) / chunkSize;

ModelLoaderUtils.InputStreamChunker inputStreamChunker = new ModelLoaderUtils.InputStreamChunker(
new ByteArrayInputStream(bytes),
chunkSize
chunkSize,
totalParts
);

int totalParts = (bytes.length + chunkSize - 1) / chunkSize;

for (int part = 0; part < totalParts - 1; ++part) {
assertEquals(chunkSize, inputStreamChunker.next().length());
}
Expand Down

0 comments on commit 1a02a3c

Please sign in to comment.