Skip to content

Commit

Permalink
Implement UpdateWorkflow (flyteorg#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
katrogan authored Apr 2, 2020
1 parent 092d7f8 commit 21e9602
Show file tree
Hide file tree
Showing 16 changed files with 238 additions and 2 deletions.
2 changes: 1 addition & 1 deletion flyteadmin/flyteadmin_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ queues:
- dynamic: "default"
attributes:
- defaultclusters
workflowConfigs
workflowConfigs:
- project: "my_queue_1"
domain: "production"
workflowName: "my_workflow_1"
Expand Down
2 changes: 1 addition & 1 deletion flyteadmin/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ require (
github.com/jmespath/go-jmespath v0.3.0 // indirect
github.com/kelseyhightower/envconfig v1.4.0 // indirect
github.com/lib/pq v1.3.0
github.com/lyft/flyteidl v0.17.20
github.com/lyft/flyteidl v0.17.23
github.com/lyft/flytepropeller v0.2.13
github.com/lyft/flytestdlib v0.3.2
github.com/magiconair/properties v1.8.1
Expand Down
3 changes: 3 additions & 0 deletions flyteadmin/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,9 @@ github.com/lyft/flyteidl v0.17.8 h1:/bZS1K3FO45EMamNrs4Eo6WYQf1TO5bNyNTIUO6cXM0=
github.com/lyft/flyteidl v0.17.8/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20=
github.com/lyft/flyteidl v0.17.20 h1:SYhu5BRyc81fQQeCvn1pt8Nhd2BBM7JOmDnvUMwGHj4=
github.com/lyft/flyteidl v0.17.20/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20=
github.com/lyft/flyteidl v0.17.23-0.20200401223233-5fcbfe070fad h1:T5lJx1on3Qy981L19GdMwHHiAuMCWXLxHQCFk+JW4B0=
github.com/lyft/flyteidl v0.17.23-0.20200401223233-5fcbfe070fad/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20=
github.com/lyft/flyteidl v0.17.23/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20=
github.com/lyft/flyteplugins v0.3.10/go.mod h1:FOSo04q4EheU6lm0oZFvfYAWgjrum/BDUK+mUT7qDFA=
github.com/lyft/flyteplugins v0.3.11/go.mod h1:FOSo04q4EheU6lm0oZFvfYAWgjrum/BDUK+mUT7qDFA=
github.com/lyft/flytepropeller v0.1.30 h1:g55bD3aMMba4WDiBE7SLFEElutPdkEtoFQkgN59OX+M=
Expand Down
20 changes: 20 additions & 0 deletions flyteadmin/pkg/manager/impl/workflow_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,26 @@ func (w *WorkflowManager) ListWorkflowIdentifiers(ctx context.Context, request a

}

func (w *WorkflowManager) UpdateWorkflow(ctx context.Context, request admin.WorkflowUpdateRequest) (
*admin.WorkflowUpdateResponse, error) {
if err := validation.ValidateIdentifier(request.Id, common.Workflow); err != nil {
logger.Debugf(ctx, "invalid identifier [%+v]: %v", request.Id, err)
return nil, err
}
ctx = getWorkflowContext(ctx, request.Id)
workflowModel, err := util.GetWorkflowModel(ctx, w.db, *request.Id)
if err != nil {
return nil, err
}
stateInt := int32(request.State)
workflowModel.State = &stateInt
err = w.db.WorkflowRepo().Update(ctx, workflowModel)
if err != nil {
return nil, err
}
return &admin.WorkflowUpdateResponse{}, nil
}

func NewWorkflowManager(
db repositories.RepositoryInterface,
config runtimeInterfaces.Configuration,
Expand Down
37 changes: 37 additions & 0 deletions flyteadmin/pkg/manager/impl/workflow_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,3 +573,40 @@ func TestWorkflowManager_ListWorkflowIdentifiers(t *testing.T) {
assert.Equal(t, nameValue, entity.Name)
}
}

func TestUpdateWorkflow(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
workflowGetFunc := func(input interfaces.GetResourceInput) (models.Workflow, error) {
return models.Workflow{
BaseModel: models.BaseModel{
CreatedAt: testutils.MockCreatedAtValue,
},
WorkflowKey: models.WorkflowKey{
Project: input.Project,
Domain: input.Domain,
Name: input.Name,
Version: input.Version,
},
TypedInterface: testutils.GetWorkflowRequestInterfaceBytes(),
RemoteClosureIdentifier: remoteClosureIdentifier,
}, nil
}
repository.WorkflowRepo().(*repositoryMocks.MockWorkflowRepo).SetGetCallback(workflowGetFunc)

updateFuncCalled := false
workflowUpdatefunc := func(input models.Workflow) error {
updateFuncCalled = true
assert.Equal(t, admin.WorkflowState_WORKFLOW_ARCHIVED, admin.WorkflowState(*input.State))
return nil
}
repository.WorkflowRepo().(*repositoryMocks.MockWorkflowRepo).SetUpdateCallback(workflowUpdatefunc)
workflowManager := NewWorkflowManager(
repository, getMockWorkflowConfigProvider(), getMockWorkflowCompiler(), commonMocks.GetMockStorageClient(), storagePrefix,
mockScope.NewTestScope())
_, err := workflowManager.UpdateWorkflow(context.Background(), admin.WorkflowUpdateRequest{
Id: &workflowIdentifier,
State: admin.WorkflowState_WORKFLOW_ARCHIVED,
})
assert.NoError(t, err)
assert.True(t, updateFuncCalled)
}
1 change: 1 addition & 0 deletions flyteadmin/pkg/manager/interfaces/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ type WorkflowInterface interface {
ListWorkflows(ctx context.Context, request admin.ResourceListRequest) (*admin.WorkflowList, error)
ListWorkflowIdentifiers(ctx context.Context, request admin.NamedEntityIdentifierListRequest) (
*admin.NamedEntityIdentifierList, error)
UpdateWorkflow(ctx context.Context, request admin.WorkflowUpdateRequest) (*admin.WorkflowUpdateResponse, error)
}
5 changes: 5 additions & 0 deletions flyteadmin/pkg/manager/mocks/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,8 @@ func (r *MockWorkflowManager) ListWorkflowIdentifiers(ctx context.Context, reque
*admin.NamedEntityIdentifierList, error) {
return nil, nil
}

func (r *MockWorkflowManager) UpdateWorkflow(ctx context.Context, request admin.WorkflowUpdateRequest) (
*admin.WorkflowUpdateResponse, error) {
return nil, nil
}
20 changes: 20 additions & 0 deletions flyteadmin/pkg/repositories/config/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,24 @@ var Migrations = []*gormigrate.Migration{
return tx.Exec("ALTER TABLE tasks DROP COLUMN IF EXISTS type").Error
},
},
// Add state to workflow model
{
ID: "2020-04-01-workflow-state",
Migrate: func(tx *gorm.DB) error {
return tx.AutoMigrate(&models.Workflow{}).Error
},
Rollback: func(tx *gorm.DB) error {
return tx.Table("workflows").DropColumn("state").Error
},
},
// Set default state value for workflow model
{
ID: "2020-04-01-workflow-state-default",
Migrate: func(tx *gorm.DB) error {
return tx.Exec("UPDATE workflows SET state = 0").Error
},
Rollback: func(tx *gorm.DB) error {
return tx.Exec("UPDATE workflows set state = NULL").Error
},
},
}
10 changes: 10 additions & 0 deletions flyteadmin/pkg/repositories/gormimpl/workflow_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,16 @@ func (r *WorkflowRepo) ListIdentifiers(ctx context.Context, input interfaces.Lis
}, nil
}

func (r *WorkflowRepo) Update(ctx context.Context, input models.Workflow) error {
timer := r.metrics.UpdateDuration.Start()
tx := r.db.Model(&input).Updates(input)
timer.Stop()
if err := tx.Error; err != nil {
return r.errorTransformer.ToFlyteAdminError(err)
}
return nil
}

// Returns an instance of WorkflowRepoInterface
func NewWorkflowRepo(
db *gorm.DB, errorTransformer errors.ErrorTransformer, scope promutils.Scope) interfaces.WorkflowRepoInterface {
Expand Down
28 changes: 28 additions & 0 deletions flyteadmin/pkg/repositories/gormimpl/workflow_repo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ var typedInterface = []byte{1, 2, 3}

const remoteSpecIdentifier = "remote spec id"

var archived = int32(admin.WorkflowState_WORKFLOW_ARCHIVED)

func TestCreateWorkflow(t *testing.T) {
workflowRepo := NewWorkflowRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope())
err := workflowRepo.Create(context.Background(), models.Workflow{
Expand Down Expand Up @@ -266,3 +268,29 @@ func TestListWorkflowIds_MissingParameters(t *testing.T) {

assert.Equal(t, err.Error(), "missing and/or invalid parameters: limit")
}

func TestSetWorkflowInactive(t *testing.T) {
workflowRepo := NewWorkflowRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope())
GlobalMock := mocket.Catcher.Reset()
GlobalMock.Logging = true
mockDb := GlobalMock.NewMock()

mockDb.WithQuery(`UPDATE "workflows" SET "domain" = ?, "id" = ?, "name" = ?, "project" = ?, "state" = ?, ` +
`"updated_at" = ?, "version" = ? WHERE "workflows"."deleted_at" IS NULL AND "workflows"."project" = ? AND ` +
`"workflows"."domain" = ? AND "workflows"."name" = ? AND "workflows"."version" = ?`)

err := workflowRepo.Update(context.Background(), models.Workflow{
BaseModel: models.BaseModel{
ID: 1,
},
WorkflowKey: models.WorkflowKey{
Project: project,
Domain: domain,
Name: name,
Version: version,
},
State: &archived,
})
assert.NoError(t, err)
assert.True(t, mockDb.Triggered)
}
2 changes: 2 additions & 0 deletions flyteadmin/pkg/repositories/interfaces/workflow_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ type WorkflowRepoInterface interface {
// Returns workflow revisions matching query parameters. A limit must be provided for the results page size.
List(ctx context.Context, input ListResourceInput) (WorkflowCollectionOutput, error)
ListIdentifiers(ctx context.Context, input ListResourceInput) (WorkflowCollectionOutput, error)
// Updates an existing workflow in the database store.
Update(ctx context.Context, input models.Workflow) error
}

// Response format for a query on workflows.
Expand Down
13 changes: 13 additions & 0 deletions flyteadmin/pkg/repositories/mocks/workflow_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ type CreateWorkflowFunc func(input models.Workflow) error
type GetWorkflowFunc func(input interfaces.GetResourceInput) (models.Workflow, error)
type ListWorkflowFunc func(input interfaces.ListResourceInput) (interfaces.WorkflowCollectionOutput, error)
type ListIdentifiersFunc func(input interfaces.ListResourceInput) (interfaces.WorkflowCollectionOutput, error)
type UpdateWorkflowFunc func(input models.Workflow) error

type MockWorkflowRepo struct {
createFunction CreateWorkflowFunc
getFunction GetWorkflowFunc
listFunction ListWorkflowFunc
listIdentifiersFunc ListIdentifiersFunc
updateFunc UpdateWorkflowFunc
}

func (r *MockWorkflowRepo) Create(ctx context.Context, input models.Workflow) error {
Expand Down Expand Up @@ -75,6 +77,17 @@ func (r *MockWorkflowRepo) ListIdentifiers(ctx context.Context, input interfaces
return interfaces.WorkflowCollectionOutput{}, nil
}

func (r *MockWorkflowRepo) Update(ctx context.Context, workflow models.Workflow) error {
if r.updateFunc != nil {
return r.updateFunc(workflow)
}
return nil
}

func (r *MockWorkflowRepo) SetUpdateCallback(updateFunc UpdateWorkflowFunc) {
r.updateFunc = updateFunc
}

func NewMockWorkflowRepo() interfaces.WorkflowRepoInterface {
return &MockWorkflowRepo{}
}
2 changes: 2 additions & 0 deletions flyteadmin/pkg/repositories/models/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ type Workflow struct {
Executions []Execution
// Hash of the compiled workflow closure
Digest []byte
// GORM doesn't save the zero value for ints, so we use a pointer for the State field
State *int32 `gorm:"default:0"`
}
2 changes: 2 additions & 0 deletions flyteadmin/pkg/rpc/adminservice/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ type workflowEndpointMetrics struct {
get util.RequestMetrics
list util.RequestMetrics
listIds util.RequestMetrics
update util.RequestMetrics
}

type AdminMetrics struct {
Expand Down Expand Up @@ -197,6 +198,7 @@ func InitMetrics(adminScope promutils.Scope) AdminMetrics {
get: util.NewRequestMetrics(adminScope, "get_workflow"),
list: util.NewRequestMetrics(adminScope, "list_workflow"),
listIds: util.NewRequestMetrics(adminScope, "list_workflow_ids"),
update: util.NewRequestMetrics(adminScope, "update_workflow"),
},
}
}
35 changes: 35 additions & 0 deletions flyteadmin/pkg/rpc/adminservice/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
"google.golang.org/grpc/status"
)

const workflowState = "workflow_state"

func (m *AdminService) CreateWorkflow(
ctx context.Context,
request *admin.WorkflowCreateRequest) (*admin.WorkflowCreateResponse, error) {
Expand Down Expand Up @@ -125,3 +127,36 @@ func (m *AdminService) ListWorkflows(ctx context.Context, request *admin.Resourc
m.Metrics.workflowEndpointMetrics.list.Success()
return response, nil
}

func (m *AdminService) UpdateWorkflow(ctx context.Context, request *admin.WorkflowUpdateRequest) (
*admin.WorkflowUpdateResponse, error) {
defer m.interceptPanic(ctx, request)
requestedAt := time.Now()
if request == nil {
return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed")
}
// NOTE: When the Get HTTP endpoint is called the resource type is implicit (from the URL) so we must add it
// to the request.
if request.Id != nil && request.Id.ResourceType == core.ResourceType_UNSPECIFIED {
logger.Info(ctx, "Adding resource type for unspecified value in request: [%+v]", request)
request.Id.ResourceType = core.ResourceType_WORKFLOW
}
var response *admin.WorkflowUpdateResponse
var err error
m.Metrics.workflowEndpointMetrics.update.Time(func() {
response, err = m.WorkflowManager.UpdateWorkflow(ctx, *request)
})
requestParameters := audit.ParametersFromIdentifier(request.Id)
requestParameters[workflowState] = request.State.String()
audit.NewLogBuilder().WithAuthenticatedCtx(ctx).WithRequest(
"UpdateWorkflow",
requestParameters,
audit.ReadWrite,
requestedAt,
).WithResponse(time.Now(), err).Log(ctx)
if err != nil {
return nil, util.TransformAndRecordError(err, &m.Metrics.workflowEndpointMetrics.update)
}
m.Metrics.workflowEndpointMetrics.update.Success()
return response, nil
}
58 changes: 58 additions & 0 deletions flyteadmin/tests/workflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@ import (
"net/http"
"testing"

"github.com/lyft/flyteadmin/pkg/repositories"
"github.com/lyft/flyteadmin/pkg/repositories/interfaces"

"github.com/golang/protobuf/proto"

"github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin"
"github.com/lyft/flyteidl/gen/pb-go/flyteidl/core"
"github.com/lyft/flytestdlib/promutils"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -308,3 +312,57 @@ func testListWorkflow_FiltersHTTP(t *testing.T) {
Version: "123",
}, workflow.Id))
}

func TestUpdateWorkflow(t *testing.T) {
ctx := context.Background()
client, conn := GetTestAdminServiceClient()
defer conn.Close()
truncateAllTablesForTestingOnly()

identifier := core.Identifier{
ResourceType: core.ResourceType_WORKFLOW,
Project: "admintests",
Domain: "development",
Name: "name",
Version: "version",
}
createReq := admin.WorkflowCreateRequest{
Id: &identifier,
Spec: &admin.WorkflowSpec{
Template: &core.WorkflowTemplate{
Id: &identifier,
Interface: &core.TypedInterface{},
},
},
}

_, err := client.CreateWorkflow(ctx, &createReq)
assert.Nil(t, err)

testScope := promutils.NewScope("UpdateWorkflow")
db := repositories.GetRepository(
repositories.POSTGRES, getDbConfig(), testScope.NewSubScope("database"))
workflow, err := db.WorkflowRepo().Get(ctx, interfaces.GetResourceInput{
Project: "admintests",
Domain: "development",
Name: "name",
Version: "version",
})
assert.Nil(t, err)
assert.Equal(t, admin.WorkflowState_WORKFLOW_ACTIVE, admin.WorkflowState(*workflow.State))

updateReq := admin.WorkflowUpdateRequest{
Id: &identifier,
State: admin.WorkflowState_WORKFLOW_ARCHIVED,
}
_, err = client.UpdateWorkflow(ctx, &updateReq)
assert.Nil(t, err)
workflow, err = db.WorkflowRepo().Get(ctx, interfaces.GetResourceInput{
Project: "admintests",
Domain: "development",
Name: "name",
Version: "version",
})
assert.Nil(t, err)
assert.Equal(t, admin.WorkflowState_WORKFLOW_ARCHIVED, admin.WorkflowState(*workflow.State))
}

0 comments on commit 21e9602

Please sign in to comment.