diff --git a/flyteadmin/pkg/manager/impl/project_manager.go b/flyteadmin/pkg/manager/impl/project_manager.go index f1da4608fd..336760e70a 100644 --- a/flyteadmin/pkg/manager/impl/project_manager.go +++ b/flyteadmin/pkg/manager/impl/project_manager.go @@ -70,9 +70,8 @@ func (m *ProjectManager) UpdateProject(ctx context.Context, projectUpdate admin. return nil, err } - // Run validation on the request, specifically checking for labels, and return err if validation does not succeed. - err = validation.ValidateProjectLabels(projectUpdate) - if err != nil { + // Run validation on the request and return err if validation does not succeed. + if err := validation.ValidateProject(projectUpdate); err != nil { return nil, err } diff --git a/flyteadmin/pkg/manager/impl/project_manager_test.go b/flyteadmin/pkg/manager/impl/project_manager_test.go index 95bb8b7bb8..db52400b3f 100644 --- a/flyteadmin/pkg/manager/impl/project_manager_test.go +++ b/flyteadmin/pkg/manager/impl/project_manager_test.go @@ -5,6 +5,8 @@ import ( "errors" "testing" + "github.com/lyft/flyteadmin/pkg/manager/impl/shared" + "github.com/lyft/flyteadmin/pkg/common" "github.com/lyft/flyteadmin/pkg/manager/impl/testutils" repositoryMocks "github.com/lyft/flyteadmin/pkg/repositories/mocks" @@ -160,3 +162,73 @@ func TestProjectManager_CreateProjectErrorDueToBadLabels(t *testing.T) { }) assert.EqualError(t, err, "invalid label value [#badlabel]: [a DNS-1123 label must consist of lower case alphanumeric characters or '-', and must start and end with an alphanumeric character (e.g. 'my-name', or '123-abc', regex used for validation is '[a-z0-9]([-a-z0-9]*[a-z0-9])?')]") } + +func TestProjectManager_UpdateProject(t *testing.T) { + mockRepository := repositoryMocks.NewMockRepository() + var updateFuncCalled bool + mockRepository.ProjectRepo().(*repositoryMocks.MockProjectRepo).GetFunction = func( + ctx context.Context, projectID string) (models.Project, error) { + return models.Project{Identifier: "project-id", Name: "old-project-name", Description: "old-project-description"}, nil + } + mockRepository.ProjectRepo().(*repositoryMocks.MockProjectRepo).UpdateProjectFunction = func( + ctx context.Context, projectUpdate models.Project) error { + updateFuncCalled = true + assert.Equal(t, "project-id", projectUpdate.Identifier) + assert.Equal(t, "new-project-name", projectUpdate.Name) + assert.Equal(t, "new-project-description", projectUpdate.Description) + return nil + } + projectManager := NewProjectManager(mockRepository, + runtimeMocks.NewMockConfigurationProvider( + getMockApplicationConfigForProjectManagerTest(), nil, nil, nil, nil, nil)) + _, err := projectManager.UpdateProject(context.Background(), admin.Project{ + Id: "project-id", + Name: "new-project-name", + Description: "new-project-description", + }) + assert.Nil(t, err) + assert.True(t, updateFuncCalled) +} + +func TestProjectManager_UpdateProject_ErrorDueToProjectNotFound(t *testing.T) { + mockRepository := repositoryMocks.NewMockRepository() + mockRepository.ProjectRepo().(*repositoryMocks.MockProjectRepo).GetFunction = func( + ctx context.Context, projectID string) (models.Project, error) { + return models.Project{}, errors.New(projectID + " not found") + } + mockRepository.ProjectRepo().(*repositoryMocks.MockProjectRepo).UpdateProjectFunction = func( + ctx context.Context, projectUpdate models.Project) error { + assert.Fail(t, "No calls to UpdateProject were expected") + return nil + } + projectManager := NewProjectManager(mockRepository, + runtimeMocks.NewMockConfigurationProvider( + getMockApplicationConfigForProjectManagerTest(), nil, nil, nil, nil, nil)) + _, err := projectManager.UpdateProject(context.Background(), admin.Project{ + Id: "not-found-project-id", + Name: "not-found-project-name", + Description: "not-found-project-description", + }) + assert.Equal(t, errors.New("not-found-project-id not found"), err) +} + +func TestProjectManager_UpdateProject_ErrorDueToInvalidProjectName(t *testing.T) { + mockRepository := repositoryMocks.NewMockRepository() + mockRepository.ProjectRepo().(*repositoryMocks.MockProjectRepo).GetFunction = func( + ctx context.Context, projectID string) (models.Project, error) { + return models.Project{Identifier: "project-id", Name: "old-project-name", Description: "old-project-description"}, nil + } + mockRepository.ProjectRepo().(*repositoryMocks.MockProjectRepo).UpdateProjectFunction = func( + ctx context.Context, projectUpdate models.Project) error { + assert.Fail(t, "No calls to UpdateProject were expected") + return nil + } + projectManager := NewProjectManager(mockRepository, + runtimeMocks.NewMockConfigurationProvider( + getMockApplicationConfigForProjectManagerTest(), nil, nil, nil, nil, nil)) + _, err := projectManager.UpdateProject(context.Background(), admin.Project{ + Id: "project-id", + // No project name + }) + assert.Equal(t, shared.GetMissingArgumentError("project_name"), err) +} diff --git a/flyteadmin/pkg/manager/impl/validation/project_validator.go b/flyteadmin/pkg/manager/impl/validation/project_validator.go index 00d7c3c6b6..212113240c 100644 --- a/flyteadmin/pkg/manager/impl/validation/project_validator.go +++ b/flyteadmin/pkg/manager/impl/validation/project_validator.go @@ -3,8 +3,9 @@ package validation import ( "context" - "github.com/lyft/flyteadmin/pkg/errors" "github.com/lyft/flyteadmin/pkg/manager/impl/shared" + + "github.com/lyft/flyteadmin/pkg/errors" "github.com/lyft/flyteadmin/pkg/repositories" runtimeInterfaces "github.com/lyft/flyteadmin/pkg/runtime/interfaces" "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" @@ -15,36 +16,52 @@ import ( const projectID = "project_id" const projectName = "project_name" const projectDescription = "project_description" +const labels = "labels" +const maxNameLength = 64 const maxDescriptionLength = 300 +const maxLabelArrayLength = 16 func ValidateProjectRegisterRequest(request admin.ProjectRegisterRequest) error { if request.Project == nil { return shared.GetMissingArgumentError(shared.Project) } - if err := ValidateEmptyStringField(request.Project.Id, projectID); err != nil { + return ValidateProject(*request.Project) +} + +func ValidateProject(project admin.Project) error { + if err := ValidateEmptyStringField(project.Id, projectID); err != nil { return err } - if err := ValidateProjectLabels(*request.Project); err != nil { + if err := validateProjectLabels(project); err != nil { return err } - if errs := validation.IsDNS1123Label(request.Project.Id); len(errs) > 0 { - return errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid project id [%s]: %v", request.Project.Id, errs) + if errs := validation.IsDNS1123Label(project.Id); len(errs) > 0 { + return errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid project id [%s]: %v", project.Id, errs) + } + if err := ValidateEmptyStringField(project.Name, projectName); err != nil { + return err } - if err := ValidateEmptyStringField(request.Project.Name, projectName); err != nil { + if err := ValidateMaxLengthStringField(project.Name, projectName, maxNameLength); err != nil { return err } - if err := ValidateMaxLengthStringField(request.Project.Description, projectDescription, maxDescriptionLength); err != nil { + if err := ValidateMaxLengthStringField(project.Description, projectDescription, maxDescriptionLength); err != nil { return err } - if request.Project.Domains != nil { + if project.Domains != nil { return errors.NewFlyteAdminError(codes.InvalidArgument, "Domains are currently only set system wide. Please retry without domains included in your request.") } return nil } -func ValidateProjectLabels(request admin.Project) error { - if err := ValidateProjectLabelsAlphanumeric(request); err != nil { +func validateProjectLabels(project admin.Project) error { + if project.Labels == nil || len(project.Labels.Values) == 0 { + return nil + } + if err := ValidateMaxMapLengthField(project.Labels.Values, labels, maxLabelArrayLength); err != nil { + return err + } + if err := validateProjectLabelsAlphanumeric(project.Labels); err != nil { return err } return nil @@ -79,11 +96,8 @@ func ValidateProjectAndDomain( // Given an admin.Project, checks if the project has labels and if it does, checks if the labels are K8s compliant, // i.e. alphanumeric + - and _ -func ValidateProjectLabelsAlphanumeric(request admin.Project) error { - if request.Labels == nil || len(request.Labels.Values) == 0 { - return nil - } - for key, value := range request.Labels.Values { +func validateProjectLabelsAlphanumeric(labels *admin.Labels) error { + for key, value := range labels.Values { if errs := validation.IsDNS1123Label(key); len(errs) > 0 { return errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid label key [%s]: %v", key, errs) } diff --git a/flyteadmin/pkg/manager/impl/validation/project_validator_test.go b/flyteadmin/pkg/manager/impl/validation/project_validator_test.go index 8b98ed3c4c..3be5933a7b 100644 --- a/flyteadmin/pkg/manager/impl/validation/project_validator_test.go +++ b/flyteadmin/pkg/manager/impl/validation/project_validator_test.go @@ -3,6 +3,7 @@ package validation import ( "context" "errors" + "strconv" "testing" "github.com/lyft/flyteadmin/pkg/manager/impl/testutils" @@ -64,6 +65,15 @@ func TestValidateProjectRegisterRequest(t *testing.T) { }, expectedError: "missing project_name", }, + { + request: admin.ProjectRegisterRequest{ + Project: &admin.Project{ + Id: "proj", + Name: "longnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamel", + }, + }, + expectedError: "project_name cannot exceed 64 characters", + }, { request: admin.ProjectRegisterRequest{ Project: &admin.Project{ @@ -93,6 +103,36 @@ func TestValidateProjectRegisterRequest(t *testing.T) { }, expectedError: "project_description cannot exceed 300 characters", }, + { + request: admin.ProjectRegisterRequest{ + Project: &admin.Project{ + Id: "proj", + Name: "name", + Labels: &admin.Labels{ + Values: map[string]string{ + "#badkey": "foo", + "bar": "baz", + }, + }, + }, + }, + expectedError: "invalid label key [#badkey]: [a DNS-1123 label must consist of lower case alphanumeric characters or '-', and must start and end with an alphanumeric character (e.g. 'my-name', or '123-abc', regex used for validation is '[a-z0-9]([-a-z0-9]*[a-z0-9])?')]", + }, + { + request: admin.ProjectRegisterRequest{ + Project: &admin.Project{ + Id: "proj", + Name: "name", + Labels: &admin.Labels{ + Values: map[string]string{ + "foo": ".bad-label-value", + "bar": "baz", + }, + }, + }, + }, + expectedError: "invalid label value [.bad-label-value]: [a DNS-1123 label must consist of lower case alphanumeric characters or '-', and must start and end with an alphanumeric character (e.g. 'my-name', or '123-abc', regex used for validation is '[a-z0-9]([-a-z0-9]*[a-z0-9])?')]", + }, } for _, val := range testValues { @@ -102,6 +142,124 @@ func TestValidateProjectRegisterRequest(t *testing.T) { } } +func TestValidateProject_ValidProject(t *testing.T) { + assert.Nil(t, ValidateProject(admin.Project{ + Id: "proj", + Name: "proj", + Description: "An amazing description for this project", + Labels: &admin.Labels{ + Values: map[string]string{ + "foo": "bar", + }, + }, + })) +} + +func TestValidateProject(t *testing.T) { + type testValue struct { + project admin.Project + expectedError string + } + testValues := []testValue{ + { + project: admin.Project{ + Name: "proj", + Domains: []*admin.Domain{ + { + Id: "foo", + Name: "foo", + }, + }, + }, + expectedError: "missing project_id", + }, + { + project: admin.Project{ + Id: "%)(*&", + Name: "proj", + }, + expectedError: "invalid project id [%)(*&]: [a DNS-1123 label must consist of lower case alphanumeric " + + "characters or '-', and must start and end with an alphanumeric character (e.g. 'my-name', or " + + "'123-abc', regex used for validation is '[a-z0-9]([-a-z0-9]*[a-z0-9])?')]", + }, + { + project: admin.Project{ + Id: "proj", + }, + expectedError: "missing project_name", + }, + { + project: admin.Project{ + Id: "proj", + Name: "longnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamel", + }, + expectedError: "project_name cannot exceed 64 characters", + }, + { + project: admin.Project{ + Id: "proj", + Name: "proj", + Domains: []*admin.Domain{ + { + Id: "foo", + Name: "foo", + }, + { + Id: "foo", + }, + }, + }, + expectedError: "Domains are currently only set system wide. Please retry without domains included in your request.", + }, + { + project: admin.Project{ + Id: "proj", + Name: "name", + // 301 character string + Description: "longnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongnamelongn", + }, + expectedError: "project_description cannot exceed 300 characters", + }, + { + project: admin.Project{ + Id: "proj", + Name: "name", + Labels: &admin.Labels{ + Values: createLabelsMap(17), + }, + }, + expectedError: "labels map cannot exceed 16 entries", + }, + { + project: admin.Project{ + Id: "proj", + Name: "name", + Labels: &admin.Labels{ + Values: map[string]string{ + "#badkey": "foo", + "bar": "baz", + }, + }, + }, + expectedError: "invalid label key [#badkey]: [a DNS-1123 label must consist of lower case alphanumeric characters or '-', and must start and end with an alphanumeric character (e.g. 'my-name', or '123-abc', regex used for validation is '[a-z0-9]([-a-z0-9]*[a-z0-9])?')]", + }, + } + + for _, val := range testValues { + t.Run(val.expectedError, func(t *testing.T) { + assert.EqualError(t, ValidateProject(val.project), val.expectedError) + }) + } +} + +func createLabelsMap(size int) map[string]string { + result := make(map[string]string, size) + for i := 0; i < size; i++ { + result["key-"+strconv.Itoa(i)] = "value" + } + return result +} + func TestValidateProjectAndDomain(t *testing.T) { mockRepo := repositoryMocks.NewMockRepository() mockRepo.ProjectRepo().(*repositoryMocks.MockProjectRepo).GetFunction = func( diff --git a/flyteadmin/pkg/manager/impl/validation/validation.go b/flyteadmin/pkg/manager/impl/validation/validation.go index 9a08daf02c..3d0a2b94d1 100644 --- a/flyteadmin/pkg/manager/impl/validation/validation.go +++ b/flyteadmin/pkg/manager/impl/validation/validation.go @@ -34,6 +34,14 @@ func ValidateMaxLengthStringField(field string, fieldName string, limit int) err return nil } +// Validates that a map field does not exceed a certain amount of entries +func ValidateMaxMapLengthField(m map[string]string, fieldName string, limit int) error { + if len(m) > limit { + return errors.NewFlyteAdminErrorf(codes.InvalidArgument, "%s map cannot exceed %d entries", fieldName, limit) + } + return nil +} + func ValidateIdentifierFieldsSet(id *core.Identifier) error { if id == nil { return shared.GetMissingArgumentError(shared.ID) diff --git a/flyteadmin/pkg/manager/impl/validation/validation_test.go b/flyteadmin/pkg/manager/impl/validation/validation_test.go index 94df9ef7a7..fe7bda3dd5 100644 --- a/flyteadmin/pkg/manager/impl/validation/validation_test.go +++ b/flyteadmin/pkg/manager/impl/validation/validation_test.go @@ -25,6 +25,17 @@ func TestValidateMaxLengthStringField(t *testing.T) { assert.Equal(t, codes.InvalidArgument, err.(errors.FlyteAdminError).Code()) } +func TestValidateMaxMapLengthField(t *testing.T) { + labels := map[string]string{ + "k1": "v1", + "k2": "v2", + "k3": "v3", + } + err := ValidateMaxMapLengthField(labels, "foo", 2) + assert.EqualError(t, err, "foo map cannot exceed 2 entries") + assert.Equal(t, codes.InvalidArgument, err.(errors.FlyteAdminError).Code()) +} + func TestValidateIdentifier(t *testing.T) { err := ValidateIdentifier(&core.Identifier{ ResourceType: core.ResourceType_TASK, diff --git a/flyteadmin/pkg/repositories/config/migrations.go b/flyteadmin/pkg/repositories/config/migrations.go index 433512a909..9ba6640d21 100644 --- a/flyteadmin/pkg/repositories/config/migrations.go +++ b/flyteadmin/pkg/repositories/config/migrations.go @@ -199,7 +199,7 @@ var Migrations = []*gormigrate.Migration{ return tx.Exec("ALTER TABLE workflows ADD COLUMN IF NOT EXISTS state integer;").Error }, }, - // Modify the executions & node_executison table, if necessary + // Modify the executions & node_execution table, if necessary { ID: "2020-04-29-executions", Migrate: func(tx *gorm.DB) error {