Skip to content

Commit

Permalink
Fix integration test errors (flyteorg#256)
Browse files Browse the repository at this point in the history
* Gorm doesn't return not found for no rows mutated

Signed-off-by: Kevin Su <[email protected]>

* Fix test

Signed-off-by: Kevin Su <[email protected]>

* Fix test

Signed-off-by: Kevin Su <[email protected]>

* Fix typo

Signed-off-by: Kevin Su <[email protected]>

* Add tests

Signed-off-by: Kevin Su <[email protected]>

* Fix test

Signed-off-by: Kevin Su <[email protected]>

* Fix tset

Signed-off-by: Kevin Su <[email protected]>

* Fix test

Signed-off-by: Kevin Su <[email protected]>

* Fix test

Signed-off-by: Kevin Su <[email protected]>

* Add test

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Sep 20, 2021
1 parent 1c077a3 commit 5b125e4
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 36 deletions.
15 changes: 14 additions & 1 deletion flyteadmin/pkg/manager/impl/project_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"errors"
"testing"

"github.com/golang/protobuf/proto"

"github.com/flyteorg/flyteadmin/pkg/common"

"github.com/flyteorg/flyteadmin/pkg/manager/impl/testutils"
Expand Down Expand Up @@ -198,16 +200,27 @@ func TestProjectManager_CreateProjectErrorDueToBadLabels(t *testing.T) {
func TestProjectManager_UpdateProject(t *testing.T) {
mockRepository := repositoryMocks.NewMockRepository()
var updateFuncCalled bool
labels := admin.Labels{
Values: map[string]string{
"foo": "#badlabel",
"bar": "baz",
},
}
labelsBytes, _ := proto.Marshal(&labels)
mockRepository.ProjectRepo().(*repositoryMocks.MockProjectRepo).GetFunction = func(
ctx context.Context, projectID string) (models.Project, error) {
return models.Project{Identifier: "project-id", Name: "old-project-name", Description: "old-project-description"}, nil

return models.Project{Identifier: "project-id",
Name: "old-project-name",
Description: "old-project-description", Labels: labelsBytes}, nil
}
mockRepository.ProjectRepo().(*repositoryMocks.MockProjectRepo).UpdateProjectFunction = func(
ctx context.Context, projectUpdate models.Project) error {
updateFuncCalled = true
assert.Equal(t, "project-id", projectUpdate.Identifier)
assert.Equal(t, "new-project-name", projectUpdate.Name)
assert.Equal(t, "new-project-description", projectUpdate.Description)
assert.Nil(t, projectUpdate.Labels)
assert.Equal(t, int32(admin.Project_ACTIVE), *projectUpdate.State)
return nil
}
Expand Down
14 changes: 14 additions & 0 deletions flyteadmin/pkg/manager/impl/validation/project_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ import (
"strconv"
"testing"

flyteAdminErrors "github.com/flyteorg/flyteadmin/pkg/errors"
"google.golang.org/grpc/codes"

"github.com/flyteorg/flyteadmin/pkg/manager/impl/testutils"
repositoryMocks "github.com/flyteorg/flyteadmin/pkg/repositories/mocks"
"github.com/flyteorg/flyteadmin/pkg/repositories/models"
Expand Down Expand Up @@ -289,3 +292,14 @@ func TestValidateProjectAndDomainError(t *testing.T) {
assert.EqualError(t, err,
"failed to validate that project [flyte-project-id] and domain [domain] are registered, err: [foo]")
}

func TestValidateProjectAndDomainNotFound(t *testing.T) {
mockRepo := repositoryMocks.NewMockRepository()
mockRepo.ProjectRepo().(*repositoryMocks.MockProjectRepo).GetFunction = func(
ctx context.Context, projectID string) (models.Project, error) {
return models.Project{}, flyteAdminErrors.NewFlyteAdminErrorf(codes.NotFound, "project [%s] not found", projectID)
}
err := ValidateProjectAndDomain(context.Background(), mockRepo, testutils.GetApplicationConfigWithDefaultDomains(),
"flyte-project", "domain")
assert.EqualError(t, err, "failed to validate that project [flyte-project] and domain [domain] are registered, err: [project [flyte-project] not found]")
}
10 changes: 6 additions & 4 deletions flyteadmin/pkg/repositories/gormimpl/project_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ package gormimpl
import (
"context"

flyteAdminErrors "github.com/flyteorg/flyteadmin/pkg/errors"
"google.golang.org/grpc/codes"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin"
"github.com/flyteorg/flytestdlib/promutils"

"github.com/jinzhu/gorm"

flyteAdminErrors "github.com/flyteorg/flyteadmin/pkg/errors"
"github.com/flyteorg/flyteadmin/pkg/repositories/errors"
"github.com/flyteorg/flyteadmin/pkg/repositories/interfaces"
"github.com/flyteorg/flyteadmin/pkg/repositories/models"
Expand Down Expand Up @@ -39,12 +39,14 @@ func (r *ProjectRepo) Get(ctx context.Context, projectID string) (models.Project
Identifier: projectID,
}).Take(&project)
timer.Stop()
if tx.Error != nil {
return models.Project{}, r.errorTransformer.ToFlyteAdminError(tx.Error)
}
if tx.RecordNotFound() {
return models.Project{}, flyteAdminErrors.NewFlyteAdminErrorf(codes.NotFound, "project [%s] not found", projectID)
}

if tx.Error != nil {
return models.Project{}, r.errorTransformer.ToFlyteAdminError(tx.Error)
}

return project, nil
}

Expand Down
6 changes: 5 additions & 1 deletion flyteadmin/pkg/repositories/gormimpl/project_repo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,18 @@ func TestGetProject(t *testing.T) {
response["description"] = "project_description"
response["state"] = admin.Project_ACTIVE

output, err := projectRepo.Get(context.Background(), "project_id")
assert.Empty(t, output)
assert.EqualError(t, err, "project [project_id] not found")

query := GlobalMock.NewMock()
query.WithQuery(`SELECT * FROM "projects" WHERE "projects"."deleted_at" IS NULL AND ` +
`(("projects"."identifier" = project_id)) LIMIT 1`).WithReply(
[]map[string]interface{}{
response,
})

output, err := projectRepo.Get(context.Background(), "project_id")
output, err = projectRepo.Get(context.Background(), "project_id")
assert.Nil(t, err)
assert.Equal(t, "project_id", output.Identifier)
assert.Equal(t, "project_name", output.Name)
Expand Down
6 changes: 3 additions & 3 deletions flyteadmin/pkg/repositories/gormimpl/task_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ func (r *TaskRepo) Get(ctx context.Context, input interfaces.Identifier) (models
},
}).Take(&task)
timer.Stop()
if tx.Error != nil {
return models.Task{}, r.errorTransformer.ToFlyteAdminError(tx.Error)
}
if tx.RecordNotFound() {
return models.Task{}, errors.GetMissingEntityError(core.ResourceType_TASK.String(), &core.Identifier{
Project: input.Project,
Expand All @@ -53,6 +50,9 @@ func (r *TaskRepo) Get(ctx context.Context, input interfaces.Identifier) (models
Version: input.Version,
})
}
if tx.Error != nil {
return models.Task{}, r.errorTransformer.ToFlyteAdminError(tx.Error)
}
return task, nil
}

Expand Down
11 changes: 10 additions & 1 deletion flyteadmin/pkg/repositories/gormimpl/task_repo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,23 @@ func TestGetTask(t *testing.T) {
task := getMockTaskResponseFromDb(version, []byte{1, 2})
tasks = append(tasks, task)

output, err := taskRepo.Get(context.Background(), interfaces.Identifier{
Project: project,
Domain: domain,
Name: name,
Version: version,
})
assert.Empty(t, output)
assert.EqualError(t, err, "missing entity of type TASK with identifier project:\"project\" domain:\"domain\" name:\"name\" version:\"XYZ\" ")

GlobalMock := mocket.Catcher.Reset()
GlobalMock.Logging = true
// Only match on queries that append expected filters
GlobalMock.NewMock().WithQuery(
`SELECT * FROM "tasks" WHERE "tasks"."deleted_at" IS NULL AND (("tasks"."project" = project) ` +
`AND ("tasks"."domain" = domain) AND ("tasks"."name" = name) AND ("tasks"."version" = XYZ)) LIMIT 1`).
WithReply(tasks)
output, err := taskRepo.Get(context.Background(), interfaces.Identifier{
output, err = taskRepo.Get(context.Background(), interfaces.Identifier{
Project: project,
Domain: domain,
Name: name,
Expand Down
8 changes: 8 additions & 0 deletions flyteadmin/pkg/repositories/transformers/project.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ type CreateProjectModelInput struct {

func CreateProjectModel(project *admin.Project) models.Project {
stateInt := int32(project.State)
if project.Labels == nil {
return models.Project{
Identifier: project.Id,
Name: project.Name,
Description: project.Description,
State: &stateInt,
}
}
projectBytes, err := proto.Marshal(project)
if err != nil {
return models.Project{}
Expand Down
23 changes: 14 additions & 9 deletions flyteadmin/pkg/repositories/transformers/project_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,30 @@ import (
)

func TestCreateProjectModel(t *testing.T) {

projectModel := CreateProjectModel(&admin.Project{
labels := admin.Labels{
Values: map[string]string{
"foo": "#badlabel",
"bar": "baz",
},
}
project := admin.Project{
Id: "project_id",
Name: "project_name",
Description: "project_description",
Labels: &labels,
State: admin.Project_ACTIVE,
})
}

projectBytes, _ := proto.Marshal(&project)
projectModel := CreateProjectModel(&project)

activeState := int32(admin.Project_ACTIVE)
assert.Equal(t, models.Project{
Identifier: "project_id",
Name: "project_name",
Description: "project_description",
Labels: []uint8{
0xa, 0xa, 0x70, 0x72, 0x6f, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x69, 0x64, 0x12, 0xc, 0x70, 0x72, 0x6f,
0x6a, 0x65, 0x63, 0x74, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x13, 0x70, 0x72, 0x6f, 0x6a, 0x65,
0x63, 0x74, 0x5f, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e,
},
State: &activeState,
Labels: projectBytes,
State: &activeState,
}, projectModel)
}

Expand Down
62 changes: 45 additions & 17 deletions flyteadmin/tests/project.go → flyteadmin/tests/project_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/assert"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
)

func TestCreateProject(t *testing.T) {
Expand All @@ -17,22 +18,36 @@ func TestCreateProject(t *testing.T) {
client, conn := GetTestAdminServiceClient()
defer conn.Close()

projectId := "potato"
task, err := client.GetTask(ctx, &admin.ObjectGetRequest{
Id: &core.Identifier{
Project: projectId,
Domain: "development",
Name: "task",
Version: "1234",
ResourceType: core.ResourceType_TASK,
},
})
assert.EqualError(t, err, "rpc error: code = NotFound desc = missing entity of type TASK" +
" with identifier project:\"potato\" domain:\"development\" name:\"task\" version:\"1234\" ")
assert.Empty(t, task)

req := admin.ProjectRegisterRequest{
Project: &admin.Project{
Id: "potato",
Id: projectId,
Name: "spud",
},
}

_, err := client.RegisterProject(ctx, &req)
_, err = client.RegisterProject(ctx, &req)
assert.Nil(t, err)

projects, err := client.ListProjects(ctx, &admin.ProjectListRequest{})
assert.Nil(t, err)
assert.NotEmpty(t, projects.Projects)
var sawNewProject bool
for _, project := range projects.Projects {
if project.Id == "potato" {
if project.Id == projectId {
sawNewProject = true
assert.Equal(t, "spud", project.Name)
}
Expand All @@ -47,14 +62,15 @@ func TestCreateProject(t *testing.T) {

func TestUpdateProjectDescription(t *testing.T) {
truncateAllTablesForTestingOnly()

ctx := context.Background()
client, conn := GetTestAdminServiceClient()
defer conn.Close()

// Create a new project.
projectId := "potato1"
req := admin.ProjectRegisterRequest{
Project: &admin.Project{
Id: "potato",
Id: projectId,
Name: "spud",
Labels: &admin.Labels{
Values: map[string]string{
Expand All @@ -75,7 +91,7 @@ func TestUpdateProjectDescription(t *testing.T) {
// Attempt to modify the name of the Project. Labels should be a no-op.
// Name and Description should modify just fine.
_, err = client.UpdateProject(ctx, &admin.Project{
Id: "potato",
Id: projectId,
Name: "foobar",
Description: "a-new-description",
})
Expand All @@ -88,14 +104,19 @@ func TestUpdateProjectDescription(t *testing.T) {
assert.Nil(t, err)
assert.NotEmpty(t, projectsUpdated.Projects)

// Verify that the project's Name has not been modified but the Description has.
updatedProject := projectsUpdated.Projects[0]
assert.Equal(t, updatedProject.Id, "potato") // unchanged
// Verify that the project's ID has not been modified but the Description has.
var updatedProject *admin.Project
for _, project := range projectsUpdated.Projects {
if project.Id == projectId {
updatedProject = project
}
}
assert.Equal(t, updatedProject.Name, "foobar") // changed
assert.Equal(t, updatedProject.Description, "a-new-description") // changed

// Verify that project labels are not removed.
labelsMap := updatedProject.Labels
assert.NotNil(t, labelsMap)
fooVal, fooExists := labelsMap.Values["foo"]
barVal, barExists := labelsMap.Values["bar"]
assert.Equal(t, fooExists, true)
Expand All @@ -110,9 +131,10 @@ func TestUpdateProjectLabels(t *testing.T) {
defer conn.Close()

// Create a new project.
projectId := "potato2"
req := admin.ProjectRegisterRequest{
Project: &admin.Project{
Id: "potato",
Id: projectId,
Name: "spud",
},
}
Expand All @@ -122,12 +144,13 @@ func TestUpdateProjectLabels(t *testing.T) {
// Verify the project has been registered.
projects, err := client.ListProjects(ctx, &admin.ProjectListRequest{})
assert.Nil(t, err)
assert.NotNil(t, projects)
assert.NotEmpty(t, projects.Projects)

// Attempt to modify the name of the Project. Labels and name should be
// modified.
_, err = client.UpdateProject(ctx, &admin.Project{
Id: "potato",
Id: projectId,
Name: "foobar",
Labels: &admin.Labels{
Values: map[string]string{
Expand All @@ -146,9 +169,13 @@ func TestUpdateProjectLabels(t *testing.T) {
assert.NotEmpty(t, projectsUpdated.Projects)

// Check the name has been modified.
// Verify that the project's Name has not been modified but the Description has.
updatedProject := projectsUpdated.Projects[0]
assert.Equal(t, updatedProject.Id, "potato") // unchanged
// Verify that the project's ID has not been modified but the Description has.
var updatedProject *admin.Project
for _, project := range projectsUpdated.Projects {
if project.Id == projectId {
updatedProject = project
}
}
assert.Equal(t, updatedProject.Name, "foobar") // changed

// Verify that the expected labels have been added to the project.
Expand All @@ -167,9 +194,10 @@ func TestUpdateProjectLabels_BadLabels(t *testing.T) {
defer conn.Close()

// Create a new project.
projectId := "potato4"
req := admin.ProjectRegisterRequest{
Project: &admin.Project{
Id: "potato",
Id: projectId,
Name: "spud",
},
}
Expand All @@ -184,7 +212,7 @@ func TestUpdateProjectLabels_BadLabels(t *testing.T) {
// Attempt to modify the name of the Project. Labels and name should be
// modified.
_, err = client.UpdateProject(ctx, &admin.Project{
Id: "potato",
Id: projectId,
Name: "foobar",
Labels: &admin.Labels{
Values: map[string]string{
Expand All @@ -195,5 +223,5 @@ func TestUpdateProjectLabels_BadLabels(t *testing.T) {
})

// Assert that update went through without an error.
assert.EqualError(t, err, "invalid label value [#bar]: [a DNS-1123 label must consist of lower case alphanumeric characters or '-', and must start and end with an alphanumeric character (e.g. 'my-name', or '123-abc', regex used for validation is '[a-z0-9]([-a-z0-9]*[a-z0-9])?')]")
assert.EqualError(t, err, "rpc error: code = InvalidArgument desc = invalid label value [#bar]: [a lowercase RFC 1123 label must consist of lower case alphanumeric characters or '-', and must start and end with an alphanumeric character (e.g. 'my-name', or '123-abc', regex used for validation is '[a-z0-9]([-a-z0-9]*[a-z0-9])?')]")
}

0 comments on commit 5b125e4

Please sign in to comment.