Skip to content

Commit

Permalink
read function Name from pretrained model
Browse files Browse the repository at this point in the history
Signed-off-by: xinyual <[email protected]>
  • Loading branch information
xinyual committed Oct 18, 2023
1 parent e9e3834 commit e0b2487
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ public void downloadPrebuiltModelConfig(String taskId, MLRegisterModelInput regi
.url(modelZipFileUrl)
.deployModel(deployModel)
.modelNodeIds(modelNodeIds)
.modelGroupId(modelGroupId);
.modelGroupId(modelGroupId)
.functionName(FunctionName.from((String) config.get("model_task_type")));;
config.entrySet().forEach(entry -> {
switch (entry.getKey().toString()) {
case MLRegisterModelInput.MODEL_FORMAT_FIELD:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ private void uploadModel(MLRegisterModelInput registerModelInput, MLTask mlTask,

private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTask mlTask, String modelVersion) {
String taskId = mlTask.getTaskId();
FunctionName functionName = mlTask.getFunctionName();
FunctionName functionName = registerModelInput.getFunctionName();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
String modelName = registerModelInput.getModelName();
String version = modelVersion == null ? registerModelInput.getVersion() : modelVersion;
Expand Down

0 comments on commit e0b2487

Please sign in to comment.