diff --git a/models/actions/task.go b/models/actions/task.go index b62a0c351b99b..af74faf937e5f 100644 --- a/models/actions/task.go +++ b/models/actions/task.go @@ -341,7 +341,7 @@ func UpdateTask(ctx context.Context, task *ActionTask, cols ...string) error { // UpdateTaskByState updates the task by the state. // It will always update the task if the state is not final, even there is no change. // So it will update ActionTask.Updated to avoid the task being judged as a zombie task. -func UpdateTaskByState(ctx context.Context, state *runnerv1.TaskState) (*ActionTask, error) { +func UpdateTaskByState(ctx context.Context, runnerID int64, state *runnerv1.TaskState) (*ActionTask, error) { stepStates := map[int64]*runnerv1.StepState{} for _, v := range state.Steps { stepStates[v.Id] = v @@ -360,6 +360,8 @@ func UpdateTaskByState(ctx context.Context, state *runnerv1.TaskState) (*ActionT return nil, err } else if !has { return nil, util.ErrNotExist + } else if runnerID != task.RunnerID { + return nil, fmt.Errorf("invalid runner for task") } if task.Status.IsDone() { diff --git a/routers/api/actions/runner/runner.go b/routers/api/actions/runner/runner.go index d4078d8af22cf..8f365cc92670a 100644 --- a/routers/api/actions/runner/runner.go +++ b/routers/api/actions/runner/runner.go @@ -175,7 +175,9 @@ func (s *Service) UpdateTask( ctx context.Context, req *connect.Request[runnerv1.UpdateTaskRequest], ) (*connect.Response[runnerv1.UpdateTaskResponse], error) { - task, err := actions_model.UpdateTaskByState(ctx, req.Msg.State) + runner := GetRunner(ctx) + + task, err := actions_model.UpdateTaskByState(ctx, runner.ID, req.Msg.State) if err != nil { return nil, status.Errorf(codes.Internal, "update task: %v", err) } @@ -237,11 +239,15 @@ func (s *Service) UpdateLog( ctx context.Context, req *connect.Request[runnerv1.UpdateLogRequest], ) (*connect.Response[runnerv1.UpdateLogResponse], error) { + runner := GetRunner(ctx) + res := connect.NewResponse(&runnerv1.UpdateLogResponse{}) task, err := actions_model.GetTaskByID(ctx, req.Msg.TaskId) if err != nil { return nil, status.Errorf(codes.Internal, "get task: %v", err) + } else if runner.ID != task.RunnerID { + return nil, status.Errorf(codes.Internal, "invalid runner for task") } ack := task.LogLength