Skip to content

Commit

Permalink
feat(controller): add list api for run (#2815)
Browse files Browse the repository at this point in the history
  • Loading branch information
anda-ren authored Oct 8, 2023
1 parent 41e1947 commit 824ae66
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 7 deletions.
33 changes: 31 additions & 2 deletions client/starwhale/base/client/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,6 +1337,26 @@ class Config:
data: TaskVo


class Status1(Enum):
pending = 'PENDING'
running = 'RUNNING'
finished = 'FINISHED'
failed = 'FAILED'


class RunVo(BaseModel):
class Config:
allow_population_by_field_name = True

id: Optional[int] = None
task_id: Optional[int] = Field(None, alias='taskId')
status: Optional[Status1] = None
ip: Optional[str] = None
start_time: Optional[int] = Field(None, alias='startTime')
finish_time: Optional[int] = Field(None, alias='finishTime')
failed_reason: Optional[str] = Field(None, alias='failedReason')


class EventVo(BaseModel):
class Config:
allow_population_by_field_name = True
Expand Down Expand Up @@ -1537,7 +1557,7 @@ class Config:
data: PageInfoDatasetVersionVo


class Status1(Enum):
class Status2(Enum):
created = 'CREATED'
ready = 'READY'
assigning = 'ASSIGNING'
Expand Down Expand Up @@ -1566,7 +1586,7 @@ class Config:
project_id: str = Field(..., alias='projectId')
task_id: str = Field(..., alias='taskId')
dataset_name: str = Field(..., alias='datasetName')
status: Status1
status: Status2
type: Type4
create_time: int = Field(..., alias='createTime')

Expand Down Expand Up @@ -2014,6 +2034,15 @@ class Config:
data: PageInfoTaskVo


class ResponseMessageListRunVo(BaseModel):
class Config:
allow_population_by_field_name = True

code: str
message: str
data: List[RunVo]


class Graph(BaseModel):
class Config:
allow_population_by_field_name = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import ai.starwhale.mlops.api.protocol.job.ModelServingStatusVo;
import ai.starwhale.mlops.api.protocol.job.ModelServingVo;
import ai.starwhale.mlops.api.protocol.job.RuntimeSuggestionVo;
import ai.starwhale.mlops.api.protocol.run.RunVo;
import ai.starwhale.mlops.api.protocol.task.TaskVo;
import ai.starwhale.mlops.common.IdConverter;
import ai.starwhale.mlops.common.InvokerManager;
Expand All @@ -41,6 +42,7 @@
import ai.starwhale.mlops.domain.job.JobServiceForWeb;
import ai.starwhale.mlops.domain.job.ModelServingService;
import ai.starwhale.mlops.domain.job.RuntimeSuggestionService;
import ai.starwhale.mlops.domain.run.RunService;
import ai.starwhale.mlops.domain.task.TaskService;
import ai.starwhale.mlops.exception.SwProcessException;
import ai.starwhale.mlops.exception.SwProcessException.ErrorType;
Expand Down Expand Up @@ -85,6 +87,8 @@ public class JobController {
private final FeaturesProperties featuresProperties;
private final EventService eventService;

private final RunService runService;

public JobController(
JobServiceForWeb jobServiceForWeb,
TaskService taskService,
Expand All @@ -93,7 +97,8 @@ public JobController(
IdConverter idConvertor,
DagQuerier dagQuerier,
FeaturesProperties featuresProperties,
EventService eventService
EventService eventService,
RunService runService
) {
this.jobServiceForWeb = jobServiceForWeb;
this.taskService = taskService;
Expand All @@ -103,6 +108,7 @@ public JobController(
this.dagQuerier = dagQuerier;
this.featuresProperties = featuresProperties;
this.eventService = eventService;
this.runService = runService;
var actions = InvokerManager.<String, String>create()
.addInvoker("cancel", jobServiceForWeb::cancelJob);
if (featuresProperties.isJobPauseEnabled()) {
Expand Down Expand Up @@ -172,6 +178,18 @@ public ResponseEntity<ResponseMessage<TaskVo>> getTask(
return ResponseEntity.ok(Code.success.asResponse(taskService.getTask(taskUrl)));
}

@Operation(summary = "Get runs info")
@GetMapping(value = "/project/{projectUrl}/job/{jobUrl}/task/{taskId}/run",
produces = MediaType.APPLICATION_JSON_VALUE)
@PreAuthorize("hasAnyRole('OWNER', 'MAINTAINER', 'GUEST')")
public ResponseEntity<ResponseMessage<List<RunVo>>> getRuns(
@PathVariable String projectUrl,
@PathVariable String jobUrl,
@PathVariable Long taskId
) {
return ResponseEntity.ok(Code.success.asResponse(runService.runOfTask(taskId)));
}

@Operation(summary = "Create a new job")
@PostMapping(value = "/project/{projectUrl}/job", produces = MediaType.APPLICATION_JSON_VALUE)
@PreAuthorize("hasAnyRole('OWNER', 'MAINTAINER')")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright 2022 Starwhale, Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.starwhale.mlops.api.protocol.run;

import ai.starwhale.mlops.domain.run.bo.Run;
import ai.starwhale.mlops.domain.run.bo.RunStatus;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;


@Data
@AllArgsConstructor
@NoArgsConstructor
public class RunVo {

private Long id;
private Long taskId;
private RunStatus status;
private String ip;
private Long startTime;
private Long finishTime;
private String failedReason;

public RunVo(Run run) {
if (null == run) {
return;
}
this.id = run.getId();
this.taskId = run.getTaskId();
this.status = run.getStatus();
this.ip = run.getIp();
this.startTime = run.getStartTime();
this.finishTime = run.getFinishTime();
this.failedReason = run.getFailedReason();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
import ai.starwhale.mlops.domain.run.mapper.RunMapper;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.Date;
import java.util.List;
import java.util.stream.Collectors;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;

@Service
public class RunDao {
Expand All @@ -39,6 +42,14 @@ public Run findById(Long id) {
return convertEntityToBo(runEntity);
}

public List<Run> findByTaskId(Long taskId) {
List<RunEntity> list = runMapper.list(taskId);
if (CollectionUtils.isEmpty(list)) {
return List.of();
}
return list.stream().map(this::convertEntityToBo).collect(Collectors.toList());
}

public Run convertEntityToBo(RunEntity runEntity) {
Run run = Run.builder()
.id(runEntity.getId())
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright 2022 Starwhale, Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.starwhale.mlops.domain.run;

import ai.starwhale.mlops.api.protocol.run.RunVo;
import java.util.List;
import java.util.stream.Collectors;
import org.springframework.stereotype.Service;

@Service
public class RunService {

private final RunDao runDao;

public RunService(RunDao runDao) {
this.runDao = runDao;
}

public List<RunVo> runOfTask(Long taskId) {
return runDao.findByTaskId(taskId).stream().map(r -> new RunVo(r)).collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import ai.starwhale.mlops.domain.job.JobServiceForWeb;
import ai.starwhale.mlops.domain.job.ModelServingService;
import ai.starwhale.mlops.domain.job.RuntimeSuggestionService;
import ai.starwhale.mlops.domain.run.RunService;
import ai.starwhale.mlops.domain.task.TaskService;
import ai.starwhale.mlops.exception.api.StarwhaleApiException;
import ai.starwhale.mlops.schedule.impl.k8s.ResourceEventHolder;
Expand Down Expand Up @@ -96,8 +97,8 @@ public void setUp() {
new IdConverter(),
dagQuerier,
featuresProperties,
mock(EventService.class)
);
mock(EventService.class),
mock(RunService.class));
}

@Test
Expand Down Expand Up @@ -307,7 +308,8 @@ public void testJobPauseDisabled() {
new IdConverter(),
dagQuerier,
featuresProperties,
mock(EventService.class)
mock(EventService.class),
mock(RunService.class)
);
assertThrows(StarwhaleApiException.class,
() -> controller.action("", "job1", "pause"));
Expand All @@ -328,7 +330,8 @@ public void testJobResumeDisabled() {
new IdConverter(),
dagQuerier,
featuresProperties,
mock(EventService.class)
mock(EventService.class),
mock(RunService.class)
);
assertThrows(StarwhaleApiException.class,
() -> controller.action("", "job1", "resume"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import ai.starwhale.mlops.domain.run.bo.Run;
import ai.starwhale.mlops.domain.run.mapper.RunMapper;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.List;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

Expand All @@ -45,6 +46,13 @@ void findById() {
assertEquals(1L, run.getId());
}

@Test
void findByTaskId() {
when(runMapper.list(1L)).thenReturn(List.of(RunEntity.builder().id(2L).build()));
List<Run> runs = runDao.findByTaskId(1L);
assertEquals(2L, runs.get(0).getId());
}

@Test
void convertEntityToBo() {
RunEntity runEntity = RunEntity.builder().id(1L).build();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright 2022 Starwhale, Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.starwhale.mlops.domain.run;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import ai.starwhale.mlops.api.protocol.run.RunVo;
import ai.starwhale.mlops.domain.run.bo.Run;
import ai.starwhale.mlops.domain.run.bo.RunSpec;
import ai.starwhale.mlops.domain.run.bo.RunStatus;
import java.util.List;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

class RunServiceTest {

private RunDao runDao;

private RunService runService;

@BeforeEach
void init() {
runDao = mock(RunDao.class);
runService = new RunService(runDao);
when(runDao.findByTaskId(any())).thenReturn(List.of(
Run.builder()
.id(1L)
.status(RunStatus.RUNNING)
.logDir("anu")
.runSpec(RunSpec.builder().build())
.ip("ip")
.taskId(1L)
.startTime(122334455L)
.finishTime(122334456L)
.failedReason("fr")
.build()
));
}

@Test
void runOfTask() {
List<RunVo> runVos = runService.runOfTask(1L);
Assertions.assertEquals(1, runVos.size());
Assertions.assertEquals(1L, runVos.get(0).getId());
Assertions.assertEquals(RunStatus.RUNNING, runVos.get(0).getStatus());
Assertions.assertEquals("ip", runVos.get(0).getIp());
Assertions.assertEquals(122334455L, runVos.get(0).getStartTime());
Assertions.assertEquals(122334456L, runVos.get(0).getFinishTime());
Assertions.assertEquals("fr", runVos.get(0).getFailedReason());
}
}

0 comments on commit 824ae66

Please sign in to comment.