Skip to content

Commit

Permalink
Execution event cluster validation (flyteorg#307)
Browse files Browse the repository at this point in the history
  • Loading branch information
Katrina Rogan authored Jan 5, 2022
1 parent 59d33f0 commit 81ac07f
Show file tree
Hide file tree
Showing 19 changed files with 397 additions and 107 deletions.
4 changes: 4 additions & 0 deletions pkg/common/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ const (
)

const MaxResponseStatusBytes = 32000

// DefaultProducerID is used in older versions of propeller which hard code this producer id.
// See https://github.com/flyteorg/flytepropeller/blob/eaf084934de5d630cd4c11aae15ecae780cc787e/pkg/controller/nodes/task/transformer.go#L114
const DefaultProducerID = "propeller"
15 changes: 15 additions & 0 deletions pkg/errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,18 @@ func NewAlreadyInTerminalStateError(ctx context.Context, errorMsg string, curPha
}
return statusErr
}

func NewIncompatibleClusterError(ctx context.Context, errorMsg, curCluster string) FlyteAdminError {
statusErr, transformationErr := NewFlyteAdminError(codes.FailedPrecondition, errorMsg).WithDetails(&admin.EventFailureReason{
Reason: &admin.EventFailureReason_IncompatibleCluster{
IncompatibleCluster: &admin.EventErrorIncompatibleCluster{
Cluster: curCluster,
},
},
})
if transformationErr != nil {
logger.Panicf(ctx, "Failed to wrap grpc status in type 'Error': %v", transformationErr)
return NewFlyteAdminErrorf(codes.FailedPrecondition, errorMsg)
}
return statusErr
}
16 changes: 16 additions & 0 deletions pkg/errors/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,19 @@ func TestGrpcStatusError(t *testing.T) {
_, ok = details.GetReason().(*admin.EventFailureReason_AlreadyInTerminalState)
assert.True(t, ok)
}

func TestNewIncompatibleClusterError(t *testing.T) {
errorMsg := "foo"
cluster := "C1"
statusErr := NewIncompatibleClusterError(context.Background(), errorMsg, cluster)
assert.NotNil(t, statusErr)
s, ok := status.FromError(statusErr)
assert.True(t, ok)
assert.Equal(t, codes.FailedPrecondition, s.Code())
assert.Equal(t, errorMsg, s.Message())

details, ok := s.Details()[0].(*admin.EventFailureReason)
assert.True(t, ok)
_, ok = details.GetReason().(*admin.EventFailureReason_IncompatibleCluster)
assert.True(t, ok)
}
8 changes: 7 additions & 1 deletion pkg/manager/impl/execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -1177,7 +1177,8 @@ func (m *ExecutionManager) CreateWorkflowEvent(ctx context.Context, request admi
}

wfExecPhase := core.WorkflowExecution_Phase(core.WorkflowExecution_Phase_value[executionModel.Phase])
if wfExecPhase == request.Event.Phase {
// Subsequent queued events announcing a cluster reassignment are permitted.
if wfExecPhase == request.Event.Phase && request.Event.Phase != core.WorkflowExecution_QUEUED {
logger.Debugf(ctx, "This phase %s was already recorded for workflow execution %v",
wfExecPhase.String(), request.Event.ExecutionId)
return nil, errors.NewFlyteAdminErrorf(codes.AlreadyExists,
Expand All @@ -1188,6 +1189,11 @@ func (m *ExecutionManager) CreateWorkflowEvent(ctx context.Context, request admi
curPhase := wfExecPhase.String()
errorMsg := fmt.Sprintf("Invalid phase change from %s to %s for workflow execution %v", curPhase, request.Event.Phase.String(), request.Event.ExecutionId)
return nil, errors.NewAlreadyInTerminalStateError(ctx, errorMsg, curPhase)
} else if wfExecPhase == core.WorkflowExecution_RUNNING && request.Event.Phase == core.WorkflowExecution_QUEUED {
// Cannot go back in time from RUNNING -> QUEUED
return nil, errors.NewFlyteAdminErrorf(codes.FailedPrecondition,
"Cannot go from %s to %s for workflow execution %v",
wfExecPhase.String(), request.Event.Phase.String(), request.Event.ExecutionId)
}

err = transformers.UpdateExecutionModelState(ctx, executionModel, request, m.config.ApplicationConfiguration().GetRemoteDataConfig().InlineEventDataPolicy, m.storageClient)
Expand Down
37 changes: 37 additions & 0 deletions pkg/manager/impl/execution_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1340,6 +1340,7 @@ func TestCreateWorkflowEvent(t *testing.T) {
OutputResult: &event.WorkflowExecutionEvent_Error{
Error: &executionError,
},
ProducerId: testCluster,
},
}
mockDbEventWriter := &eventWriterMocks.WorkflowExecutionEventWriter{}
Expand Down Expand Up @@ -1391,6 +1392,40 @@ func TestCreateWorkflowEvent_TerminalState(t *testing.T) {
assert.True(t, ok)
}

func TestCreateWorkflowEvent_NoRunningToQueued(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
executionGetFunc := func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) {
return models.Execution{
ExecutionKey: models.ExecutionKey{
Project: "project",
Domain: "domain",
Name: "name",
},
Spec: specBytes,
Phase: core.WorkflowExecution_RUNNING.String(),
}, nil
}

repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(executionGetFunc)
updateExecutionFunc := func(context context.Context, execution models.Execution) error {
return nil
}
repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetUpdateCallback(updateExecutionFunc)
execManager := NewExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{})

resp, err := execManager.CreateWorkflowEvent(context.Background(), admin.WorkflowExecutionEventRequest{
RequestId: "1",
Event: &event.WorkflowExecutionEvent{
ExecutionId: &executionIdentifier,
Phase: core.WorkflowExecution_QUEUED,
},
})
assert.Nil(t, resp)
assert.NotNil(t, err)
adminError := err.(flyteAdminErrors.FlyteAdminError)
assert.Equal(t, adminError.Code(), codes.FailedPrecondition)
}

func TestCreateWorkflowEvent_StartedRunning(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
occurredAt := time.Now().UTC()
Expand Down Expand Up @@ -1427,6 +1462,7 @@ func TestCreateWorkflowEvent_StartedRunning(t *testing.T) {
ExecutionId: &executionIdentifier,
OccurredAt: occurredAtTimestamp,
Phase: core.WorkflowExecution_RUNNING,
ProducerId: testCluster,
},
}
mockDbEventWriter := &eventWriterMocks.WorkflowExecutionEventWriter{}
Expand Down Expand Up @@ -1641,6 +1677,7 @@ func TestCreateWorkflowEvent_DatabaseUpdateError(t *testing.T) {
OutputResult: &event.WorkflowExecutionEvent_Error{
Error: &executionError,
},
ProducerId: testCluster,
},
})
assert.Nil(t, resp)
Expand Down
41 changes: 22 additions & 19 deletions pkg/manager/impl/node_execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,25 +83,6 @@ func getNodeExecutionContext(ctx context.Context, identifier *core.NodeExecution

func (m *NodeExecutionManager) createNodeExecutionWithEvent(
ctx context.Context, request *admin.NodeExecutionEventRequest, dynamicWorkflowRemoteClosureReference string) error {

executionID := request.Event.Id.ExecutionId
workflowExecutionExists, err := m.db.ExecutionRepo().Exists(ctx, repoInterfaces.Identifier{
Project: executionID.Project,
Domain: executionID.Domain,
Name: executionID.Name,
})
if err != nil || !workflowExecutionExists {
m.metrics.MissingWorkflowExecution.Inc()
logger.Debugf(ctx, "Failed to find existing execution with id [%+v] with err: %v", executionID, err)
if err != nil {
if ferr, ok := err.(errors.FlyteAdminError); ok {
return errors.NewFlyteAdminErrorf(ferr.Code(),
"Failed to get existing execution id: [%+v] with err: %v", executionID, err)
}
}
return fmt.Errorf("failed to get existing execution id: [%+v]", executionID)
}

var parentTaskExecutionID *uint
if request.Event.ParentTaskMetadata != nil {
taskExecutionModel, err := util.GetTaskExecutionModel(ctx, m.db, request.Event.ParentTaskMetadata.Id)
Expand Down Expand Up @@ -233,6 +214,28 @@ func (m *NodeExecutionManager) CreateNodeEvent(ctx context.Context, request admi
logger.Debugf(ctx, "Received node execution event for Node Exec Id [%+v] transitioning to phase [%v], w/ Metadata [%v]",
request.Event.Id, request.Event.Phase, request.Event.ParentTaskMetadata)

executionID := request.Event.Id.ExecutionId
workflowExecution, err := m.db.ExecutionRepo().Get(ctx, repoInterfaces.Identifier{
Project: executionID.Project,
Domain: executionID.Domain,
Name: executionID.Name,
})
if err != nil {
m.metrics.MissingWorkflowExecution.Inc()
logger.Debugf(ctx, "Failed to find existing execution with id [%+v] with err: %v", executionID, err)
if err != nil {
if ferr, ok := err.(errors.FlyteAdminError); ok {
return nil, errors.NewFlyteAdminErrorf(ferr.Code(),
"Failed to get existing execution id: [%+v] with err: %v", executionID, err)
}
}
return nil, fmt.Errorf("failed to get existing execution id: [%+v]", executionID)
}

if err := validation.ValidateCluster(ctx, workflowExecution.Cluster, request.Event.ProducerId); err != nil {
return nil, err
}

var dynamicWorkflowRemoteClosureReference string
if request.Event.GetTaskNodeMetadata() != nil && request.Event.GetTaskNodeMetadata().DynamicWorkflow != nil {
dynamicWorkflowRemoteClosureDataReference, err := m.uploadDynamicWorkflowClosure(
Expand Down
28 changes: 11 additions & 17 deletions pkg/manager/impl/node_execution_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ func addGetExecutionCallback(t *testing.T, repository repositories.RepositoryInt
Domain: "domain",
Name: "name",
},
Cluster: "propeller",
}, nil
})
}
Expand Down Expand Up @@ -225,26 +226,19 @@ func TestCreateNodeEvent_MissingExecution(t *testing.T) {
expectedErr := flyteAdminErrors.NewFlyteAdminErrorf(codes.Internal, "expected error")
repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetCallback(
func(ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) {
return models.NodeExecution{}, flyteAdminErrors.NewFlyteAdminError(codes.NotFound, "foo")
return models.NodeExecution{}, expectedErr
})
repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).ExistsFunction =
func(ctx context.Context, input interfaces.Identifier) (bool, error) {
return false, expectedErr
}
repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(
func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) {
return models.Execution{}, expectedErr
})

repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) {
return models.Execution{}, expectedErr
})
nodeExecManager := NewNodeExecutionManager(repository, getMockExecutionsConfigProvider(), make([]string, 0), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockNodeExecutionRemoteURL, &mockPublisher, &eventWriterMocks.NodeExecutionEventWriter{})
resp, err := nodeExecManager.CreateNodeEvent(context.Background(), request)
assert.EqualError(t, err, "Failed to get existing execution id: [project:\"project\""+
" domain:\"domain\" name:\"name\" ] with err: expected error")
assert.Nil(t, resp)

repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).ExistsFunction =
func(ctx context.Context, input interfaces.Identifier) (bool, error) {
return false, nil
}
nodeExecManager = NewNodeExecutionManager(repository, getMockExecutionsConfigProvider(), make([]string, 0), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockNodeExecutionRemoteURL, &mockPublisher, &eventWriterMocks.NodeExecutionEventWriter{})
resp, err = nodeExecManager.CreateNodeEvent(context.Background(), request)
assert.EqualError(t, err, "failed to get existing execution id: [project:\"project\""+
" domain:\"domain\" name:\"name\" ]")
assert.EqualError(t, err, "Failed to get existing execution id: [project:\"project\" domain:\"domain\" name:\"name\" ] with err: expected error")
assert.Nil(t, resp)
}

Expand Down
4 changes: 4 additions & 0 deletions pkg/manager/impl/task_execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ func (m *TaskExecutionManager) CreateTaskExecutionEvent(ctx context.Context, req
return nil, err
}

if err := validation.ValidateClusterForExecutionID(ctx, m.db, request.Event.ParentNodeExecutionId.ExecutionId, request.Event.ProducerId); err != nil {
return nil, err
}

// Get the parent node execution, if none found a MissingEntityError will be returned
nodeExecutionID := request.Event.ParentNodeExecutionId
taskExecutionID := core.TaskExecutionIdentifier{
Expand Down
5 changes: 5 additions & 0 deletions pkg/manager/impl/task_execution_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ func addGetWorkflowExecutionCallback(repository repositories.RepositoryInterface
Domain: sampleNodeExecID.ExecutionId.Domain,
Name: sampleNodeExecID.ExecutionId.Name,
},
Cluster: "propeller",
}, nil
},
)
Expand Down Expand Up @@ -299,6 +300,7 @@ func TestCreateTaskEvent_Update(t *testing.T) {
func TestCreateTaskEvent_MissingExecution(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
expectedErr := flyteAdminErrors.NewFlyteAdminErrorf(codes.Internal, "expected error")
addGetWorkflowExecutionCallback(repository)
repository.TaskExecutionRepo().(*repositoryMocks.MockTaskExecutionRepo).SetGetCallback(
func(ctx context.Context, input interfaces.GetTaskExecutionInput) (models.TaskExecution, error) {
return models.TaskExecution{}, flyteAdminErrors.NewFlyteAdminError(codes.NotFound, "foo")
Expand Down Expand Up @@ -327,6 +329,7 @@ func TestCreateTaskEvent_MissingExecution(t *testing.T) {

func TestCreateTaskEvent_CreateDatabaseError(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
addGetWorkflowExecutionCallback(repository)
repository.TaskExecutionRepo().(*repositoryMocks.MockTaskExecutionRepo).SetGetCallback(
func(ctx context.Context, input interfaces.GetTaskExecutionInput) (models.TaskExecution, error) {
return models.TaskExecution{}, flyteAdminErrors.NewFlyteAdminError(codes.NotFound, "foo")
Expand All @@ -345,6 +348,7 @@ func TestCreateTaskEvent_CreateDatabaseError(t *testing.T) {

func TestCreateTaskEvent_UpdateDatabaseError(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
addGetWorkflowExecutionCallback(repository)
addGetNodeExecutionCallback(repository)

repository.TaskExecutionRepo().(*repositoryMocks.MockTaskExecutionRepo).SetGetCallback(
Expand Down Expand Up @@ -385,6 +389,7 @@ func TestCreateTaskEvent_UpdateDatabaseError(t *testing.T) {

func TestCreateTaskEvent_UpdateTerminalEventError(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
addGetWorkflowExecutionCallback(repository)
repository.TaskExecutionRepo().(*repositoryMocks.MockTaskExecutionRepo).SetGetCallback(
func(ctx context.Context, input interfaces.GetTaskExecutionInput) (models.TaskExecution, error) {
return models.TaskExecution{
Expand Down
42 changes: 42 additions & 0 deletions pkg/manager/impl/validation/shared_execution.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package validation

import (
"context"
"fmt"

"github.com/flyteorg/flyteadmin/pkg/common"
"github.com/flyteorg/flyteadmin/pkg/errors"
"github.com/flyteorg/flyteadmin/pkg/repositories"
repoInterfaces "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flytestdlib/logger"
)

// ValidateClusterForExecutionID validates that the execution denoted by executionId is recorded as executing on `cluster`.
func ValidateClusterForExecutionID(ctx context.Context, db repositories.RepositoryInterface, executionID *core.WorkflowExecutionIdentifier, cluster string) error {
workflowExecution, err := db.ExecutionRepo().Get(ctx, repoInterfaces.Identifier{
Project: executionID.Project,
Domain: executionID.Domain,
Name: executionID.Name,
})
if err != nil {
logger.Debugf(ctx, "Failed to find existing execution with id [%+v] with err: %v", executionID, err)
return err
}
return ValidateCluster(ctx, workflowExecution.Cluster, cluster)
}

// ValidateClusterForExecution validates that the execution is recorded as executing on `cluster`.
func ValidateCluster(ctx context.Context, recordedCluster, cluster string) error {
// DefaultProducerID is used in older versions of propeller which hard code this producer id.
// See https://github.com/flyteorg/flytepropeller/blob/eaf084934de5d630cd4c11aae15ecae780cc787e/pkg/controller/nodes/task/transformer.go#L114
if len(cluster) == 0 || cluster == common.DefaultProducerID {
return nil
}
if recordedCluster != cluster {
errorMsg := fmt.Sprintf("Cluster/producer from event [%s] does not match existing workflow execution cluster: [%s]",
recordedCluster, cluster)
return errors.NewIncompatibleClusterError(ctx, errorMsg, recordedCluster)
}
return nil
}
55 changes: 55 additions & 0 deletions pkg/manager/impl/validation/shared_execution_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package validation

import (
"context"
"testing"

"github.com/flyteorg/flyteadmin/pkg/common"
"github.com/flyteorg/flyteadmin/pkg/errors"
"github.com/flyteorg/flyteadmin/pkg/repositories/interfaces"
repositoryMocks "github.com/flyteorg/flyteadmin/pkg/repositories/mocks"
"github.com/flyteorg/flyteadmin/pkg/repositories/models"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/codes"
)

var testCluster = "C1"

var testExecID = &core.WorkflowExecutionIdentifier{
Project: "p",
Domain: "d",
Name: "n",
}

func TestValidateClusterForExecutionID(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) {
return models.Execution{
Cluster: testCluster,
}, nil
})
assert.NoError(t, ValidateClusterForExecutionID(context.TODO(), repository, testExecID, testCluster))
assert.NoError(t, ValidateClusterForExecutionID(context.TODO(), repository, testExecID, common.DefaultProducerID))
}

func TestValidateCluster_Nonmatching(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) {
return models.Execution{
Cluster: "C2",
}, nil
})
err := ValidateClusterForExecutionID(context.TODO(), repository, testExecID, testCluster)
assert.Equal(t, codes.FailedPrecondition, err.(errors.FlyteAdminError).Code())
}

func TestValidateCluster_NoExecution(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
expectedErr := errors.NewFlyteAdminError(codes.Internal, "foo")
repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) {
return models.Execution{}, expectedErr
})
err := ValidateClusterForExecutionID(context.TODO(), repository, testExecID, testCluster)
assert.Equal(t, expectedErr, err)
}
Loading

0 comments on commit 81ac07f

Please sign in to comment.