Skip to content

Commit

Permalink
return model id in registering remote model
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Sep 13, 2023
1 parent f5c20fb commit d21b032
Show file tree
Hide file tree
Showing 5 changed files with 229 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,46 @@
@Getter
public class MLRegisterModelResponse extends ActionResponse implements ToXContentObject {
public static final String TASK_ID_FIELD = "task_id";
public static final String MODEL_ID_FIELD = "model_id";
public static final String STATUS_FIELD = "status";

private String taskId;
private String status;
private String modelId;

public MLRegisterModelResponse(StreamInput in) throws IOException {
super(in);
this.taskId = in.readString();
this.status = in.readString();
this.modelId = in.readOptionalString();
}

public MLRegisterModelResponse(String taskId, String status) {
this.taskId = taskId;
this.status= status;
}

public MLRegisterModelResponse(String taskId, String status, String modelId) {
this.taskId = taskId;
this.status= status;
this.modelId = modelId;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(taskId);
out.writeString(status);
out.writeOptionalString(modelId);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(TASK_ID_FIELD, taskId);
builder.field(STATUS_FIELD, status);
if (modelId != null) {
builder.field(MODEL_ID_FIELD, modelId);
}
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,27 @@ public class MLRegisterModelResponseTest {

private String taskId;
private String status;
private String modelId;

@Before
public void setUp() throws Exception {
taskId = "test_id";
status = "test";
modelId = "model_id";
}

@Test
public void writeTo_Success() throws IOException {
// Setup
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
MLRegisterModelResponse response = new MLRegisterModelResponse(taskId, status);
MLRegisterModelResponse response = new MLRegisterModelResponse(taskId, status, modelId);
// Run the test
response.writeTo(bytesStreamOutput);
MLRegisterModelResponse parsedResponse = new MLRegisterModelResponse(bytesStreamOutput.bytes().streamInput());
// Verify the results
assertEquals(response.getTaskId(), parsedResponse.getTaskId());
assertEquals(response.getStatus(), parsedResponse.getStatus());
assertEquals(response.getModelId(), parsedResponse.getModelId());
}

@Test
Expand All @@ -49,4 +52,18 @@ public void testToXContent() throws IOException {
assertEquals("{\"task_id\":\"test_id\"," +
"\"status\":\"test\"}", jsonStr);
}

@Test
public void testToXContent_withModelId() throws IOException {
// Setup
MLRegisterModelResponse response = new MLRegisterModelResponse(taskId, status, modelId);
// Run the test
XContentBuilder builder = XContentFactory.jsonBuilder();
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertNotNull(builder);
String jsonStr = builder.toString();
// Verify the results
assertEquals("{\"task_id\":\"test_id\"," +
"\"status\":\"test\"," + "\"model_id\":\"model_id\"}", jsonStr);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen
throw new IllegalArgumentException("URL can't match trusted url regex");
}
}
System.out.println("registering the model");
boolean isAsync = registerModelInput.getFunctionName() != FunctionName.REMOTE;
MLTask mlTask = MLTask
.builder()
Expand All @@ -249,8 +250,8 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
String taskId = response.getId();
mlTask.setTaskId(taskId);
mlModelManager.registerMLModel(registerModelInput, mlTask);
listener.onResponse(new MLRegisterModelResponse(taskId, MLTaskState.CREATED.name()));
System.out.println("mlModelManager calls registerMLRemoteModel");
mlModelManager.registerMLRemoteModel(registerModelInput, mlTask, listener);
}, e -> {
logException("Failed to register model", e, log);
listener.onFailure(e);
Expand Down
Loading

0 comments on commit d21b032

Please sign in to comment.