Skip to content

Commit

Permalink
Fix no response issue in functional test
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Jan 26, 2024
1 parent 5e2d98b commit d398d48
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput;

Expand Down Expand Up @@ -85,6 +86,7 @@ public void onStream(Publisher<ByteBuffer> stream) {
subscription.request(Long.MAX_VALUE);
}
@Override public void onError(Throwable t) {
countDownLatch.getCountDownLatch().countDown();
log.error("Error on receiving response body from remote: {}", t instanceof NullPointerException ? "NullPointerException" : t.getMessage(), t);
errorMsg.add("Error on receiving response body from remote: " + (t instanceof NullPointerException ? "NullPointerException" : t.getMessage()));
if (countDownLatch.getCountDownLatch().getCount() == 0) {
Expand All @@ -96,15 +98,16 @@ public void onStream(Publisher<ByteBuffer> stream) {

@Override
public void onComplete() {
countDownLatch.getCountDownLatch().countDown();
try {
String fullResponseBody = responseBody.toString();
processResponse(statusCode, fullResponseBody, parameters, tensorOutputs);
countDownLatch.getCountDownLatch().countDown();
if (countDownLatch.getCountDownLatch().getCount() == 0) {
log.debug("All responses received, calling action listener to return final results.");
actionListener.onResponse(reOrderTensorResponses(tensorOutputs));
}
} catch (Throwable e) {
countDownLatch.getCountDownLatch().countDown();
log.error("Error on processing response from remote: {}", e instanceof NullPointerException ? "NullPointerException" : e.getMessage(), e);
errorMsg.add("Error on receiving response from remote: " + (e instanceof NullPointerException ? "NullPointerException" : e.getMessage()));
if (countDownLatch.getCountDownLatch().getCount() == 0) {
Expand Down Expand Up @@ -142,7 +145,8 @@ private void processResponse(Integer statusCode, String body, Map<String, String

private List<ModelTensors> reOrderTensorResponses(Map<Integer, ModelTensors> tensorOutputs) {
List<ModelTensors> modelTensors = new ArrayList<>();
for (Map.Entry<Integer, ModelTensors> entry : tensorOutputs.entrySet()) {
TreeMap<Integer, ModelTensors> sortedMap = new TreeMap<>(tensorOutputs);
for (Map.Entry<Integer, ModelTensors> entry : sortedMap.entrySet()) {
modelTensors.add(entry.getKey(), entry.getValue());
}
return modelTensors;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,24 @@ default void executePredict(MLInput mlInput, ActionListener<MLTaskResponse> acti
ActionListener<List<ModelTensors>> tensorActionListener = ActionListener.wrap(r -> {
actionListener.onResponse(new MLTaskResponse(new ModelTensorOutput(r)));
}, actionListener::onFailure);
Map<Integer, ModelTensors> modelTensorsQueue = new ConcurrentHashMap<>();
Map<Integer, ModelTensors> modelTensors = new ConcurrentHashMap<>();
if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset();
Tuple<Integer, Integer> calculatedChunkSize = calculateChunkSize(textDocsInputDataSet);
CountDownLatch countDownLatch = new CountDownLatch(calculatedChunkSize.v1());
int sequence = 0;
for (int processedDocs = 0; processedDocs < calculatedChunkSize.v1(); processedDocs = processedDocs + calculatedChunkSize.v2()) {
for (int processedDocs = 0; processedDocs < textDocsInputDataSet.getDocs().size(); processedDocs += calculatedChunkSize.v2()) {
List<String> textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size());
preparePayloadAndInvokeRemoteModel(
MLInput
.builder()
.algorithm(FunctionName.TEXT_EMBEDDING)
.inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build())
.build(),
modelTensorsQueue, new WrappedCountDownLatch(sequence++, countDownLatch) , tensorActionListener);
modelTensors, new WrappedCountDownLatch(sequence++, countDownLatch) , tensorActionListener);
}
} else {
preparePayloadAndInvokeRemoteModel(mlInput, modelTensorsQueue, new WrappedCountDownLatch(0, new CountDownLatch(1)), tensorActionListener);
preparePayloadAndInvokeRemoteModel(mlInput, modelTensors, new WrappedCountDownLatch(0, new CountDownLatch(1)), tensorActionListener);
}
}

Expand Down

0 comments on commit d398d48

Please sign in to comment.