diff --git a/common/src/main/java/org/opensearch/ml/common/output/MLPredictionOutput.java b/common/src/main/java/org/opensearch/ml/common/output/MLPredictionOutput.java index 1b8eb8bd6c..830a19623a 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/MLPredictionOutput.java +++ b/common/src/main/java/org/opensearch/ml/common/output/MLPredictionOutput.java @@ -30,6 +30,8 @@ public class MLPredictionOutput extends MLOutput { public static final String TASK_ID_FIELD = "task_id"; public static final String STATUS_FIELD = "status"; public static final String PREDICTION_RESULT_FIELD = "prediction_result"; + + // This field will be created for offline batch prediction tasks containing details of the batch job as outputted by the remote server. public static final String REMOTE_JOB_FIELD = "remote_job"; String taskId; diff --git a/common/src/test/java/org/opensearch/ml/common/output/MLPredictionOutputTest.java b/common/src/test/java/org/opensearch/ml/common/output/MLPredictionOutputTest.java index 857e92f5a3..49a1e9355b 100644 --- a/common/src/test/java/org/opensearch/ml/common/output/MLPredictionOutputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/output/MLPredictionOutputTest.java @@ -9,7 +9,9 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.junit.Before; import org.junit.Test; @@ -30,6 +32,7 @@ public class MLPredictionOutputTest { MLPredictionOutput output; + MLPredictionOutput outputWithRemoteJob; @Before public void setUp() { @@ -38,12 +41,17 @@ public void setUp() { rows.add(new Row(new ColumnValue[] { new IntValue(1) })); rows.add(new Row(new ColumnValue[] { new IntValue(2) })); DataFrame dataFrame = new DefaultDataFrame(columnMetas, rows); + Map remoteJob = new HashMap<>(); + remoteJob.put("status", "INPROGRESS"); + remoteJob.put("job_id", "testJobID"); output = MLPredictionOutput.builder().taskId("test_task_id").status("test_status").predictionResult(dataFrame).build(); + outputWithRemoteJob = new MLPredictionOutput("test_task_id", "test_status", remoteJob); } @Test public void toXContent() throws IOException { XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + XContentBuilder builderWithRemoteJob = MediaTypeRegistry.contentBuilder(XContentType.JSON); output.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); assertEquals( @@ -53,6 +61,12 @@ public void toXContent() throws IOException { + "\"value\":2}]}]}}", jsonStr ); + outputWithRemoteJob.toXContent(builderWithRemoteJob, ToXContent.EMPTY_PARAMS); + String jsonStr2 = builderWithRemoteJob.toString(); + assertEquals( + "{\"task_id\":\"test_task_id\",\"status\":\"test_status\",\"remote_job\":{\"job_id\":\"testJobID\",\"status\":\"INPROGRESS\"}}", + jsonStr2 + ); } @Test