diff --git a/pkg/manager/impl/launch_plan_manager_test.go b/pkg/manager/impl/launch_plan_manager_test.go index 942cf710e0..e72810e343 100644 --- a/pkg/manager/impl/launch_plan_manager_test.go +++ b/pkg/manager/impl/launch_plan_manager_test.go @@ -241,6 +241,20 @@ func TestLaunchPlan_ValidationError(t *testing.T) { assert.Nil(t, response) } +func TestLaunchPlanManager_CreateLaunchPlanErrorDueToBadLabels(t *testing.T) { + repository := getMockRepositoryForLpTest() + lpManager := NewLaunchPlanManager(repository, getMockConfigForLpTest(), mockScheduler, mockScope.NewTestScope()) + request := testutils.GetLaunchPlanRequest() + request.Spec.Labels = &admin.Labels{ + Values: map[string]string{ + "foo": "#badlabel", + "bar": "baz", + }} + response, err := lpManager.CreateLaunchPlan(context.Background(), request) + assert.EqualError(t, err, "invalid label value [#badlabel]: [a lowercase RFC 1123 label must consist of lower case alphanumeric characters or '-', and must start and end with an alphanumeric character (e.g. 'my-name', or '123-abc', regex used for validation is '[a-z0-9]([-a-z0-9]*[a-z0-9])?')]") + assert.Nil(t, response) +} + func TestLaunchPlan_DatabaseError(t *testing.T) { repository := getMockRepositoryForLpTest() repository.LaunchPlanRepo().(*repositoryMocks.MockLaunchPlanRepo).SetGetCallback( diff --git a/pkg/manager/impl/validation/launch_plan_validator.go b/pkg/manager/impl/validation/launch_plan_validator.go index 254d981066..4d6d62502d 100644 --- a/pkg/manager/impl/validation/launch_plan_validator.go +++ b/pkg/manager/impl/validation/launch_plan_validator.go @@ -30,6 +30,9 @@ func ValidateLaunchPlan(ctx context.Context, if err := ValidateIdentifier(request.Spec.WorkflowId, common.Workflow); err != nil { return err } + if err := validateLabels(request.Spec.Labels); err != nil { + return err + } if err := validateLiteralMap(request.Spec.FixedInputs, shared.FixedInputs); err != nil { return err diff --git a/pkg/manager/impl/validation/launch_plan_validator_test.go b/pkg/manager/impl/validation/launch_plan_validator_test.go index dc091600c1..1f9c38c2db 100644 --- a/pkg/manager/impl/validation/launch_plan_validator_test.go +++ b/pkg/manager/impl/validation/launch_plan_validator_test.go @@ -4,6 +4,8 @@ import ( "context" "testing" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/clients/go/coreutils" "github.com/flyteorg/flyteadmin/pkg/manager/impl/testutils" @@ -38,6 +40,17 @@ func TestValidateLpEmptyName(t *testing.T) { assert.EqualError(t, err, "missing name") } +func TestValidateLpLabels(t *testing.T) { + request := testutils.GetLaunchPlanRequest() + request.Spec.Labels = &admin.Labels{ + Values: map[string]string{ + "foo": "#badlabel", + "bar": "baz", + }} + err := ValidateLaunchPlan(context.Background(), request, testutils.GetRepoWithDefaultProject(), lpApplicationConfig, getWorkflowInterface()) + assert.EqualError(t, err, "invalid label value [#badlabel]: [a lowercase RFC 1123 label must consist of lower case alphanumeric characters or '-', and must start and end with an alphanumeric character (e.g. 'my-name', or '123-abc', regex used for validation is '[a-z0-9]([-a-z0-9]*[a-z0-9])?')]") +} + func TestValidateLpEmptyVersion(t *testing.T) { request := testutils.GetLaunchPlanRequest() request.Id.Version = "" diff --git a/pkg/manager/impl/validation/project_validator.go b/pkg/manager/impl/validation/project_validator.go index 2420a527cd..0c1a2e6f73 100644 --- a/pkg/manager/impl/validation/project_validator.go +++ b/pkg/manager/impl/validation/project_validator.go @@ -16,7 +16,6 @@ import ( const projectID = "project_id" const projectName = "project_name" const projectDescription = "project_description" -const labels = "labels" const maxNameLength = 64 const maxDescriptionLength = 300 const maxLabelArrayLength = 16 @@ -36,7 +35,7 @@ func ValidateProject(project admin.Project) error { if err := ValidateEmptyStringField(project.Id, projectID); err != nil { return err } - if err := validateProjectLabels(project); err != nil { + if err := validateLabels(project.Labels); err != nil { return err } if errs := validation.IsDNS1123Label(project.Id); len(errs) > 0 { @@ -55,19 +54,6 @@ func ValidateProject(project admin.Project) error { return nil } -func validateProjectLabels(project admin.Project) error { - if project.Labels == nil || len(project.Labels.Values) == 0 { - return nil - } - if err := ValidateMaxMapLengthField(project.Labels.Values, labels, maxLabelArrayLength); err != nil { - return err - } - if err := validateProjectLabelsAlphanumeric(project.Labels); err != nil { - return err - } - return nil -} - // Validates that a specified project and domain combination has been registered and exists in the db. func ValidateProjectAndDomain( ctx context.Context, db repositories.RepositoryInterface, config runtimeInterfaces.ApplicationConfiguration, projectID, domainID string) error { @@ -94,17 +80,3 @@ func ValidateProjectAndDomain( } return nil } - -// Given an admin.Project, checks if the project has labels and if it does, checks if the labels are K8s compliant, -// i.e. alphanumeric + - and _ -func validateProjectLabelsAlphanumeric(labels *admin.Labels) error { - for key, value := range labels.Values { - if errs := validation.IsDNS1123Label(key); len(errs) > 0 { - return errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid label key [%s]: %v", key, errs) - } - if errs := validation.IsDNS1123Label(value); len(errs) > 0 { - return errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid label value [%s]: %v", value, errs) - } - } - return nil -} diff --git a/pkg/manager/impl/validation/validation.go b/pkg/manager/impl/validation/validation.go index 0c491b3575..9fb427eb3c 100644 --- a/pkg/manager/impl/validation/validation.go +++ b/pkg/manager/impl/validation/validation.go @@ -4,6 +4,8 @@ import ( "strconv" "strings" + "k8s.io/apimachinery/pkg/util/validation" + "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/pkg/errors" "github.com/flyteorg/flyteadmin/pkg/manager/impl/shared" @@ -42,6 +44,33 @@ func ValidateMaxMapLengthField(m map[string]string, fieldName string, limit int) return nil } +func validateLabels(labels *admin.Labels) error { + if labels == nil || len(labels.Values) == 0 { + return nil + } + if err := ValidateMaxMapLengthField(labels.Values, "labels", maxLabelArrayLength); err != nil { + return err + } + if err := validateLabelsAlphanumeric(labels); err != nil { + return err + } + return nil +} + +// Given an admin.Labels, checks if the labels exist or not and if it does, checks if the labels are K8s compliant, +// i.e. alphanumeric + - and _ +func validateLabelsAlphanumeric(labels *admin.Labels) error { + for key, value := range labels.Values { + if errs := validation.IsDNS1123Label(key); len(errs) > 0 { + return errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid label key [%s]: %v", key, errs) + } + if errs := validation.IsDNS1123Label(value); len(errs) > 0 { + return errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid label value [%s]: %v", value, errs) + } + } + return nil +} + func ValidateIdentifierFieldsSet(id *core.Identifier) error { if id == nil { return shared.GetMissingArgumentError(shared.ID)