diff --git a/flyteadmin/pkg/manager/impl/project_manager_test.go b/flyteadmin/pkg/manager/impl/project_manager_test.go index c5674a0b12..76ba71f3a5 100644 --- a/flyteadmin/pkg/manager/impl/project_manager_test.go +++ b/flyteadmin/pkg/manager/impl/project_manager_test.go @@ -5,6 +5,8 @@ import ( "errors" "testing" + "github.com/golang/protobuf/proto" + "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/pkg/manager/impl/testutils" @@ -198,9 +200,19 @@ func TestProjectManager_CreateProjectErrorDueToBadLabels(t *testing.T) { func TestProjectManager_UpdateProject(t *testing.T) { mockRepository := repositoryMocks.NewMockRepository() var updateFuncCalled bool + labels := admin.Labels{ + Values: map[string]string{ + "foo": "#badlabel", + "bar": "baz", + }, + } + labelsBytes, _ := proto.Marshal(&labels) mockRepository.ProjectRepo().(*repositoryMocks.MockProjectRepo).GetFunction = func( ctx context.Context, projectID string) (models.Project, error) { - return models.Project{Identifier: "project-id", Name: "old-project-name", Description: "old-project-description"}, nil + + return models.Project{Identifier: "project-id", + Name: "old-project-name", + Description: "old-project-description", Labels: labelsBytes}, nil } mockRepository.ProjectRepo().(*repositoryMocks.MockProjectRepo).UpdateProjectFunction = func( ctx context.Context, projectUpdate models.Project) error { @@ -208,6 +220,7 @@ func TestProjectManager_UpdateProject(t *testing.T) { assert.Equal(t, "project-id", projectUpdate.Identifier) assert.Equal(t, "new-project-name", projectUpdate.Name) assert.Equal(t, "new-project-description", projectUpdate.Description) + assert.Nil(t, projectUpdate.Labels) assert.Equal(t, int32(admin.Project_ACTIVE), *projectUpdate.State) return nil } diff --git a/flyteadmin/pkg/manager/impl/validation/project_validator_test.go b/flyteadmin/pkg/manager/impl/validation/project_validator_test.go index f5fc87eb83..9862f3d176 100644 --- a/flyteadmin/pkg/manager/impl/validation/project_validator_test.go +++ b/flyteadmin/pkg/manager/impl/validation/project_validator_test.go @@ -6,6 +6,9 @@ import ( "strconv" "testing" + flyteAdminErrors "github.com/flyteorg/flyteadmin/pkg/errors" + "google.golang.org/grpc/codes" + "github.com/flyteorg/flyteadmin/pkg/manager/impl/testutils" repositoryMocks "github.com/flyteorg/flyteadmin/pkg/repositories/mocks" "github.com/flyteorg/flyteadmin/pkg/repositories/models" @@ -289,3 +292,14 @@ func TestValidateProjectAndDomainError(t *testing.T) { assert.EqualError(t, err, "failed to validate that project [flyte-project-id] and domain [domain] are registered, err: [foo]") } + +func TestValidateProjectAndDomainNotFound(t *testing.T) { + mockRepo := repositoryMocks.NewMockRepository() + mockRepo.ProjectRepo().(*repositoryMocks.MockProjectRepo).GetFunction = func( + ctx context.Context, projectID string) (models.Project, error) { + return models.Project{}, flyteAdminErrors.NewFlyteAdminErrorf(codes.NotFound, "project [%s] not found", projectID) + } + err := ValidateProjectAndDomain(context.Background(), mockRepo, testutils.GetApplicationConfigWithDefaultDomains(), + "flyte-project", "domain") + assert.EqualError(t, err, "failed to validate that project [flyte-project] and domain [domain] are registered, err: [project [flyte-project] not found]") +} diff --git a/flyteadmin/pkg/repositories/gormimpl/project_repo.go b/flyteadmin/pkg/repositories/gormimpl/project_repo.go index fbcb4bc4d4..6e3917f4f4 100644 --- a/flyteadmin/pkg/repositories/gormimpl/project_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/project_repo.go @@ -3,6 +3,7 @@ package gormimpl import ( "context" + flyteAdminErrors "github.com/flyteorg/flyteadmin/pkg/errors" "google.golang.org/grpc/codes" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" @@ -10,7 +11,6 @@ import ( "github.com/jinzhu/gorm" - flyteAdminErrors "github.com/flyteorg/flyteadmin/pkg/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyteadmin/pkg/repositories/models" @@ -39,12 +39,14 @@ func (r *ProjectRepo) Get(ctx context.Context, projectID string) (models.Project Identifier: projectID, }).Take(&project) timer.Stop() - if tx.Error != nil { - return models.Project{}, r.errorTransformer.ToFlyteAdminError(tx.Error) - } if tx.RecordNotFound() { return models.Project{}, flyteAdminErrors.NewFlyteAdminErrorf(codes.NotFound, "project [%s] not found", projectID) } + + if tx.Error != nil { + return models.Project{}, r.errorTransformer.ToFlyteAdminError(tx.Error) + } + return project, nil } diff --git a/flyteadmin/pkg/repositories/gormimpl/project_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/project_repo_test.go index dd0b49207c..1568e4514c 100644 --- a/flyteadmin/pkg/repositories/gormimpl/project_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/project_repo_test.go @@ -48,6 +48,10 @@ func TestGetProject(t *testing.T) { response["description"] = "project_description" response["state"] = admin.Project_ACTIVE + output, err := projectRepo.Get(context.Background(), "project_id") + assert.Empty(t, output) + assert.EqualError(t, err, "project [project_id] not found") + query := GlobalMock.NewMock() query.WithQuery(`SELECT * FROM "projects" WHERE "projects"."deleted_at" IS NULL AND ` + `(("projects"."identifier" = project_id)) LIMIT 1`).WithReply( @@ -55,7 +59,7 @@ func TestGetProject(t *testing.T) { response, }) - output, err := projectRepo.Get(context.Background(), "project_id") + output, err = projectRepo.Get(context.Background(), "project_id") assert.Nil(t, err) assert.Equal(t, "project_id", output.Identifier) assert.Equal(t, "project_name", output.Name) diff --git a/flyteadmin/pkg/repositories/gormimpl/task_repo.go b/flyteadmin/pkg/repositories/gormimpl/task_repo.go index 36c3e46376..f564c857ef 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_repo.go @@ -42,9 +42,6 @@ func (r *TaskRepo) Get(ctx context.Context, input interfaces.Identifier) (models }, }).Take(&task) timer.Stop() - if tx.Error != nil { - return models.Task{}, r.errorTransformer.ToFlyteAdminError(tx.Error) - } if tx.RecordNotFound() { return models.Task{}, errors.GetMissingEntityError(core.ResourceType_TASK.String(), &core.Identifier{ Project: input.Project, @@ -53,6 +50,9 @@ func (r *TaskRepo) Get(ctx context.Context, input interfaces.Identifier) (models Version: input.Version, }) } + if tx.Error != nil { + return models.Task{}, r.errorTransformer.ToFlyteAdminError(tx.Error) + } return task, nil } diff --git a/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go index b3cd88ae0a..bf8dc4414d 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go @@ -51,6 +51,15 @@ func TestGetTask(t *testing.T) { task := getMockTaskResponseFromDb(version, []byte{1, 2}) tasks = append(tasks, task) + output, err := taskRepo.Get(context.Background(), interfaces.Identifier{ + Project: project, + Domain: domain, + Name: name, + Version: version, + }) + assert.Empty(t, output) + assert.EqualError(t, err, "missing entity of type TASK with identifier project:\"project\" domain:\"domain\" name:\"name\" version:\"XYZ\" ") + GlobalMock := mocket.Catcher.Reset() GlobalMock.Logging = true // Only match on queries that append expected filters @@ -58,7 +67,7 @@ func TestGetTask(t *testing.T) { `SELECT * FROM "tasks" WHERE "tasks"."deleted_at" IS NULL AND (("tasks"."project" = project) ` + `AND ("tasks"."domain" = domain) AND ("tasks"."name" = name) AND ("tasks"."version" = XYZ)) LIMIT 1`). WithReply(tasks) - output, err := taskRepo.Get(context.Background(), interfaces.Identifier{ + output, err = taskRepo.Get(context.Background(), interfaces.Identifier{ Project: project, Domain: domain, Name: name, diff --git a/flyteadmin/pkg/repositories/transformers/project.go b/flyteadmin/pkg/repositories/transformers/project.go index 3c7d83425f..23e9145179 100644 --- a/flyteadmin/pkg/repositories/transformers/project.go +++ b/flyteadmin/pkg/repositories/transformers/project.go @@ -14,6 +14,14 @@ type CreateProjectModelInput struct { func CreateProjectModel(project *admin.Project) models.Project { stateInt := int32(project.State) + if project.Labels == nil { + return models.Project{ + Identifier: project.Id, + Name: project.Name, + Description: project.Description, + State: &stateInt, + } + } projectBytes, err := proto.Marshal(project) if err != nil { return models.Project{} diff --git a/flyteadmin/pkg/repositories/transformers/project_test.go b/flyteadmin/pkg/repositories/transformers/project_test.go index 6dca5de9ac..3853b71557 100644 --- a/flyteadmin/pkg/repositories/transformers/project_test.go +++ b/flyteadmin/pkg/repositories/transformers/project_test.go @@ -11,25 +11,30 @@ import ( ) func TestCreateProjectModel(t *testing.T) { - - projectModel := CreateProjectModel(&admin.Project{ + labels := admin.Labels{ + Values: map[string]string{ + "foo": "#badlabel", + "bar": "baz", + }, + } + project := admin.Project{ Id: "project_id", Name: "project_name", Description: "project_description", + Labels: &labels, State: admin.Project_ACTIVE, - }) + } + + projectBytes, _ := proto.Marshal(&project) + projectModel := CreateProjectModel(&project) activeState := int32(admin.Project_ACTIVE) assert.Equal(t, models.Project{ Identifier: "project_id", Name: "project_name", Description: "project_description", - Labels: []uint8{ - 0xa, 0xa, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0x12, 0xc, 0x70, 0x72, 0x6f, - 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x13, 0x70, 0x72, 0x6f, 0x6a, 0x65, - 0x63, 0x74, 0x5f, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, - }, - State: &activeState, + Labels: projectBytes, + State: &activeState, }, projectModel) } diff --git a/flyteadmin/tests/project.go b/flyteadmin/tests/project_test.go similarity index 74% rename from flyteadmin/tests/project.go rename to flyteadmin/tests/project_test.go index 849418bbad..d46849d8be 100644 --- a/flyteadmin/tests/project.go +++ b/flyteadmin/tests/project_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" ) func TestCreateProject(t *testing.T) { @@ -17,14 +18,28 @@ func TestCreateProject(t *testing.T) { client, conn := GetTestAdminServiceClient() defer conn.Close() + projectId := "potato" + task, err := client.GetTask(ctx, &admin.ObjectGetRequest{ + Id: &core.Identifier{ + Project: projectId, + Domain: "development", + Name: "task", + Version: "1234", + ResourceType: core.ResourceType_TASK, + }, + }) + assert.EqualError(t, err, "rpc error: code = NotFound desc = missing entity of type TASK" + + " with identifier project:\"potato\" domain:\"development\" name:\"task\" version:\"1234\" ") + assert.Empty(t, task) + req := admin.ProjectRegisterRequest{ Project: &admin.Project{ - Id: "potato", + Id: projectId, Name: "spud", }, } - _, err := client.RegisterProject(ctx, &req) + _, err = client.RegisterProject(ctx, &req) assert.Nil(t, err) projects, err := client.ListProjects(ctx, &admin.ProjectListRequest{}) @@ -32,7 +47,7 @@ func TestCreateProject(t *testing.T) { assert.NotEmpty(t, projects.Projects) var sawNewProject bool for _, project := range projects.Projects { - if project.Id == "potato" { + if project.Id == projectId { sawNewProject = true assert.Equal(t, "spud", project.Name) } @@ -47,14 +62,15 @@ func TestCreateProject(t *testing.T) { func TestUpdateProjectDescription(t *testing.T) { truncateAllTablesForTestingOnly() + ctx := context.Background() client, conn := GetTestAdminServiceClient() defer conn.Close() - // Create a new project. + projectId := "potato1" req := admin.ProjectRegisterRequest{ Project: &admin.Project{ - Id: "potato", + Id: projectId, Name: "spud", Labels: &admin.Labels{ Values: map[string]string{ @@ -75,7 +91,7 @@ func TestUpdateProjectDescription(t *testing.T) { // Attempt to modify the name of the Project. Labels should be a no-op. // Name and Description should modify just fine. _, err = client.UpdateProject(ctx, &admin.Project{ - Id: "potato", + Id: projectId, Name: "foobar", Description: "a-new-description", }) @@ -88,14 +104,19 @@ func TestUpdateProjectDescription(t *testing.T) { assert.Nil(t, err) assert.NotEmpty(t, projectsUpdated.Projects) - // Verify that the project's Name has not been modified but the Description has. - updatedProject := projectsUpdated.Projects[0] - assert.Equal(t, updatedProject.Id, "potato") // unchanged + // Verify that the project's ID has not been modified but the Description has. + var updatedProject *admin.Project + for _, project := range projectsUpdated.Projects { + if project.Id == projectId { + updatedProject = project + } + } assert.Equal(t, updatedProject.Name, "foobar") // changed assert.Equal(t, updatedProject.Description, "a-new-description") // changed // Verify that project labels are not removed. labelsMap := updatedProject.Labels + assert.NotNil(t, labelsMap) fooVal, fooExists := labelsMap.Values["foo"] barVal, barExists := labelsMap.Values["bar"] assert.Equal(t, fooExists, true) @@ -110,9 +131,10 @@ func TestUpdateProjectLabels(t *testing.T) { defer conn.Close() // Create a new project. + projectId := "potato2" req := admin.ProjectRegisterRequest{ Project: &admin.Project{ - Id: "potato", + Id: projectId, Name: "spud", }, } @@ -122,12 +144,13 @@ func TestUpdateProjectLabels(t *testing.T) { // Verify the project has been registered. projects, err := client.ListProjects(ctx, &admin.ProjectListRequest{}) assert.Nil(t, err) + assert.NotNil(t, projects) assert.NotEmpty(t, projects.Projects) // Attempt to modify the name of the Project. Labels and name should be // modified. _, err = client.UpdateProject(ctx, &admin.Project{ - Id: "potato", + Id: projectId, Name: "foobar", Labels: &admin.Labels{ Values: map[string]string{ @@ -146,9 +169,13 @@ func TestUpdateProjectLabels(t *testing.T) { assert.NotEmpty(t, projectsUpdated.Projects) // Check the name has been modified. - // Verify that the project's Name has not been modified but the Description has. - updatedProject := projectsUpdated.Projects[0] - assert.Equal(t, updatedProject.Id, "potato") // unchanged + // Verify that the project's ID has not been modified but the Description has. + var updatedProject *admin.Project + for _, project := range projectsUpdated.Projects { + if project.Id == projectId { + updatedProject = project + } + } assert.Equal(t, updatedProject.Name, "foobar") // changed // Verify that the expected labels have been added to the project. @@ -167,9 +194,10 @@ func TestUpdateProjectLabels_BadLabels(t *testing.T) { defer conn.Close() // Create a new project. + projectId := "potato4" req := admin.ProjectRegisterRequest{ Project: &admin.Project{ - Id: "potato", + Id: projectId, Name: "spud", }, } @@ -184,7 +212,7 @@ func TestUpdateProjectLabels_BadLabels(t *testing.T) { // Attempt to modify the name of the Project. Labels and name should be // modified. _, err = client.UpdateProject(ctx, &admin.Project{ - Id: "potato", + Id: projectId, Name: "foobar", Labels: &admin.Labels{ Values: map[string]string{ @@ -195,5 +223,5 @@ func TestUpdateProjectLabels_BadLabels(t *testing.T) { }) // Assert that update went through without an error. - assert.EqualError(t, err, "invalid label value [#bar]: [a DNS-1123 label must consist of lower case alphanumeric characters or '-', and must start and end with an alphanumeric character (e.g. 'my-name', or '123-abc', regex used for validation is '[a-z0-9]([-a-z0-9]*[a-z0-9])?')]") + assert.EqualError(t, err, "rpc error: code = InvalidArgument desc = invalid label value [#bar]: [a lowercase RFC 1123 label must consist of lower case alphanumeric characters or '-', and must start and end with an alphanumeric character (e.g. 'my-name', or '123-abc', regex used for validation is '[a-z0-9]([-a-z0-9]*[a-z0-9])?')]") }