diff --git a/clierrors/errors.go b/clierrors/errors.go index a8ff701cb3..3ced52fee7 100644 --- a/clierrors/errors.go +++ b/clierrors/errors.go @@ -4,6 +4,7 @@ var ( ErrInvalidStateUpdate = "invalid state passed. Specify either activate or archive\n" ErrProjectNotPassed = "project id wasn't passed\n" // #nosec + ErrProjectIDBothPassed = "both project and id are passed\n" ErrProjectNameNotPassed = "project name is a required flag" ErrFailedProjectUpdate = "Project %v failed to update due to %v\n" diff --git a/cmd/config/subcommand/project/project_config.go b/cmd/config/subcommand/project/project_config.go index a6e8fcebd6..96b9c643dd 100644 --- a/cmd/config/subcommand/project/project_config.go +++ b/cmd/config/subcommand/project/project_config.go @@ -5,6 +5,7 @@ import ( "io/ioutil" "github.com/flyteorg/flytectl/clierrors" + "github.com/flyteorg/flytectl/cmd/config" "github.com/flyteorg/flytectl/pkg/filters" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" @@ -45,8 +46,9 @@ var DefaultProjectConfig = &ConfigProject{ } // GetProjectSpec return project spec from a file/flags -func (c *ConfigProject) GetProjectSpec(id string) (*admin.Project, error) { +func (c *ConfigProject) GetProjectSpec(cf *config.Config) (*admin.Project, error) { projectSpec := admin.Project{} + if len(c.File) > 0 { yamlFile, err := ioutil.ReadFile(c.File) if err != nil { @@ -56,23 +58,35 @@ func (c *ConfigProject) GetProjectSpec(id string) (*admin.Project, error) { if err != nil { return nil, err } - if len(id) > 0 { - projectSpec.Id = id + } else { + projectSpec.Id = c.ID + projectSpec.Name = c.Name + projectSpec.Description = c.Description + projectSpec.Labels = &admin.Labels{ + Values: c.Labels, + } + projectState, err := c.MapToAdminState() + if err != nil { + return nil, err } - return &projectSpec, nil + projectSpec.State = projectState } - projectSpec.Id = id - projectSpec.Name = c.Name - projectSpec.Description = c.Description - projectSpec.Labels = &admin.Labels{ - Values: c.Labels, + project := cf.Project + if len(projectSpec.Id) == 0 && len(project) == 0 { + err := fmt.Errorf(clierrors.ErrProjectNotPassed) + return nil, err } - projectState, err := c.MapToAdminState() - if err != nil { + + if len(projectSpec.Id) > 0 && len(project) > 0 { + err := fmt.Errorf(clierrors.ErrProjectIDBothPassed) return nil, err } - projectSpec.State = projectState + + // Get projectId from file, if not provided, fall back to project + if len(projectSpec.Id) == 0 { + projectSpec.Id = project + } return &projectSpec, nil } diff --git a/cmd/config/subcommand/project/project_config_test.go b/cmd/config/subcommand/project/project_config_test.go index a44cd0b423..b02daa4972 100644 --- a/cmd/config/subcommand/project/project_config_test.go +++ b/cmd/config/subcommand/project/project_config_test.go @@ -5,25 +5,39 @@ import ( "testing" "github.com/flyteorg/flytectl/clierrors" + "github.com/flyteorg/flytectl/cmd/config" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/stretchr/testify/assert" ) func TestGetProjectSpec(t *testing.T) { + cf := &config.Config{ + Project: "flytesnacks1", + } t.Run("Successful get project spec", func(t *testing.T) { c := &ConfigProject{ Name: "flytesnacks", } - response, err := c.GetProjectSpec("flytesnacks") + response, err := c.GetProjectSpec(cf) assert.Nil(t, err) - assert.NotNil(t, response) + assert.Equal(t, "flytesnacks1", response.Id) }) + + t.Run("Error if project and ID both exist", func(t *testing.T) { + c := &ConfigProject{ + ID: "flytesnacks", + Name: "flytesnacks", + } + _, err := c.GetProjectSpec(cf) + assert.NotNil(t, err) + }) + t.Run("Successful get request spec from file", func(t *testing.T) { c := &ConfigProject{ File: "testdata/project.yaml", } - response, err := c.GetProjectSpec("flytesnacks") + response, err := c.GetProjectSpec(&config.Config{}) assert.Nil(t, err) assert.Equal(t, "flytesnacks", response.Name) assert.Equal(t, "flytesnacks test", response.Description) diff --git a/cmd/create/project.go b/cmd/create/project.go index b0bb4eba73..c6db389d05 100644 --- a/cmd/create/project.go +++ b/cmd/create/project.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/flyteorg/flytectl/clierrors" + "github.com/flyteorg/flytectl/cmd/config" "github.com/flyteorg/flytectl/cmd/config/subcommand/project" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" @@ -45,7 +46,7 @@ Create a project by definition file. ) func createProjectsCommand(ctx context.Context, args []string, cmdCtx cmdCore.CommandContext) error { - projectSpec, err := project.DefaultProjectConfig.GetProjectSpec(project.DefaultProjectConfig.ID) + projectSpec, err := project.DefaultProjectConfig.GetProjectSpec(config.GetConfig()) if err != nil { return err } diff --git a/cmd/create/project_test.go b/cmd/create/project_test.go index 2de7ee24e1..e0c166db5c 100644 --- a/cmd/create/project_test.go +++ b/cmd/create/project_test.go @@ -7,6 +7,7 @@ import ( "github.com/flyteorg/flytectl/clierrors" + "github.com/flyteorg/flytectl/cmd/config" "github.com/flyteorg/flytectl/cmd/config/subcommand/project" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" @@ -36,6 +37,7 @@ func createProjectSetup() { project.DefaultProjectConfig.Name = "" project.DefaultProjectConfig.Labels = map[string]string{} project.DefaultProjectConfig.Description = "" + config.GetConfig().Project = "" } func TestCreateProjectFunc(t *testing.T) { s := setup() diff --git a/cmd/update/project.go b/cmd/update/project.go index 7f8b2589e9..bec3d36b96 100644 --- a/cmd/update/project.go +++ b/cmd/update/project.go @@ -4,9 +4,8 @@ import ( "context" "fmt" - "github.com/flyteorg/flytectl/cmd/config" - "github.com/flyteorg/flytectl/clierrors" + "github.com/flyteorg/flytectl/cmd/config" "github.com/flyteorg/flytectl/cmd/config/subcommand/project" cmdCore "github.com/flyteorg/flytectl/cmd/core" "github.com/flyteorg/flytestdlib/logger" @@ -54,7 +53,7 @@ Update projects.(project/projects can be used interchangeably in these commands) Update a project by definition file. Note: The name shouldn't contain any whitespace characters. :: - flytectl update project --file project.yaml + flytectl update project --file project.yaml .. code-block:: yaml @@ -84,10 +83,11 @@ Usage ) func updateProjectsFunc(ctx context.Context, args []string, cmdCtx cmdCore.CommandContext) error { - projectSpec, err := project.DefaultProjectConfig.GetProjectSpec(config.GetConfig().Project) + projectSpec, err := project.DefaultProjectConfig.GetProjectSpec(config.GetConfig()) if err != nil { return err } + if projectSpec.Id == "" { return fmt.Errorf(clierrors.ErrProjectNotPassed) }