diff --git a/pkg/controller/controller.go b/pkg/controller/controller.go index 2fc0b5a3bc..a6a682e27e 100644 --- a/pkg/controller/controller.go +++ b/pkg/controller/controller.go @@ -8,6 +8,8 @@ import ( "runtime/pprof" "time" + "google.golang.org/grpc" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" "github.com/flyteorg/flytestdlib/promutils/labeled" @@ -285,21 +287,21 @@ func newControllerMetrics(scope promutils.Scope) *metrics { } } -func getAdminClient(ctx context.Context) (client service.AdminServiceClient, err error) { +func getAdminClient(ctx context.Context) (client service.AdminServiceClient, opt grpc.DialOption, err error) { cfg := admin.GetConfig(ctx) clients, err := admin.NewClientsetBuilder().WithConfig(cfg).Build(ctx) if err != nil { - return nil, fmt.Errorf("failed to initialize clientset. Error: %w", err) + return nil, nil, fmt.Errorf("failed to initialize clientset. Error: %w", err) } - return clients.AdminClient(), nil + return clients.AdminClient(), clients.AuthOpt(), nil } // NewController returns a new FlyteWorkflow controller func New(ctx context.Context, cfg *config.Config, kubeclientset kubernetes.Interface, flytepropellerClientset clientset.Interface, flyteworkflowInformerFactory informers.SharedInformerFactory, kubeClient executors.Client, scope promutils.Scope) (*Controller, error) { - adminClient, err := getAdminClient(ctx) + adminClient, authOpts, err := getAdminClient(ctx) if err != nil { logger.Errorf(ctx, "failed to initialize Admin client, err :%s", err.Error()) return nil, err @@ -382,7 +384,7 @@ func New(ctx context.Context, cfg *config.Config, kubeclientset kubernetes.Inter } logger.Info(ctx, "Setting up Catalog client.") - catalogClient, err := catalog.NewCatalogClient(ctx) + catalogClient, err := catalog.NewCatalogClient(ctx, authOpts) if err != nil { return nil, errors.Wrapf(err, "Failed to create datacatalog client") } diff --git a/pkg/controller/nodes/task/catalog/config.go b/pkg/controller/nodes/task/catalog/config.go index 958772267f..f4dfb31fff 100644 --- a/pkg/controller/nodes/task/catalog/config.go +++ b/pkg/controller/nodes/task/catalog/config.go @@ -6,6 +6,7 @@ import ( "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog" "github.com/flyteorg/flytestdlib/config" + "google.golang.org/grpc" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/catalog/datacatalog" ) @@ -30,10 +31,11 @@ const ( ) type Config struct { - Type DiscoveryType `json:"type" pflag:"\"noop\", Catalog Implementation to use"` - Endpoint string `json:"endpoint" pflag:"\"\", Endpoint for catalog service"` - Insecure bool `json:"insecure" pflag:"false, Use insecure grpc connection"` - MaxCacheAge config.Duration `json:"max-cache-age" pflag:", Cache entries past this age will incur cache miss. 0 means cache never expires"` + Type DiscoveryType `json:"type" pflag:"\"noop\", Catalog Implementation to use"` + Endpoint string `json:"endpoint" pflag:"\"\", Endpoint for catalog service"` + Insecure bool `json:"insecure" pflag:"false, Use insecure grpc connection"` + MaxCacheAge config.Duration `json:"max-cache-age" pflag:", Cache entries past this age will incur cache miss. 0 means cache never expires"` + UseAdminAuth bool `json:"use-admin-auth" pflag:"false, Use the same gRPC credentials option as the flyteadmin client"` } // Gets loaded config for Discovery @@ -41,12 +43,12 @@ func GetConfig() *Config { return configSection.GetConfig().(*Config) } -func NewCatalogClient(ctx context.Context) (catalog.Client, error) { +func NewCatalogClient(ctx context.Context, authOpt grpc.DialOption) (catalog.Client, error) { catalogConfig := GetConfig() switch catalogConfig.Type { case DataCatalogType: - return datacatalog.NewDataCatalog(ctx, catalogConfig.Endpoint, catalogConfig.Insecure, catalogConfig.MaxCacheAge.Duration) + return datacatalog.NewDataCatalog(ctx, catalogConfig.Endpoint, catalogConfig.Insecure, catalogConfig.MaxCacheAge.Duration, catalogConfig.UseAdminAuth, authOpt) case NoOpDiscoveryType, "": return NOOPCatalog{}, nil } diff --git a/pkg/controller/nodes/task/catalog/config_flags.go b/pkg/controller/nodes/task/catalog/config_flags.go index 97e6b9a5d7..004af8fe30 100755 --- a/pkg/controller/nodes/task/catalog/config_flags.go +++ b/pkg/controller/nodes/task/catalog/config_flags.go @@ -54,5 +54,6 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "endpoint"), defaultConfig.Endpoint, " Endpoint for catalog service") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "insecure"), defaultConfig.Insecure, " Use insecure grpc connection") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "max-cache-age"), defaultConfig.MaxCacheAge.String(), " Cache entries past this age will incur cache miss. 0 means cache never expires") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "use-admin-auth"), defaultConfig.UseAdminAuth, " Use the same gRPC credentials option as the flyteadmin client") return cmdFlags } diff --git a/pkg/controller/nodes/task/catalog/config_flags_test.go b/pkg/controller/nodes/task/catalog/config_flags_test.go index e6dc041166..21bf7d253b 100755 --- a/pkg/controller/nodes/task/catalog/config_flags_test.go +++ b/pkg/controller/nodes/task/catalog/config_flags_test.go @@ -155,4 +155,18 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_use-admin-auth", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("use-admin-auth", testValue) + if vBool, err := cmdFlags.GetBool("use-admin-auth"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.UseAdminAuth) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) } diff --git a/pkg/controller/nodes/task/catalog/datacatalog/datacatalog.go b/pkg/controller/nodes/task/catalog/datacatalog/datacatalog.go index f65dd33c99..7fea673991 100644 --- a/pkg/controller/nodes/task/catalog/datacatalog/datacatalog.go +++ b/pkg/controller/nodes/task/catalog/datacatalog/datacatalog.go @@ -351,8 +351,11 @@ func (m *CatalogClient) ReleaseReservation(ctx context.Context, key catalog.Key, } // Create a new Datacatalog client for task execution caching -func NewDataCatalog(ctx context.Context, endpoint string, insecureConnection bool, maxCacheAge time.Duration) (*CatalogClient, error) { +func NewDataCatalog(ctx context.Context, endpoint string, insecureConnection bool, maxCacheAge time.Duration, useAdminAuth bool, authOpt grpc.DialOption) (*CatalogClient, error) { var opts []grpc.DialOption + if useAdminAuth && authOpt != nil { + opts = append(opts, authOpt) + } grpcOptions := []grpcRetry.CallOption{ grpcRetry.WithBackoff(grpcRetry.BackoffLinear(100 * time.Millisecond)), diff --git a/pkg/controller/workflow/executor_test.go b/pkg/controller/workflow/executor_test.go index 64dd5afe09..cde9493beb 100644 --- a/pkg/controller/workflow/executor_test.go +++ b/pkg/controller/workflow/executor_test.go @@ -238,7 +238,7 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Error(t *testing.T) { enqueueWorkflow := func(workflowId v1alpha1.WorkflowID) {} eventSink := eventMocks.NewMockEventSink() - catalogClient, err := catalog.NewCatalogClient(ctx) + catalogClient, err := catalog.NewCatalogClient(ctx, nil) assert.NoError(t, err) recoveryClient := &recoveryMocks.RecoveryClient{} @@ -318,7 +318,7 @@ func TestWorkflowExecutor_HandleFlyteWorkflow(t *testing.T) { enqueueWorkflow := func(workflowId v1alpha1.WorkflowID) {} eventSink := eventMocks.NewMockEventSink() - catalogClient, err := catalog.NewCatalogClient(ctx) + catalogClient, err := catalog.NewCatalogClient(ctx, nil) assert.NoError(t, err) recoveryClient := &recoveryMocks.RecoveryClient{} @@ -382,7 +382,7 @@ func BenchmarkWorkflowExecutor(b *testing.B) { enqueueWorkflow := func(workflowId v1alpha1.WorkflowID) {} eventSink := eventMocks.NewMockEventSink() - catalogClient, err := catalog.NewCatalogClient(ctx) + catalogClient, err := catalog.NewCatalogClient(ctx, nil) assert.NoError(b, err) recoveryClient := &recoveryMocks.RecoveryClient{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() @@ -492,7 +492,7 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Failing(t *testing.T) { } return nil } - catalogClient, err := catalog.NewCatalogClient(ctx) + catalogClient, err := catalog.NewCatalogClient(ctx, nil) assert.NoError(t, err) recoveryClient := &recoveryMocks.RecoveryClient{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() @@ -588,7 +588,7 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Events(t *testing.T) { } return nil } - catalogClient, err := catalog.NewCatalogClient(ctx) + catalogClient, err := catalog.NewCatalogClient(ctx, nil) assert.NoError(t, err) adminClient := launchplan.NewFailFastLaunchPlanExecutor() recoveryClient := &recoveryMocks.RecoveryClient{} @@ -645,7 +645,7 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_EventFailure(t *testing.T) { assert.NoError(t, err) nodeEventSink := eventMocks.NewMockEventSink() - catalogClient, err := catalog.NewCatalogClient(ctx) + catalogClient, err := catalog.NewCatalogClient(ctx, nil) assert.NoError(t, err) recoveryClient := &recoveryMocks.RecoveryClient{}