Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <[email protected]>
  • Loading branch information
rbhavna committed Oct 9, 2024
1 parent 17cbad0 commit 0bdc387
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ public class MLPredictionOutput extends MLOutput {
public static final String TASK_ID_FIELD = "task_id";
public static final String STATUS_FIELD = "status";
public static final String PREDICTION_RESULT_FIELD = "prediction_result";

// This field will be created for offline batch prediction tasks containing details of the batch job as outputted by the remote server.
public static final String REMOTE_JOB_FIELD = "remote_job";

String taskId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.junit.Before;
import org.junit.Test;
Expand All @@ -30,6 +32,7 @@
public class MLPredictionOutputTest {

MLPredictionOutput output;
MLPredictionOutput outputWithRemoteJob;

@Before
public void setUp() {
Expand All @@ -38,12 +41,17 @@ public void setUp() {
rows.add(new Row(new ColumnValue[] { new IntValue(1) }));
rows.add(new Row(new ColumnValue[] { new IntValue(2) }));
DataFrame dataFrame = new DefaultDataFrame(columnMetas, rows);
Map<String, Object> remoteJob = new HashMap<>();
remoteJob.put("status", "INPROGRESS");
remoteJob.put("job_id", "testJobID");
output = MLPredictionOutput.builder().taskId("test_task_id").status("test_status").predictionResult(dataFrame).build();
outputWithRemoteJob = new MLPredictionOutput("test_task_id", "test_status", remoteJob);
}

@Test
public void toXContent() throws IOException {
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
XContentBuilder builderWithRemoteJob = MediaTypeRegistry.contentBuilder(XContentType.JSON);
output.toXContent(builder, ToXContent.EMPTY_PARAMS);
String jsonStr = builder.toString();
assertEquals(
Expand All @@ -53,6 +61,12 @@ public void toXContent() throws IOException {
+ "\"value\":2}]}]}}",
jsonStr
);
outputWithRemoteJob.toXContent(builderWithRemoteJob, ToXContent.EMPTY_PARAMS);
String jsonStr2 = builderWithRemoteJob.toString();
assertEquals(
"{\"task_id\":\"test_task_id\",\"status\":\"test_status\",\"remote_job\":{\"job_id\":\"testJobID\",\"status\":\"INPROGRESS\"}}",
jsonStr2
);
}

@Test
Expand Down

0 comments on commit 0bdc387

Please sign in to comment.