Skip to content

Commit

Permalink
feat(controller): add run info into taskVo (#2853)
Browse files Browse the repository at this point in the history
  • Loading branch information
anda-ren authored Oct 16, 2023
1 parent 85e592b commit 62e4890
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 28 deletions.
49 changes: 25 additions & 24 deletions client/starwhale/base/client/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1341,6 +1341,26 @@ class Config:
data: JobVo


class Status1(Enum):
pending = 'PENDING'
running = 'RUNNING'
finished = 'FINISHED'
failed = 'FAILED'


class RunVo(BaseModel):
class Config:
allow_population_by_field_name = True

id: Optional[int] = None
task_id: Optional[int] = Field(None, alias='taskId')
status: Optional[Status1] = None
ip: Optional[str] = None
start_time: Optional[int] = Field(None, alias='startTime')
finish_time: Optional[int] = Field(None, alias='finishTime')
failed_reason: Optional[str] = Field(None, alias='failedReason')


class TaskStatus(Enum):
created = 'CREATED'
ready = 'READY'
Expand Down Expand Up @@ -1370,6 +1390,7 @@ class Config:
step_name: str = Field(..., alias='stepName')
exposed_links: Optional[List[ExposedLinkVo]] = Field(None, alias='exposedLinks')
failed_reason: Optional[str] = Field(None, alias='failedReason')
runs: Optional[List[RunVo]] = None


class ResponseMessageTaskVo(BaseModel):
Expand All @@ -1381,24 +1402,13 @@ class Config:
data: TaskVo


class Status1(Enum):
pending = 'PENDING'
running = 'RUNNING'
finished = 'FINISHED'
failed = 'FAILED'


class RunVo(BaseModel):
class ResponseMessageListRunVo(BaseModel):
class Config:
allow_population_by_field_name = True

id: Optional[int] = None
task_id: Optional[int] = Field(None, alias='taskId')
status: Optional[Status1] = None
ip: Optional[str] = None
start_time: Optional[int] = Field(None, alias='startTime')
finish_time: Optional[int] = Field(None, alias='finishTime')
failed_reason: Optional[str] = Field(None, alias='failedReason')
code: str
message: str
data: List[RunVo]


class EventVo(BaseModel):
Expand Down Expand Up @@ -2078,15 +2088,6 @@ class Config:
data: PageInfoTaskVo


class ResponseMessageListRunVo(BaseModel):
class Config:
allow_population_by_field_name = True

code: str
message: str
data: List[RunVo]


class Graph(BaseModel):
class Config:
allow_population_by_field_name = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package ai.starwhale.mlops.api.protocol.task;

import ai.starwhale.mlops.api.protocol.job.ExposedLinkVo;
import ai.starwhale.mlops.api.protocol.run.RunVo;
import ai.starwhale.mlops.domain.task.status.TaskStatus;
import io.swagger.v3.oas.annotations.media.Schema;
import java.io.Serializable;
Expand Down Expand Up @@ -58,4 +59,6 @@ public class TaskVo implements Serializable {
private List<ExposedLinkVo> exposedLinks;

private String failedReason;

private List<RunVo> runs;
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import ai.starwhale.mlops.domain.job.step.ExposedType;
import ai.starwhale.mlops.domain.job.step.mapper.StepMapper;
import ai.starwhale.mlops.domain.job.step.po.StepEntity;
import ai.starwhale.mlops.domain.run.RunService;
import ai.starwhale.mlops.domain.system.resourcepool.bo.ResourcePool;
import ai.starwhale.mlops.domain.task.po.TaskEntity;
import ai.starwhale.mlops.domain.task.status.TaskStatus;
Expand All @@ -45,18 +46,22 @@ public class TaskConverter {
private final WebServerInTask webServerInTask;
private final JobSpecParser jobSpecParser;

private final RunService runService;


public TaskConverter(
IdConverter idConvertor, StepMapper stepMapper,
@Value("${sw.task.dev-port}") int devPort,
WebServerInTask webServerInTask,
JobSpecParser jobSpecParser
JobSpecParser jobSpecParser,
RunService runService
) {
this.idConvertor = idConvertor;
this.stepMapper = stepMapper;
this.devPort = devPort;
this.webServerInTask = webServerInTask;
this.jobSpecParser = jobSpecParser;
this.runService = runService;
}

public TaskVo convert(TaskEntity entity) {
Expand Down Expand Up @@ -110,6 +115,7 @@ public TaskVo convert(TaskEntity entity) {
.startedTime(entity.getStartedTime() == null ? null : entity.getStartedTime().getTime())
.resourcePool(pool)
.failedReason(entity.getFailedReason())
.runs(runService.runOfTask(entity.getId()))
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import ai.starwhale.mlops.api.protocol.job.ExposedLinkVo;
import ai.starwhale.mlops.api.protocol.run.RunVo;
import ai.starwhale.mlops.api.protocol.task.TaskVo;
import ai.starwhale.mlops.common.IdConverter;
import ai.starwhale.mlops.common.proxy.WebServerInTask;
Expand All @@ -31,11 +33,13 @@
import ai.starwhale.mlops.domain.job.step.ExposedType;
import ai.starwhale.mlops.domain.job.step.mapper.StepMapper;
import ai.starwhale.mlops.domain.job.step.po.StepEntity;
import ai.starwhale.mlops.domain.run.RunService;
import ai.starwhale.mlops.domain.task.converter.TaskConverter;
import ai.starwhale.mlops.domain.task.po.TaskEntity;
import ai.starwhale.mlops.domain.task.status.TaskStatus;
import ai.starwhale.mlops.exception.SwProcessException;
import java.util.Date;
import java.util.List;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand All @@ -47,6 +51,8 @@ public class TaskConverterTest {
private WebServerInTask webServerInTask;
private JobSpecParser jobSpecParser;

private RunService runService;

@BeforeEach
public void setup() {
stepMapper = mock(StepMapper.class);
Expand All @@ -57,7 +63,15 @@ public void setup() {
}
});
webServerInTask = mock(WebServerInTask.class);
taskConvertor = new TaskConverter(new IdConverter(), stepMapper, 8000, webServerInTask, jobSpecParser);
runService = mock(RunService.class);
when(runService.runOfTask(any())).thenReturn(List.of(new RunVo() {
{
setId(3L);
}
}));
taskConvertor = new TaskConverter(
new IdConverter(), stepMapper, 8000, webServerInTask, jobSpecParser, runService
);
}


Expand Down Expand Up @@ -105,6 +119,7 @@ public void testValidData() {
Assertions.assertEquals(taskVo.getStartedTime(), taskEntity.getStartedTime().getTime());
Assertions.assertEquals(taskVo.getStepName(), "ppl");
Assertions.assertEquals(taskVo.getRetryNum(), taskEntity.getRetryNum());
Assertions.assertEquals(3L, taskVo.getRuns().get(0).getId());

var expectedExposedLink = ExposedLinkVo.builder()
.type(ExposedType.DEV_MODE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import ai.starwhale.mlops.domain.job.spec.JobSpecParser;
import ai.starwhale.mlops.domain.job.step.mapper.StepMapper;
import ai.starwhale.mlops.domain.job.step.po.StepEntity;
import ai.starwhale.mlops.domain.run.RunService;
import ai.starwhale.mlops.domain.system.resourcepool.bo.ResourcePool;
import ai.starwhale.mlops.domain.task.converter.TaskConverter;
import ai.starwhale.mlops.domain.task.mapper.TaskMapper;
Expand Down Expand Up @@ -68,8 +69,8 @@ public void setup() {
stepMapper,
8000,
mock(WebServerInTask.class),
mock(JobSpecParser.class)
);
mock(JobSpecParser.class),
mock(RunService.class));
taskMapper = mock(TaskMapper.class);
storageAccessService = mock(StorageAccessService.class);
jobDao = mock(JobDao.class);
Expand Down Expand Up @@ -99,9 +100,11 @@ public void testListTaskWithJobResourcePool() {
assertThat(taskVoPageInfo.getList(), containsInAnyOrder(
TaskVo.builder().id("1").startedTime(startedTime.getTime()).finishedTime(finishedTime.getTime())
.uuid("uuid1")
.runs(List.of())
.taskStatus(TaskStatus.RUNNING).resourcePool("a").stepName("ppl").build(),
TaskVo.builder().id("2").startedTime(startedTime.getTime()).finishedTime(finishedTime.getTime())
.uuid("uuid2")
.runs(List.of())
.taskStatus(TaskStatus.SUCCESS).resourcePool("a").stepName("ppl").build()));
}

Expand Down Expand Up @@ -136,9 +139,11 @@ public void testListTaskWithStepResourcePool() throws IOException {
assertThat(taskVoPageInfo.getList(), containsInAnyOrder(
TaskVo.builder().id("1").startedTime(startedTime.getTime()).finishedTime(finishedTime.getTime())
.uuid("uuid1")
.runs(List.of())
.taskStatus(TaskStatus.RUNNING).resourcePool("job from step").stepName("ppl").build(),
TaskVo.builder().id("2").startedTime(startedTime.getTime()).finishedTime(finishedTime.getTime())
.uuid("uuid2")
.runs(List.of())
.taskStatus(TaskStatus.SUCCESS).resourcePool("job from step").stepName("ppl").build()));
}

Expand All @@ -162,6 +167,7 @@ public void testGetTask() {
.finishedTime(finishedTime.getTime())
.taskStatus(TaskStatus.RUNNING)
.uuid("uuid1")
.runs(List.of())
.stepName("ppl")
.resourcePool("")
.build();
Expand Down

0 comments on commit 62e4890

Please sign in to comment.