From ec5be98802e5e024d7a2e1b4edf2559db5882528 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Fri, 29 Sep 2023 10:18:31 -0700 Subject: [PATCH] fix text docs input unescaped error; enable deploy remote model (#1407) Signed-off-by: Yaliang Wu (cherry picked from commit 6e0d949335c6238a28a77d31abe569897d2d88ad) --- .../ml/engine/algorithms/remote/ConnectorUtils.java | 11 +++++++++-- .../java/org/opensearch/ml/model/MLModelManager.java | 6 +++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index ac3f8a7eda..c481725057 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -81,7 +81,6 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto return inputData; } private static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDataSet inputDataSet, Connector connector, Map parameters, ScriptService scriptService) { - List docs = new ArrayList<>(inputDataSet.getDocs()); Optional predictAction = connector.findPredictAction(); if (predictAction.isEmpty()) { throw new IllegalArgumentException("no predict action found"); @@ -89,9 +88,17 @@ private static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDat String preProcessFunction = predictAction.get().getPreProcessFunction(); preProcessFunction = preProcessFunction == null ? MLPreProcessFunction.TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT : preProcessFunction; if (MLPreProcessFunction.contains(preProcessFunction)) { - Map buildInFunctionResult = MLPreProcessFunction.get(preProcessFunction).apply(docs); + Map buildInFunctionResult = MLPreProcessFunction.get(preProcessFunction).apply(inputDataSet.getDocs()); return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(buildInFunctionResult)).build(); } else { + List docs = new ArrayList<>(); + for (String doc : inputDataSet.getDocs()) { + if (doc != null) { + docs.add(gson.toJson(doc)); + } else { + docs.add(null); + } + } if (preProcessFunction.contains("${parameters")) { StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); preProcessFunction = substitutor.replace(preProcessFunction); diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index a2d962b444..090b82443f 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -519,9 +519,9 @@ private void indexRemoteModel( mlTask.setModelId(modelId); log.info("create new model meta doc {} for upload task {}", modelId, taskId); mlTaskManager.updateMLTask(taskId, ImmutableMap.of(MODEL_ID_FIELD, modelId, STATE_FIELD, COMPLETED), 5000, true); - // if (registerModelInput.isDeployModel()) { - // deployModelAfterRegistering(registerModelInput, modelId); - // } + if (registerModelInput.isDeployModel()) { + deployModelAfterRegistering(registerModelInput, modelId); + } listener.onResponse(new MLRegisterModelResponse(taskId, MLTaskState.CREATED.name(), modelId)); }, e -> { log.error("Failed to index model meta doc", e);