From 11abcc0a40e0999b0b783073f226fed0940ff090 Mon Sep 17 00:00:00 2001 From: Dan Rammer Date: Fri, 3 Dec 2021 17:33:03 -0600 Subject: [PATCH 1/7] Scale out with propeller manager and workflow sharding (#351) * added 'manager' command Signed-off-by: Daniel Rammer * using go routine and timer for manager loop Signed-off-by: Daniel Rammer * moved manager loop out of cmd and into pkg directory Signed-off-by: Daniel Rammer * detecting missing replicas Signed-off-by: Daniel Rammer * moved extracting replica from pod name to new function Signed-off-by: Daniel Rammer * creating managed flytepropeller pods Signed-off-by: Daniel Rammer * refactored configuration Signed-off-by: Daniel Rammer * removed regex parsing for replica - checking for existance with fully qualified pod name Signed-off-by: Daniel Rammer * mocked out shard strategy abstraction Signed-off-by: Daniel Rammer * adding arguments to podspec for ConsistentHashingShardStrategy Signed-off-by: Daniel Rammer * updated import naming Signed-off-by: Daniel Rammer * moved manager to a top-level package Signed-off-by: Daniel Rammer * added shard strategy to manager configuration Signed-off-by: Daniel Rammer * setting shard key label selector on managed propeller instances Signed-off-by: Daniel Rammer * fixed random lint issues Signed-off-by: Daniel Rammer * split pod name generate to separate function to ease future auto-scaler implementation Signed-off-by: Daniel Rammer * cleaned up pod label selector Signed-off-by: Daniel Rammer * delete pods on shutdown Signed-off-by: Daniel Rammer * added prometheus metric reporting Signed-off-by: Daniel Rammer * updated manager run loop to use k8s wait.UntilWithContext Signed-off-by: Daniel Rammer * moved getKubeConfig into a shared package Signed-off-by: Daniel Rammer * assigning shard and namespace labels on FlyteWorkflow Signed-off-by: Daniel Rammer * implement NamespaceShardStrategy Signed-off-by: Daniel Rammer * implemented NamespaceShardStrategy Signed-off-by: Daniel Rammer * fixed shard label Signed-off-by: Daniel Rammer * added comments Signed-off-by: Daniel Rammer * checking for existing pods on startup Signed-off-by: Daniel Rammer * handling delete of non-existent pod Signed-off-by: Daniel Rammer * changes ConsistentHashing name to Random - because that's what it really is Signed-off-by: Daniel Rammer * implemented EnableUncoveredReplica configuration option Signed-off-by: Daniel Rammer * added leader election to manager using existing propeller config Signed-off-by: Daniel Rammer * fixed disable leader election in managed propeller pods Signed-off-by: Daniel Rammer * removed listPods function Signed-off-by: Daniel Rammer * added leader election to mitigate concurrent modification issues Signed-off-by: Daniel Rammer * enabled pprof to profile resource metrics Signed-off-by: Daniel Rammer * added 'manager' target to Makefile to start manager in development mode (similar to existing server) Signed-off-by: Daniel Rammer * added shard strategy test for computing key ranges Signed-off-by: Daniel Rammer * fixed key range computation Signed-off-by: Daniel Rammer * implemented project and domain shard types Signed-off-by: Daniel Rammer * returning error on out of range podIndex during UpdatePodSpec call on shard strategy Signed-off-by: Daniel Rammer * fixed random lint issues Signed-off-by: Daniel Rammer * added manager tests Signed-off-by: Daniel Rammer * fixed lint issues Signed-off-by: Daniel Rammer * added doc comments on exported types and functions Signed-off-by: Daniel Rammer * exporting ComputeKeyRange function and changed adding addLabelSelector function name to addLabelSelectorIfExists to better reflect functionality Signed-off-by: Daniel Rammer * adding pod template resource version and shard config hash annotations to fuel automatic pod management on updates Signed-off-by: Daniel Rammer * removed pod deletion on manager shutdown Signed-off-by: Daniel Rammer * cleaned up unit tests and lint Signed-off-by: Daniel Rammer * updated getContainer function to retrive flytepropeller container from pod spec using container name instead of command Signed-off-by: Daniel Rammer * removed addLabelSelectorIfExists function call Signed-off-by: Daniel Rammer * changed bytes.Buffer from a var to declaring with new Signed-off-by: Daniel Rammer * created a new shardstrategy package Signed-off-by: Daniel Rammer * generating mocks for ShardStrategy to decouple manager package tests from shardstrategy package tests Signed-off-by: Daniel Rammer * fixed lint issues Signed-off-by: Daniel Rammer * changed shard configuration defintions and added support for wildcard id in EnvironmentShardStrategy Signed-off-by: Daniel Rammer * updated documentation Signed-off-by: Daniel Rammer * fixed lint issues Signed-off-by: Daniel Rammer * setting managed pod owner references Signed-off-by: Daniel Rammer * updated documentation Signed-off-by: Daniel Rammer * fixed a few nits Signed-off-by: Daniel Rammer * delete pods with failed state Signed-off-by: Daniel Rammer * changed ShardType type to int instead of string Signed-off-by: Daniel Rammer * removed default values in manager config Signed-off-by: Daniel Rammer * updated config_flags with pflags generation Signed-off-by: Daniel Rammer Signed-off-by: Haytham Abuelfutuh --- Makefile | 8 + cmd/controller/cmd/init_certs.go | 3 +- cmd/controller/cmd/root.go | 58 ++-- cmd/controller/cmd/webhook.go | 3 +- cmd/manager/cmd/root.go | 202 ++++++++++++ cmd/manager/main.go | 9 + manager/config/config.go | 62 ++++ manager/config/config_flags.go | 61 ++++ manager/config/config_flags_test.go | 200 ++++++++++++ manager/config/doc.go | 4 + manager/config/shardtype_enumer.go | 86 +++++ manager/doc.go | 63 ++++ manager/manager.go | 295 ++++++++++++++++++ manager/manager_test.go | 184 +++++++++++ manager/shardstrategy/doc.go | 4 + manager/shardstrategy/environment.go | 62 ++++ manager/shardstrategy/hash.go | 72 +++++ manager/shardstrategy/hash_test.go | 25 ++ manager/shardstrategy/mocks/shard_strategy.go | 117 +++++++ manager/shardstrategy/shard_strategy.go | 98 ++++++ manager/shardstrategy/shard_strategy_test.go | 161 ++++++++++ pkg/controller/config/config.go | 6 + pkg/controller/config/config_flags.go | 6 + pkg/controller/config/config_flags_test.go | 84 +++++ pkg/controller/controller.go | 29 +- .../leader_election.go} | 4 +- pkg/utils/k8s.go | 67 ++++ 27 files changed, 1914 insertions(+), 59 deletions(-) create mode 100644 cmd/manager/cmd/root.go create mode 100644 cmd/manager/main.go create mode 100644 manager/config/config.go create mode 100755 manager/config/config_flags.go create mode 100755 manager/config/config_flags_test.go create mode 100644 manager/config/doc.go create mode 100644 manager/config/shardtype_enumer.go create mode 100644 manager/doc.go create mode 100644 manager/manager.go create mode 100644 manager/manager_test.go create mode 100644 manager/shardstrategy/doc.go create mode 100644 manager/shardstrategy/environment.go create mode 100644 manager/shardstrategy/hash.go create mode 100644 manager/shardstrategy/hash_test.go create mode 100644 manager/shardstrategy/mocks/shard_strategy.go create mode 100644 manager/shardstrategy/shard_strategy.go create mode 100644 manager/shardstrategy/shard_strategy_test.go rename pkg/{controller/leaderelection.go => leaderelection/leader_election.go} (94%) diff --git a/Makefile b/Makefile index 84938afdd..b2f078723 100644 --- a/Makefile +++ b/Makefile @@ -12,18 +12,21 @@ update_boilerplate: .PHONY: linux_compile linux_compile: GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o /artifacts/flytepropeller ./cmd/controller/main.go + GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o /artifacts/flytepropeller-manager ./cmd/manager/main.go GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o /artifacts/kubectl-flyte ./cmd/kubectl-flyte/main.go .PHONY: compile compile: mkdir -p ./bin go build -o bin/flytepropeller ./cmd/controller/main.go + go build -o bin/flytepropeller-manager ./cmd/manager/main.go go build -o bin/kubectl-flyte ./cmd/kubectl-flyte/main.go && cp bin/kubectl-flyte ${GOPATH}/bin cross_compile: @glide install @mkdir -p ./bin/cross GOOS=linux GOARCH=amd64 go build -o bin/cross/flytepropeller ./cmd/controller/main.go + GOOS=linux GOARCH=amd64 go build -o bin/cross/flytepropeller-manager ./cmd/manager/main.go GOOS=linux GOARCH=amd64 go build -o bin/cross/kubectl-flyte ./cmd/kubectl-flyte/main.go op_code_generate: @@ -38,6 +41,11 @@ benchmark: server: @go run ./cmd/controller/main.go --alsologtostderr --propeller.kube-config=$(HOME)/.kube/config +# manager starts the manager service in development mode +.PHONY: manager +manager: + @go run ./cmd/manager/main.go --alsologtostderr --propeller.kube-config=$(HOME)/.kube/config + clean: rm -rf bin diff --git a/cmd/controller/cmd/init_certs.go b/cmd/controller/cmd/init_certs.go index 101181580..9e2167729 100644 --- a/cmd/controller/cmd/init_certs.go +++ b/cmd/controller/cmd/init_certs.go @@ -11,6 +11,7 @@ import ( kubeErrors "k8s.io/apimachinery/pkg/api/errors" "github.com/flyteorg/flytepropeller/pkg/controller/config" + "github.com/flyteorg/flytepropeller/pkg/utils" corev1 "k8s.io/api/core/v1" v12 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -84,7 +85,7 @@ func runCertsCmd(ctx context.Context, propellerCfg *config.Config, cfg *webhookC return err } - kubeClient, _, err := getKubeConfig(ctx, propellerCfg) + kubeClient, _, err := utils.GetKubeConfig(ctx, propellerCfg) if err != nil { return err } diff --git a/cmd/controller/cmd/root.go b/cmd/controller/cmd/root.go index 1236d68dc..497d4680e 100644 --- a/cmd/controller/cmd/root.go +++ b/cmd/controller/cmd/root.go @@ -11,6 +11,7 @@ import ( "github.com/flyteorg/flytestdlib/contextutils" + transformers "github.com/flyteorg/flytepropeller/pkg/compiler/transformers/k8s" "github.com/flyteorg/flytepropeller/pkg/controller/executors" "k8s.io/klog" @@ -27,20 +28,15 @@ import ( "github.com/flyteorg/flytestdlib/logger" "github.com/flyteorg/flytestdlib/profutils" "github.com/flyteorg/flytestdlib/promutils" - "github.com/pkg/errors" "github.com/spf13/pflag" "github.com/spf13/cobra" - "k8s.io/client-go/kubernetes" - "k8s.io/client-go/tools/clientcmd" - - restclient "k8s.io/client-go/rest" - clientset "github.com/flyteorg/flytepropeller/pkg/client/clientset/versioned" informers "github.com/flyteorg/flytepropeller/pkg/client/informers/externalversions" "github.com/flyteorg/flytepropeller/pkg/controller" "github.com/flyteorg/flytepropeller/pkg/signals" + "github.com/flyteorg/flytepropeller/pkg/utils" ) const ( @@ -116,39 +112,39 @@ func logAndExit(err error) { os.Exit(-1) } -func getKubeConfig(_ context.Context, cfg *config2.Config) (*kubernetes.Clientset, *restclient.Config, error) { - var kubecfg *restclient.Config - var err error - if cfg.KubeConfigPath != "" { - kubeConfigPath := os.ExpandEnv(cfg.KubeConfigPath) - kubecfg, err = clientcmd.BuildConfigFromFlags(cfg.MasterURL, kubeConfigPath) - if err != nil { - return nil, nil, errors.Wrapf(err, "Error building kubeconfig") - } - } else { - kubecfg, err = restclient.InClusterConfig() - if err != nil { - return nil, nil, errors.Wrapf(err, "Cannot get InCluster kubeconfig") - } +func sharedInformerOptions(cfg *config2.Config) []informers.SharedInformerOption { + selectors := []struct { + label string + operation v1.LabelSelectorOperator + values []string + }{ + {transformers.ShardKeyLabel, v1.LabelSelectorOpIn, cfg.IncludeShardKeyLabel}, + {transformers.ShardKeyLabel, v1.LabelSelectorOpNotIn, cfg.ExcludeShardKeyLabel}, + {transformers.ProjectLabel, v1.LabelSelectorOpIn, cfg.IncludeProjectLabel}, + {transformers.ProjectLabel, v1.LabelSelectorOpNotIn, cfg.ExcludeProjectLabel}, + {transformers.DomainLabel, v1.LabelSelectorOpIn, cfg.IncludeDomainLabel}, + {transformers.DomainLabel, v1.LabelSelectorOpNotIn, cfg.ExcludeDomainLabel}, } - kubecfg.QPS = cfg.KubeConfig.QPS - kubecfg.Burst = cfg.KubeConfig.Burst - kubecfg.Timeout = cfg.KubeConfig.Timeout.Duration + labelSelector := controller.IgnoreCompletedWorkflowsLabelSelector() + for _, selector := range selectors { + if len(selector.values) > 0 { + labelSelectorRequirement := v1.LabelSelectorRequirement{ + Key: selector.label, + Operator: selector.operation, + Values: selector.values, + } - kubeClient, err := kubernetes.NewForConfig(kubecfg) - if err != nil { - return nil, nil, errors.Wrapf(err, "Error building kubernetes clientset") + labelSelector.MatchExpressions = append(labelSelector.MatchExpressions, labelSelectorRequirement) + } } - return kubeClient, kubecfg, err -} -func sharedInformerOptions(cfg *config2.Config) []informers.SharedInformerOption { opts := []informers.SharedInformerOption{ informers.WithTweakListOptions(func(options *v1.ListOptions) { - options.LabelSelector = v1.FormatLabelSelector(controller.IgnoreCompletedWorkflowsLabelSelector()) + options.LabelSelector = v1.FormatLabelSelector(labelSelector) }), } + if cfg.LimitNamespace != defaultNamespace { opts = append(opts, informers.WithNamespace(cfg.LimitNamespace)) } @@ -166,7 +162,7 @@ func executeRootCmd(cfg *config2.Config) { // set up signals so we handle the first shutdown signal gracefully ctx := signals.SetupSignalHandler(baseCtx) - kubeClient, kubecfg, err := getKubeConfig(ctx, cfg) + kubeClient, kubecfg, err := utils.GetKubeConfig(ctx, cfg) if err != nil { logger.Fatalf(ctx, "Error building kubernetes clientset: %s", err.Error()) } diff --git a/cmd/controller/cmd/webhook.go b/cmd/controller/cmd/webhook.go index 3af087e7b..40c039e05 100644 --- a/cmd/controller/cmd/webhook.go +++ b/cmd/controller/cmd/webhook.go @@ -17,6 +17,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/signals" + "github.com/flyteorg/flytepropeller/pkg/utils" "github.com/flyteorg/flytepropeller/pkg/webhook" "github.com/flyteorg/flytestdlib/logger" "github.com/flyteorg/flytestdlib/profutils" @@ -105,7 +106,7 @@ func runWebhook(origContext context.Context, propellerCfg *config.Config, cfg *w fmt.Println(string(raw)) - kubeClient, kubecfg, err := getKubeConfig(ctx, propellerCfg) + kubeClient, kubecfg, err := utils.GetKubeConfig(ctx, propellerCfg) if err != nil { return err } diff --git a/cmd/manager/cmd/root.go b/cmd/manager/cmd/root.go new file mode 100644 index 000000000..fc3da4af6 --- /dev/null +++ b/cmd/manager/cmd/root.go @@ -0,0 +1,202 @@ +// Commands for FlytePropeller manager. +package cmd + +import ( + "context" + "flag" + "os" + "runtime" + + "github.com/flyteorg/flytestdlib/config" + "github.com/flyteorg/flytestdlib/config/viper" + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/profutils" + "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/version" + + "github.com/flyteorg/flytepropeller/manager" + managerConfig "github.com/flyteorg/flytepropeller/manager/config" + propellerConfig "github.com/flyteorg/flytepropeller/pkg/controller/config" + "github.com/flyteorg/flytepropeller/pkg/signals" + "github.com/flyteorg/flytepropeller/pkg/utils" + + "github.com/spf13/cobra" + "github.com/spf13/pflag" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/klog" +) + +const ( + appName = "flytepropeller-manager" + podDefaultNamespace = "flyte" + podNameEnvVar = "POD_NAME" + podNamespaceEnvVar = "POD_NAMESPACE" +) + +var ( + cfgFile string + configAccessor = viper.NewAccessor(config.Options{StrictMode: true}) +) + +// rootCmd represents the base command when called without any subcommands +var rootCmd = &cobra.Command{ + Use: appName, + Short: "Runs FlytePropeller Manager to scale out FlytePropeller by executing multiple instances configured according to the defined sharding scheme.", + Long: ` +FlytePropeller Manager is used to effectively scale out FlyteWorkflow processing among a collection of FlytePropeller instances. Users configure a sharding mechanism (ex. 'hash', 'project', or 'domain') to define the sharding environment. + +The FlytePropeller Manager uses a kubernetes PodTemplate to construct the base FlytePropeller PodSpec. This means, apart from the configured sharding scheme, all managed FlytePropeller instances will be identical. + +The Manager ensures liveness and correctness by periodically scanning kubernets pods and recovering state (ie. starting missing pods, etc). Live configuration updates are currently unsupported, meaning configuration changes require an application restart. + +Sample configuration, illustrating 3 separate sharding techniques, is provided below: + + manager: + pod-application: "flytepropeller" + pod-namespace: "flyte" + pod-template-name: "flytepropeller-template" + pod-template-namespace: "flyte" + scan-interval: 10s + shard: + # distribute FlyteWorkflow processing over 3 machines evenly + type: hash + pod-count: 3 + + # process the specified projects on defined replicas and all uncovered projects on another + type: project + enableUncoveredReplica: true + replicas: + - entities: + - flytesnacks + - entities: + - flyteexamples + - flytelab + + # process the 'production' domain on a single instace and all other domains on another + type: domain + enableUncoveredReplica: true + replicas: + - entities: + - production + `, + PersistentPreRunE: initConfig, + Run: func(cmd *cobra.Command, args []string) { + executeRootCmd(propellerConfig.GetConfig(), managerConfig.GetConfig()) + }, +} + +// Execute adds all child commands to the root command and sets flags appropriately. +// This is called by main.main(). It only needs to happen once to the rootCmd. +func Execute() { + version.LogBuildInformation(appName) + logger.Infof(context.TODO(), "detected %d CPU's\n", runtime.NumCPU()) + if err := rootCmd.Execute(); err != nil { + logger.Error(context.TODO(), err) + os.Exit(1) + } +} + +func init() { + // allows `$ flytepropeller-manager --logtostderr` to work + klog.InitFlags(flag.CommandLine) + pflag.CommandLine.AddGoFlagSet(flag.CommandLine) + err := flag.CommandLine.Parse([]string{}) + if err != nil { + logAndExit(err) + } + + // Here you will define your flags and configuration settings. Cobra supports persistent flags, which, if defined + // here, will be global for your application. + rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", + "config file (default is $HOME/config.yaml)") + + configAccessor.InitializePflags(rootCmd.PersistentFlags()) + + rootCmd.AddCommand(viper.GetConfigCommand()) +} + +func initConfig(cmd *cobra.Command, _ []string) error { + configAccessor = viper.NewAccessor(config.Options{ + StrictMode: false, + SearchPaths: []string{cfgFile}, + }) + + configAccessor.InitializePflags(cmd.PersistentFlags()) + + err := configAccessor.UpdateConfig(context.TODO()) + if err != nil { + return err + } + + return nil +} + +func logAndExit(err error) { + logger.Error(context.Background(), err) + os.Exit(-1) +} + +func executeRootCmd(propellerCfg *propellerConfig.Config, cfg *managerConfig.Config) { + baseCtx := context.Background() + + // set up signals so we handle the first shutdown signal gracefully + ctx := signals.SetupSignalHandler(baseCtx) + + // lookup owner reference + kubeClient, _, err := utils.GetKubeConfig(ctx, propellerCfg) + if err != nil { + logger.Fatalf(ctx, "error building kubernetes clientset [%v]", err) + } + + ownerReferences := make([]metav1.OwnerReference, 0) + lookupOwnerReferences := true + podName, found := os.LookupEnv(podNameEnvVar) + if !found { + lookupOwnerReferences = false + } + + podNamespace, found := os.LookupEnv(podNamespaceEnvVar) + if !found { + lookupOwnerReferences = false + podNamespace = podDefaultNamespace + } + + if lookupOwnerReferences { + p, err := kubeClient.CoreV1().Pods(podNamespace).Get(ctx, podName, metav1.GetOptions{}) + if err != nil { + logger.Fatalf(ctx, "failed to get pod '%v' in namespace '%v' [%v]", podName, podNamespace, err) + } + + for _, ownerReference := range p.OwnerReferences { + // must set owner reference controller to false because k8s does not allow setting pod + // owner references to a controller that does not acknowledge ownership. in this case + // the owner is technically the FlytePropeller Manager pod and not that pods owner. + *ownerReference.BlockOwnerDeletion = false + *ownerReference.Controller = false + + ownerReferences = append(ownerReferences, ownerReference) + } + } + + // Add the propeller_manager subscope because the MetricsPrefix only has "flyte:" to get uniform collection of metrics. + scope := promutils.NewScope(propellerCfg.MetricsPrefix).NewSubScope("propeller_manager") + + go func() { + err := profutils.StartProfilingServerWithDefaultHandlers(ctx, propellerCfg.ProfilerPort.Port, nil) + if err != nil { + logger.Panicf(ctx, "failed to start profiling and metrics server [%v]", err) + } + }() + + m, err := manager.New(ctx, propellerCfg, cfg, podNamespace, ownerReferences, kubeClient, scope) + if err != nil { + logger.Fatalf(ctx, "failed to start manager [%v]", err) + } else if m == nil { + logger.Fatalf(ctx, "failed to start manager, nil manager received") + } + + if err = m.Run(ctx); err != nil { + logger.Fatalf(ctx, "error running manager [%v]", err) + } +} diff --git a/cmd/manager/main.go b/cmd/manager/main.go new file mode 100644 index 000000000..9ced29741 --- /dev/null +++ b/cmd/manager/main.go @@ -0,0 +1,9 @@ +package main + +import ( + "github.com/flyteorg/flytepropeller/cmd/manager/cmd" +) + +func main() { + cmd.Execute() +} diff --git a/manager/config/config.go b/manager/config/config.go new file mode 100644 index 000000000..d6bc21ac1 --- /dev/null +++ b/manager/config/config.go @@ -0,0 +1,62 @@ +package config + +import ( + "time" + + "github.com/flyteorg/flytestdlib/config" +) + +//go:generate pflags Config --default-var=DefaultConfig +//go:generate enumer --type=ShardType --trimprefix=ShardType -json -yaml + +var ( + DefaultConfig = &Config{ + PodApplication: "flytepropeller", + PodTemplateContainerName: "flytepropeller", + PodTemplateName: "flytepropeller-template", + PodTemplateNamespace: "flyte", + ScanInterval: config.Duration{ + Duration: 10 * time.Second, + }, + ShardConfig: ShardConfig{ + Type: ShardTypeHash, + ShardCount: 3, + }, + } + + configSection = config.MustRegisterSection("manager", DefaultConfig) +) + +type ShardType int + +const ( + ShardTypeDomain ShardType = iota + ShardTypeProject + ShardTypeHash +) + +// Configuration for defining shard replicas when using project or domain shard types +type PerShardMappingsConfig struct { + IDs []string `json:"ids" pflag:",The list of ids to be managed"` +} + +// Configuration for the FlytePropeller sharding strategy +type ShardConfig struct { + Type ShardType `json:"type" pflag:",Shard implementation to use"` + PerShardMappings []PerShardMappingsConfig `json:"per-shard-mapping" pflag:"-"` + ShardCount int `json:"shard-count" pflag:",The number of shards to manage for a 'hash' shard type"` +} + +// Configuration for the FlytePropeller Manager instance +type Config struct { + PodApplication string `json:"pod-application" pflag:",Application name for managed pods"` + PodTemplateContainerName string `json:"pod-template-container-name" pflag:",The container name within the K8s PodTemplate name used to set FlyteWorkflow CRD labels selectors"` + PodTemplateName string `json:"pod-template-name" pflag:",K8s PodTemplate name to use for starting FlytePropeller pods"` + PodTemplateNamespace string `json:"pod-template-namespace" pflag:",Namespace where the k8s PodTemplate is located"` + ScanInterval config.Duration `json:"scan-interval" pflag:",Frequency to scan FlytePropeller pods and start / restart if necessary"` + ShardConfig ShardConfig `json:"shard" pflag:",Configure the shard strategy for this manager"` +} + +func GetConfig() *Config { + return configSection.GetConfig().(*Config) +} diff --git a/manager/config/config_flags.go b/manager/config/config_flags.go new file mode 100755 index 000000000..0e143f881 --- /dev/null +++ b/manager/config/config_flags.go @@ -0,0 +1,61 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package config + +import ( + "encoding/json" + "reflect" + + "fmt" + + "github.com/spf13/pflag" +) + +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func (Config) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + +func (Config) mustJsonMarshal(v interface{}) string { + raw, err := json.Marshal(v) + if err != nil { + panic(err) + } + + return string(raw) +} + +func (Config) mustMarshalJSON(v json.Marshaler) string { + raw, err := v.MarshalJSON() + if err != nil { + panic(err) + } + + return string(raw) +} + +// GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "pod-application"), DefaultConfig.PodApplication, "Application name for managed pods") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "pod-template-container-name"), DefaultConfig.PodTemplateContainerName, "The container name within the K8s PodTemplate name used to set FlyteWorkflow CRD labels selectors") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "pod-template-name"), DefaultConfig.PodTemplateName, "K8s PodTemplate name to use for starting FlytePropeller pods") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "pod-template-namespace"), DefaultConfig.PodTemplateNamespace, "Namespace where the k8s PodTemplate is located") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "scan-interval"), DefaultConfig.ScanInterval.String(), "Frequency to scan FlytePropeller pods and start / restart if necessary") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "shard.type"), DefaultConfig.ShardConfig.Type.String(), "Shard implementation to use") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "shard.shard-count"), DefaultConfig.ShardConfig.ShardCount, "The number of shards to manage for a 'hash' shard type") + return cmdFlags +} diff --git a/manager/config/config_flags_test.go b/manager/config/config_flags_test.go new file mode 100755 index 000000000..887452276 --- /dev/null +++ b/manager/config/config_flags_test.go @@ -0,0 +1,200 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package config + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsConfig = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementConfig(t reflect.Kind) bool { + _, exists := dereferencableKindsConfig[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookConfig(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementConfig(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_Config(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookConfig, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_Config(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_Config(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_Config(val, result)) +} + +func testDecodeRaw_Config(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_Config(vStringSlice, result)) +} + +func TestConfig_GetPFlagSet(t *testing.T) { + val := Config{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestConfig_SetFlags(t *testing.T) { + actual := Config{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_pod-application", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("pod-application", testValue) + if vString, err := cmdFlags.GetString("pod-application"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.PodApplication) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_pod-template-container-name", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("pod-template-container-name", testValue) + if vString, err := cmdFlags.GetString("pod-template-container-name"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.PodTemplateContainerName) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_pod-template-name", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("pod-template-name", testValue) + if vString, err := cmdFlags.GetString("pod-template-name"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.PodTemplateName) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_pod-template-namespace", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("pod-template-namespace", testValue) + if vString, err := cmdFlags.GetString("pod-template-namespace"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.PodTemplateNamespace) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_scan-interval", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := DefaultConfig.ScanInterval.String() + + cmdFlags.Set("scan-interval", testValue) + if vString, err := cmdFlags.GetString("scan-interval"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.ScanInterval) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_shard.type", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("shard.type", testValue) + if vString, err := cmdFlags.GetString("shard.type"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.ShardConfig.Type) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_shard.shard-count", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("shard.shard-count", testValue) + if vInt, err := cmdFlags.GetInt("shard.shard-count"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.ShardConfig.ShardCount) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/manager/config/doc.go b/manager/config/doc.go new file mode 100644 index 000000000..2930e72a9 --- /dev/null +++ b/manager/config/doc.go @@ -0,0 +1,4 @@ +/* +Package config details configuration data structures for the FlytePropeller Manager implementation. +*/ +package config diff --git a/manager/config/shardtype_enumer.go b/manager/config/shardtype_enumer.go new file mode 100644 index 000000000..78ae145d8 --- /dev/null +++ b/manager/config/shardtype_enumer.go @@ -0,0 +1,86 @@ +// Code generated by "enumer --type=ShardType --trimprefix=ShardType -json -yaml"; DO NOT EDIT. + +// +package config + +import ( + "encoding/json" + "fmt" +) + +const _ShardTypeName = "DomainProjectHash" + +var _ShardTypeIndex = [...]uint8{0, 6, 13, 17} + +func (i ShardType) String() string { + if i < 0 || i >= ShardType(len(_ShardTypeIndex)-1) { + return fmt.Sprintf("ShardType(%d)", i) + } + return _ShardTypeName[_ShardTypeIndex[i]:_ShardTypeIndex[i+1]] +} + +var _ShardTypeValues = []ShardType{0, 1, 2} + +var _ShardTypeNameToValueMap = map[string]ShardType{ + _ShardTypeName[0:6]: 0, + _ShardTypeName[6:13]: 1, + _ShardTypeName[13:17]: 2, +} + +// ShardTypeString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func ShardTypeString(s string) (ShardType, error) { + if val, ok := _ShardTypeNameToValueMap[s]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to ShardType values", s) +} + +// ShardTypeValues returns all values of the enum +func ShardTypeValues() []ShardType { + return _ShardTypeValues +} + +// IsAShardType returns "true" if the value is listed in the enum definition. "false" otherwise +func (i ShardType) IsAShardType() bool { + for _, v := range _ShardTypeValues { + if i == v { + return true + } + } + return false +} + +// MarshalJSON implements the json.Marshaler interface for ShardType +func (i ShardType) MarshalJSON() ([]byte, error) { + return json.Marshal(i.String()) +} + +// UnmarshalJSON implements the json.Unmarshaler interface for ShardType +func (i *ShardType) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return fmt.Errorf("ShardType should be a string, got %s", data) + } + + var err error + *i, err = ShardTypeString(s) + return err +} + +// MarshalYAML implements a YAML Marshaler for ShardType +func (i ShardType) MarshalYAML() (interface{}, error) { + return i.String(), nil +} + +// UnmarshalYAML implements a YAML Unmarshaler for ShardType +func (i *ShardType) UnmarshalYAML(unmarshal func(interface{}) error) error { + var s string + if err := unmarshal(&s); err != nil { + return err + } + + var err error + *i, err = ShardTypeString(s) + return err +} diff --git a/manager/doc.go b/manager/doc.go new file mode 100644 index 000000000..2afdcb148 --- /dev/null +++ b/manager/doc.go @@ -0,0 +1,63 @@ +/* +Package manager introduces a FlytePropeller Manager implementation that enables horizontal scaling of FlytePropeller by sharding FlyteWorkflows. + +The FlytePropeller Manager manages a collection of FlytePropeller instances to effectively distribute load. Each managed FlytePropller instance is created as a k8s pod using a configurable k8s PodTemplate resource. The FlytePropeller Manager use a control loop to periodically check the status of managed FlytePropeller instances and creates, updates, or deletes pods as required. It is important to note that if the FlytePropeller Manager fails, managed instances are left running. This is in effort to ensure progress continues in evaluating FlyteWorkflow CRDs. + +FlytePropeller Manager is configured at the root of the FlytePropeller configurtion. Below is an example of the variety of configuration options along with succinct associated descriptions for each field: + + manager: + pod-application: "flytepropeller" # application name for managed pods + pod-template-container-name: "flytepropeller" # the container name within the K8s PodTemplate name used to set FlyteWorkflow CRD labels selectors + pod-template-name: "flytepropeller-template" # k8s PodTemplate name to use for starting FlytePropeller pods + pod-template-namespace: "flyte" # namespace where the k8s PodTemplate is located + scan-interval: 10s # frequency to scan FlytePropeller pods and start / restart if necessary + shard: # configure sharding strategy + # shard configuration redacted + +FlytePropeller Manager handles dynamic updates to both the k8s PodTemplate and shard configuration. The k8s PodTemplate resource has an associated resource version which uniquely identifies changes. Additionally, shard configuration modifications may be tracked using a simple hash. Flyte stores these values as annotations on managed FlytePropeller instances. Therefore, if either of there values change the FlytePropeller Manager instance will detect it and perform the necessary deployment updates. + +Shard Strategies + +Flyte defines a variety of Shard Strategies for configuring how FlyteWorkflows are sharded. These options may include the shard type (ex. hash, project, or domain) along with the number of shards or the distribution of project / domain IDs over shards. + +Internally, FlyteWorkflow CRDs are initialized with k8s labels for project, domain, and a shard-key. The project and domain label values are associated with the environment of the registered workflow. The shard-key value is a range-bounded hash over various components of the FlyteWorkflow metadata, currently the keyspace range is defined as [0,32). A sharded Flyte deployment ensures deterministic FlyteWorkflow evalutions by setting disjoint k8s label selectors, based on the aforementioned labels, on each managed FlytePropeller instance. This ensures that only a single FlytePropeller instance is responsible for processing each FlyteWorkflow. + +The Hash Shard Strategy, denoted by "type: hash" in the configuration below, uses consistent hashing to evenly distribute FlyteWorkflows over managed FlytePropeller instances. This is achieved by partitioning the keyspace (i.e. [0,32)) into a collection of disjoint ranges and using label selectors to assign those ranges to managed FlytePropeller instances. For example, with "shard-count: 4" the first instance is responsible for FlyteWorkflows with "shard-keys" in the range [0,8), the second [8,16), the third [16,24), and the fourth [24,32). It may be useful to note that the default shard type is "hash", so it will be implicitly defined if otherwise left out of the configuration. An example configuration for the Hash Shard Strategy is provided below: + + # a configuration example using the "hash" shard type + manager: + # pod and scanning configuration redacted + shard: + type: hash # use the "hash" shard strategy + shard-count: 4 # the total number of shards + +The Project and Domain Shard Strategies, denoted by "type: project" and "type: domain" respectively, use the FlyteWorkflow project and domain metadata to distributed FlyteWorkflows over managed FlytePropeller instances. These Shard Strategies are configured using a "per-shard-mapping" option, which is a list of ID lists. Each element in the "per-shard-mapping" list defines a new shard and the ID list assigns responsibility for the specified IDs to that shard. The assignment is performed using k8s label selectors, where each managed FlytePropeller instance includes FlyteWorkflows with the specified project or domain labels. + +A shard configured as a single wildcard ID (i.e. "*") is responsible for all IDs that are not covered by other shards. Only a single shard may be configured with a wildcard ID and on that shard their must be only one ID, namely the wildcard. In this case, the managed FlytePropeller instance uses k8s label selectors to exclude FlyteWorkflows with project or domain IDs from other shards. + + # a configuration example using the "project" shard type + manager: + # pod and scanning configuration redacted + shard: + type: project # use the "project" shard strategy + per-shard-mapping: # a list of per shard mappings - one shard is created for each element + - ids: # the list of ids to be managed by the first shard + - flytesnacks + - ids: # the list of ids to be managed by the second shard + - flyteexamples + - flytelabs + - ids: # the list of ids to be managed by the third shard + - "*" # use the wildcard to manage all ids not managed by other shards + + # a configuration example using the "domain" shard type + manager: + # pod and scanning configuration redacted + shard: + type: domain # use the "domain" shard strategy + per-shard-mapping: # a list of per shard mappings - one shard is created for each element + - ids: # the list of ids to be managed by the first shard + - production + - ids: # the list of ids to be managed by the second shard + - "*" # use the wildcard to manage all ids not managed by other shards +*/ +package manager diff --git a/manager/manager.go b/manager/manager.go new file mode 100644 index 000000000..fd3cc6e9b --- /dev/null +++ b/manager/manager.go @@ -0,0 +1,295 @@ +package manager + +import ( + "context" + "fmt" + "time" + + managerConfig "github.com/flyteorg/flytepropeller/manager/config" + "github.com/flyteorg/flytepropeller/manager/shardstrategy" + propellerConfig "github.com/flyteorg/flytepropeller/pkg/controller/config" + leader "github.com/flyteorg/flytepropeller/pkg/leaderelection" + "github.com/flyteorg/flytepropeller/pkg/utils" + + stderrors "github.com/flyteorg/flytestdlib/errors" + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/promutils" + + "github.com/prometheus/client_golang/prometheus" + + v1 "k8s.io/api/core/v1" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/util/wait" + + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/tools/leaderelection" +) + +const ( + podTemplateResourceVersion = "podTemplateResourceVersion" + shardConfigHash = "shardConfigHash" +) + +type metrics struct { + Scope promutils.Scope + RoundTime promutils.StopWatch + PodsCreated prometheus.Counter + PodsDeleted prometheus.Counter + PodsRunning prometheus.Gauge +} + +func newManagerMetrics(scope promutils.Scope) *metrics { + return &metrics{ + Scope: scope, + RoundTime: scope.MustNewStopWatch("round_time", "Time to perform one round of validating managed pod status'", time.Millisecond), + PodsCreated: scope.MustNewCounter("pods_created_count", "Total number of pods created"), + PodsDeleted: scope.MustNewCounter("pods_deleted_count", "Total number of pods deleted"), + PodsRunning: scope.MustNewGauge("pods_running_count", "Number of managed pods currently running"), + } +} + +// Manager periodically scans k8s to ensure liveness of multiple FlytePropeller controller instances +// and rectifies state based on the configured sharding strategy. +type Manager struct { + kubeClient kubernetes.Interface + leaderElector *leaderelection.LeaderElector + metrics *metrics + ownerReferences []metav1.OwnerReference + podApplication string + podNamespace string + podTemplateContainerName string + podTemplateName string + podTemplateNamespace string + scanInterval time.Duration + shardStrategy shardstrategy.ShardStrategy +} + +func (m *Manager) createPods(ctx context.Context) error { + t := m.metrics.RoundTime.Start() + defer t.Stop() + + // retrieve pod metadata + podTemplate, err := m.kubeClient.CoreV1().PodTemplates(m.podTemplateNamespace).Get(ctx, m.podTemplateName, metav1.GetOptions{}) + if err != nil { + return fmt.Errorf("failed to retrieve pod template '%s' from namespace '%s' [%v]", m.podTemplateName, m.podTemplateNamespace, err) + } + + shardConfigHash, err := m.shardStrategy.HashCode() + if err != nil { + return err + } + + podAnnotations := map[string]string{ + "podTemplateResourceVersion": podTemplate.ObjectMeta.ResourceVersion, + "shardConfigHash": fmt.Sprintf("%d", shardConfigHash), + } + podNames := m.getPodNames() + podLabels := map[string]string{ + "app": m.podApplication, + } + + // disable leader election on all managed pods + container, err := utils.GetContainer(&podTemplate.Template.Spec, m.podTemplateContainerName) + if err != nil { + return fmt.Errorf("failed to retrieve flytepropeller container from pod template [%v]", err) + } + + container.Args = append(container.Args, "--propeller.leader-election.enabled=false") + + // retrieve existing pods + listOptions := metav1.ListOptions{ + LabelSelector: labels.SelectorFromSet(podLabels).String(), + } + + pods, err := m.kubeClient.CoreV1().Pods(m.podNamespace).List(ctx, listOptions) + if err != nil { + return err + } + + // note: we are unable to short-circuit if 'len(pods) == len(m.podNames)' because there may be + // unmanaged flytepropeller pods - which is invalid configuration but will be detected later + + // determine missing managed pods + podExists := make(map[string]bool) + for _, podName := range podNames { + podExists[podName] = false + } + + podsRunning := 0 + for _, pod := range pods.Items { + podName := pod.ObjectMeta.Name + + // validate existing pod annotations + deletePod := false + for key, value := range podAnnotations { + if pod.ObjectMeta.Annotations[key] != value { + logger.Infof(ctx, "detected pod '%s' with stale configuration", podName) + deletePod = true + break + } + } + + if pod.Status.Phase == v1.PodFailed { + logger.Warnf(ctx, "detected pod '%s' in 'failed' state", podName) + deletePod = true + } + + if deletePod { + err := m.kubeClient.CoreV1().Pods(m.podNamespace).Delete(ctx, podName, metav1.DeleteOptions{}) + if err != nil { + return err + } + + m.metrics.PodsDeleted.Inc() + logger.Infof(ctx, "deleted pod '%s'", podName) + continue + } + + // update podExists to track existing pods + if _, ok := podExists[podName]; ok { + podExists[podName] = true + + if pod.Status.Phase == v1.PodRunning { + podsRunning++ + } + } + } + + m.metrics.PodsRunning.Set(float64(podsRunning)) + + // create non-existent pods + errs := stderrors.ErrorCollection{} + for i, podName := range podNames { + if exists := podExists[podName]; !exists { + pod := &v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: podAnnotations, + Name: podName, + Namespace: m.podNamespace, + Labels: podLabels, + OwnerReferences: m.ownerReferences, + }, + Spec: *podTemplate.Template.Spec.DeepCopy(), + } + + err := m.shardStrategy.UpdatePodSpec(&pod.Spec, m.podTemplateContainerName, i) + if err != nil { + errs.Append(fmt.Errorf("failed to update pod spec for '%s' [%v]", podName, err)) + continue + } + + _, err = m.kubeClient.CoreV1().Pods(m.podNamespace).Create(ctx, pod, metav1.CreateOptions{}) + if err != nil { + errs.Append(fmt.Errorf("failed to create pod '%s' [%v]", podName, err)) + continue + } + + m.metrics.PodsCreated.Inc() + logger.Infof(ctx, "created pod '%s'", podName) + } + } + + return errs.ErrorOrDefault() +} + +func (m *Manager) getPodNames() []string { + podCount := m.shardStrategy.GetPodCount() + var podNames []string + for i := 0; i < podCount; i++ { + podNames = append(podNames, fmt.Sprintf("%s-%d", m.podApplication, i)) + } + + return podNames +} + +// Run starts the manager instance as either a k8s leader, if configured, or as a standalone process. +func (m *Manager) Run(ctx context.Context) error { + if m.leaderElector != nil { + logger.Infof(ctx, "running with leader election") + m.leaderElector.Run(ctx) + } else { + logger.Infof(ctx, "running without leader election") + if err := m.run(ctx); err != nil { + return err + } + } + + return nil +} + +func (m *Manager) run(ctx context.Context) error { + logger.Infof(ctx, "started manager") + wait.UntilWithContext(ctx, + func(ctx context.Context) { + logger.Debugf(ctx, "validating managed pod(s) state") + err := m.createPods(ctx) + if err != nil { + logger.Errorf(ctx, "failed to create pod(s) [%v]", err) + } + }, + m.scanInterval, + ) + + logger.Infof(ctx, "shutting down manager") + return nil +} + +// New creates a new FlytePropeller Manager instance. +func New(ctx context.Context, propellerCfg *propellerConfig.Config, cfg *managerConfig.Config, podNamespace string, ownerReferences []metav1.OwnerReference, kubeClient kubernetes.Interface, scope promutils.Scope) (*Manager, error) { + shardStrategy, err := shardstrategy.NewShardStrategy(ctx, cfg.ShardConfig) + if err != nil { + return nil, fmt.Errorf("failed to initialize shard strategy [%v]", err) + } + + manager := &Manager{ + kubeClient: kubeClient, + metrics: newManagerMetrics(scope), + ownerReferences: ownerReferences, + podApplication: cfg.PodApplication, + podNamespace: podNamespace, + podTemplateContainerName: cfg.PodTemplateContainerName, + podTemplateName: cfg.PodTemplateName, + podTemplateNamespace: cfg.PodTemplateNamespace, + scanInterval: cfg.ScanInterval.Duration, + shardStrategy: shardStrategy, + } + + // configure leader elector + eventRecorder, err := utils.NewK8sEventRecorder(ctx, kubeClient, "flytepropeller-manager", propellerCfg.PublishK8sEvents) + if err != nil { + return nil, fmt.Errorf("failed to initialize k8s event recorder [%v]", err) + } + + lock, err := leader.NewResourceLock(kubeClient.CoreV1(), kubeClient.CoordinationV1(), eventRecorder, propellerCfg.LeaderElection) + if err != nil { + return nil, fmt.Errorf("failed to initialize resource lock [%v]", err) + } + + if lock != nil { + logger.Infof(ctx, "creating leader elector for the controller") + manager.leaderElector, err = leader.NewLeaderElector( + lock, + propellerCfg.LeaderElection, + func(ctx context.Context) { + logger.Infof(ctx, "started leading") + if err := manager.run(ctx); err != nil { + logger.Error(ctx, err) + } + }, + func() { + // need to check if this elector obtained leadership until k8s client-go api is fixed. currently the + // OnStoppingLeader func is called as a defer on every elector run, regardless of election status. + if manager.leaderElector.IsLeader() { + logger.Info(ctx, "stopped leading") + } + }) + + if err != nil { + return nil, fmt.Errorf("failed to initialize leader elector [%v]", err) + } + } + + return manager, nil +} diff --git a/manager/manager_test.go b/manager/manager_test.go new file mode 100644 index 000000000..9eb831d8b --- /dev/null +++ b/manager/manager_test.go @@ -0,0 +1,184 @@ +package manager + +import ( + "context" + "fmt" + "testing" + + "github.com/flyteorg/flytestdlib/promutils" + + "github.com/flyteorg/flytepropeller/manager/shardstrategy" + "github.com/flyteorg/flytepropeller/manager/shardstrategy/mocks" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/kubernetes/fake" +) + +var ( + podTemplate = &v1.PodTemplate{ + ObjectMeta: metav1.ObjectMeta{ + ResourceVersion: "0", + }, + Template: v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + v1.Container{ + Command: []string{"flytepropeller"}, + Args: []string{"--config", "/etc/flyte/config/*.yaml"}, + }, + }, + }, + }, + } +) + +func createShardStrategy(podCount int) shardstrategy.ShardStrategy { + shardStrategy := mocks.ShardStrategy{} + shardStrategy.OnGetPodCount().Return(podCount) + shardStrategy.OnHashCode().Return(0, nil) + shardStrategy.OnUpdatePodSpecMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + + return &shardStrategy +} + +func TestCreatePods(t *testing.T) { + t.Parallel() + tests := []struct { + name string + shardStrategy shardstrategy.ShardStrategy + }{ + {"2", createShardStrategy(2)}, + {"3", createShardStrategy(3)}, + {"4", createShardStrategy(4)}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.TODO() + scope := promutils.NewScope(fmt.Sprintf("create_%s", tt.name)) + kubeClient := fake.NewSimpleClientset(podTemplate) + + manager := Manager{ + kubeClient: kubeClient, + metrics: newManagerMetrics(scope), + podApplication: "flytepropeller", + shardStrategy: tt.shardStrategy, + } + + // ensure no pods are "running" + kubePodsClient := kubeClient.CoreV1().Pods("") + pods, err := kubePodsClient.List(ctx, metav1.ListOptions{}) + assert.NoError(t, err) + assert.Equal(t, 0, len(pods.Items)) + + // create all pods and validate state + err = manager.createPods(ctx) + assert.NoError(t, err) + + pods, err = kubePodsClient.List(ctx, metav1.ListOptions{}) + assert.NoError(t, err) + assert.Equal(t, tt.shardStrategy.GetPodCount(), len(pods.Items)) + + // execute again to ensure no new pods are created + err = manager.createPods(ctx) + assert.NoError(t, err) + + pods, err = kubePodsClient.List(ctx, metav1.ListOptions{}) + assert.NoError(t, err) + assert.Equal(t, tt.shardStrategy.GetPodCount(), len(pods.Items)) + }) + } +} + +func TestUpdatePods(t *testing.T) { + t.Parallel() + tests := []struct { + name string + shardStrategy shardstrategy.ShardStrategy + }{ + {"2", createShardStrategy(2)}, + {"3", createShardStrategy(3)}, + {"4", createShardStrategy(4)}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.TODO() + scope := promutils.NewScope(fmt.Sprintf("update_%s", tt.name)) + + initObjects := []runtime.Object{podTemplate} + for i := 0; i < tt.shardStrategy.GetPodCount(); i++ { + initObjects = append(initObjects, &v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + podTemplateResourceVersion: "1", + shardConfigHash: "1", + }, + Labels: map[string]string{ + "app": "flytepropeller", + }, + Name: fmt.Sprintf("flytepropeller-%d", i), + }, + }) + } + + kubeClient := fake.NewSimpleClientset(initObjects...) + + manager := Manager{ + kubeClient: kubeClient, + metrics: newManagerMetrics(scope), + podApplication: "flytepropeller", + shardStrategy: tt.shardStrategy, + } + + // ensure all pods are "running" + kubePodsClient := kubeClient.CoreV1().Pods("") + pods, err := kubePodsClient.List(ctx, metav1.ListOptions{}) + assert.NoError(t, err) + assert.Equal(t, tt.shardStrategy.GetPodCount(), len(pods.Items)) + for _, pod := range pods.Items { + assert.Equal(t, "1", pod.ObjectMeta.Annotations[podTemplateResourceVersion]) + } + + // create all pods and validate state + err = manager.createPods(ctx) + assert.NoError(t, err) + + pods, err = kubePodsClient.List(ctx, metav1.ListOptions{}) + assert.NoError(t, err) + assert.Equal(t, tt.shardStrategy.GetPodCount(), len(pods.Items)) + for _, pod := range pods.Items { + assert.Equal(t, podTemplate.ObjectMeta.ResourceVersion, pod.ObjectMeta.Annotations[podTemplateResourceVersion]) + } + }) + } +} + +func TestGetPodNames(t *testing.T) { + t.Parallel() + tests := []struct { + name string + shardStrategy shardstrategy.ShardStrategy + podCount int + }{ + {"2", createShardStrategy(2), 2}, + {"3", createShardStrategy(3), 3}, + {"4", createShardStrategy(4), 4}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + manager := Manager{ + podApplication: "flytepropeller", + shardStrategy: tt.shardStrategy, + } + + assert.Equal(t, tt.podCount, len(manager.getPodNames())) + }) + } +} diff --git a/manager/shardstrategy/doc.go b/manager/shardstrategy/doc.go new file mode 100644 index 000000000..096315e0b --- /dev/null +++ b/manager/shardstrategy/doc.go @@ -0,0 +1,4 @@ +/* +Package shardstrategy defines a variety of sharding stratgies to distribute FlyteWorkflows over managed FlytePropeller instances. +*/ +package shardstrategy diff --git a/manager/shardstrategy/environment.go b/manager/shardstrategy/environment.go new file mode 100644 index 000000000..e6b819cbd --- /dev/null +++ b/manager/shardstrategy/environment.go @@ -0,0 +1,62 @@ +package shardstrategy + +import ( + "fmt" + + "github.com/flyteorg/flytepropeller/pkg/utils" + + v1 "k8s.io/api/core/v1" +) + +// EnvironmentShardStrategy assigns either project or domain identifers to individual +// FlytePropeller instances to determine FlyteWorkflow processing responsibility. +type EnvironmentShardStrategy struct { + EnvType environmentType + PerShardIDs [][]string +} + +type environmentType int + +const ( + Project environmentType = iota + Domain +) + +func (e environmentType) String() string { + return [...]string{"project", "domain"}[e] +} + +func (e *EnvironmentShardStrategy) GetPodCount() int { + return len(e.PerShardIDs) +} + +func (e *EnvironmentShardStrategy) HashCode() (uint32, error) { + return computeHashCode(e) +} + +func (e *EnvironmentShardStrategy) UpdatePodSpec(pod *v1.PodSpec, containerName string, podIndex int) error { + container, err := utils.GetContainer(pod, containerName) + if err != nil { + return err + } + + if podIndex < 0 || podIndex >= e.GetPodCount() { + return fmt.Errorf("invalid podIndex '%d' out of range [0,%d)", podIndex, e.GetPodCount()) + } + + if len(e.PerShardIDs[podIndex]) == 1 && e.PerShardIDs[podIndex][0] == "*" { + for i, shardIDs := range e.PerShardIDs { + if i != podIndex { + for _, id := range shardIDs { + container.Args = append(container.Args, fmt.Sprintf("--propeller.exclude-%s-label", e.EnvType), id) + } + } + } + } else { + for _, id := range e.PerShardIDs[podIndex] { + container.Args = append(container.Args, fmt.Sprintf("--propeller.include-%s-label", e.EnvType), id) + } + } + + return nil +} diff --git a/manager/shardstrategy/hash.go b/manager/shardstrategy/hash.go new file mode 100644 index 000000000..7de7e69f3 --- /dev/null +++ b/manager/shardstrategy/hash.go @@ -0,0 +1,72 @@ +package shardstrategy + +import ( + "fmt" + + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/utils" + + v1 "k8s.io/api/core/v1" +) + +// HashShardStrategy evenly assigns disjoint keyspace responsibilities over a collection of pods. +// All FlyteWorkflows are assigned a shard-key using a hash of their executionID and are then +// processed by the FlytePropeller instance responsible for that keyspace range. +type HashShardStrategy struct { + ShardCount int +} + +func (h *HashShardStrategy) GetPodCount() int { + return h.ShardCount +} + +func (h *HashShardStrategy) HashCode() (uint32, error) { + return computeHashCode(h) +} + +func (h *HashShardStrategy) UpdatePodSpec(pod *v1.PodSpec, containerName string, podIndex int) error { + container, err := utils.GetContainer(pod, containerName) + if err != nil { + return err + } + + if podIndex < 0 || podIndex >= h.GetPodCount() { + return fmt.Errorf("invalid podIndex '%d' out of range [0,%d)", podIndex, h.GetPodCount()) + } + + startKey, endKey := ComputeKeyRange(v1alpha1.ShardKeyspaceSize, h.GetPodCount(), podIndex) + for i := startKey; i < endKey; i++ { + container.Args = append(container.Args, "--propeller.include-shard-key-label", fmt.Sprintf("%d", i)) + } + + return nil +} + +// ComputeKeyRange computes a [startKey, endKey) pair denoting the key responsibilities for the +// provided pod index given the keyspaceSize and podCount parameters. +func ComputeKeyRange(keyspaceSize, podCount, podIndex int) (int, int) { + keysPerPod := keyspaceSize / podCount + keyRemainder := keyspaceSize - (podCount * keysPerPod) + + return computeStartKey(keysPerPod, keyRemainder, podIndex), computeStartKey(keysPerPod, keyRemainder, podIndex+1) +} + +func computeStartKey(keysPerPod, keysRemainder, podIndex int) int { + return (intMin(podIndex, keysRemainder) * (keysPerPod + 1)) + (intMax(0, podIndex-keysRemainder) * keysPerPod) +} + +func intMin(a, b int) int { + if a < b { + return a + } + + return b +} + +func intMax(a, b int) int { + if a > b { + return a + } + + return b +} diff --git a/manager/shardstrategy/hash_test.go b/manager/shardstrategy/hash_test.go new file mode 100644 index 000000000..6685fbd45 --- /dev/null +++ b/manager/shardstrategy/hash_test.go @@ -0,0 +1,25 @@ +package shardstrategy + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestComputeKeyRange(t *testing.T) { + keyspaceSize := 32 + for podCount := 1; podCount < keyspaceSize; podCount++ { + keysCovered := 0 + minKeyRangeSize := keyspaceSize / podCount + for podIndex := 0; podIndex < podCount; podIndex++ { + startIndex, endIndex := ComputeKeyRange(keyspaceSize, podCount, podIndex) + + rangeSize := endIndex - startIndex + keysCovered += rangeSize + assert.True(t, rangeSize-minKeyRangeSize >= 0) + assert.True(t, rangeSize-minKeyRangeSize <= 1) + } + + assert.Equal(t, keyspaceSize, keysCovered) + } +} diff --git a/manager/shardstrategy/mocks/shard_strategy.go b/manager/shardstrategy/mocks/shard_strategy.go new file mode 100644 index 000000000..5f5925974 --- /dev/null +++ b/manager/shardstrategy/mocks/shard_strategy.go @@ -0,0 +1,117 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + mock "github.com/stretchr/testify/mock" + + v1 "k8s.io/api/core/v1" +) + +// ShardStrategy is an autogenerated mock type for the ShardStrategy type +type ShardStrategy struct { + mock.Mock +} + +type ShardStrategy_GetPodCount struct { + *mock.Call +} + +func (_m ShardStrategy_GetPodCount) Return(_a0 int) *ShardStrategy_GetPodCount { + return &ShardStrategy_GetPodCount{Call: _m.Call.Return(_a0)} +} + +func (_m *ShardStrategy) OnGetPodCount() *ShardStrategy_GetPodCount { + c := _m.On("GetPodCount") + return &ShardStrategy_GetPodCount{Call: c} +} + +func (_m *ShardStrategy) OnGetPodCountMatch(matchers ...interface{}) *ShardStrategy_GetPodCount { + c := _m.On("GetPodCount", matchers...) + return &ShardStrategy_GetPodCount{Call: c} +} + +// GetPodCount provides a mock function with given fields: +func (_m *ShardStrategy) GetPodCount() int { + ret := _m.Called() + + var r0 int + if rf, ok := ret.Get(0).(func() int); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int) + } + + return r0 +} + +type ShardStrategy_HashCode struct { + *mock.Call +} + +func (_m ShardStrategy_HashCode) Return(_a0 uint32, _a1 error) *ShardStrategy_HashCode { + return &ShardStrategy_HashCode{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *ShardStrategy) OnHashCode() *ShardStrategy_HashCode { + c := _m.On("HashCode") + return &ShardStrategy_HashCode{Call: c} +} + +func (_m *ShardStrategy) OnHashCodeMatch(matchers ...interface{}) *ShardStrategy_HashCode { + c := _m.On("HashCode", matchers...) + return &ShardStrategy_HashCode{Call: c} +} + +// HashCode provides a mock function with given fields: +func (_m *ShardStrategy) HashCode() (uint32, error) { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type ShardStrategy_UpdatePodSpec struct { + *mock.Call +} + +func (_m ShardStrategy_UpdatePodSpec) Return(_a0 error) *ShardStrategy_UpdatePodSpec { + return &ShardStrategy_UpdatePodSpec{Call: _m.Call.Return(_a0)} +} + +func (_m *ShardStrategy) OnUpdatePodSpec(pod *v1.PodSpec, containerName string, podIndex int) *ShardStrategy_UpdatePodSpec { + c := _m.On("UpdatePodSpec", pod, containerName, podIndex) + return &ShardStrategy_UpdatePodSpec{Call: c} +} + +func (_m *ShardStrategy) OnUpdatePodSpecMatch(matchers ...interface{}) *ShardStrategy_UpdatePodSpec { + c := _m.On("UpdatePodSpec", matchers...) + return &ShardStrategy_UpdatePodSpec{Call: c} +} + +// UpdatePodSpec provides a mock function with given fields: pod, containerName, podIndex +func (_m *ShardStrategy) UpdatePodSpec(pod *v1.PodSpec, containerName string, podIndex int) error { + ret := _m.Called(pod, containerName, podIndex) + + var r0 error + if rf, ok := ret.Get(0).(func(*v1.PodSpec, string, int) error); ok { + r0 = rf(pod, containerName, podIndex) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/manager/shardstrategy/shard_strategy.go b/manager/shardstrategy/shard_strategy.go new file mode 100644 index 000000000..217539295 --- /dev/null +++ b/manager/shardstrategy/shard_strategy.go @@ -0,0 +1,98 @@ +package shardstrategy + +import ( + "bytes" + "context" + "encoding/gob" + "fmt" + "hash/fnv" + + "github.com/flyteorg/flytepropeller/manager/config" + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + + v1 "k8s.io/api/core/v1" +) + +//go:generate mockery -name ShardStrategy -case=underscore + +// ShardStrategy defines necessary functionality for a sharding strategy. +type ShardStrategy interface { + // GetPodCount returns the total number of pods for the sharding strategy. + GetPodCount() int + // HashCode generates a unique hash code to identify shard strategy updates. + HashCode() (uint32, error) + // UpdatePodSpec amends the PodSpec for the specified index to include label selectors. + UpdatePodSpec(pod *v1.PodSpec, containerName string, podIndex int) error +} + +// NewShardStrategy creates and validates a new ShardStrategy defined by the configuration. +func NewShardStrategy(ctx context.Context, shardConfig config.ShardConfig) (ShardStrategy, error) { + switch shardConfig.Type { + case config.ShardTypeHash: + if shardConfig.ShardCount <= 0 { + return nil, fmt.Errorf("configured ShardCount (%d) must be greater than zero", shardConfig.ShardCount) + } else if shardConfig.ShardCount > v1alpha1.ShardKeyspaceSize { + return nil, fmt.Errorf("configured ShardCount (%d) is larger than available keyspace size (%d)", shardConfig.ShardCount, v1alpha1.ShardKeyspaceSize) + } + + return &HashShardStrategy{ + ShardCount: shardConfig.ShardCount, + }, nil + case config.ShardTypeProject, config.ShardTypeDomain: + perShardIDs := make([][]string, 0) + wildcardIDFound := false + for _, perShardMapping := range shardConfig.PerShardMappings { + if len(perShardMapping.IDs) == 0 { + return nil, fmt.Errorf("unable to create shard with 0 configured ids") + } + + // validate wildcard ID + for _, id := range perShardMapping.IDs { + if id == "*" { + if len(perShardMapping.IDs) != 1 { + return nil, fmt.Errorf("shards responsible for the wildcard id (ie. '*') may only contain one id") + } + + if wildcardIDFound { + return nil, fmt.Errorf("may only define one shard responsible for the wildcard id (ie. '*')") + } + + wildcardIDFound = true + } + } + + perShardIDs = append(perShardIDs, perShardMapping.IDs) + } + + var envType environmentType + switch shardConfig.Type { + case config.ShardTypeProject: + envType = Project + case config.ShardTypeDomain: + envType = Domain + } + + return &EnvironmentShardStrategy{ + EnvType: envType, + PerShardIDs: perShardIDs, + }, nil + } + + return nil, fmt.Errorf("shard strategy '%s' does not exist", shardConfig.Type) +} + +func computeHashCode(data interface{}) (uint32, error) { + hash := fnv.New32a() + + buffer := new(bytes.Buffer) + encoder := gob.NewEncoder(buffer) + if err := encoder.Encode(data); err != nil { + return 0, err + } + + if _, err := hash.Write(buffer.Bytes()); err != nil { + return 0, err + } + + return hash.Sum32(), nil +} diff --git a/manager/shardstrategy/shard_strategy_test.go b/manager/shardstrategy/shard_strategy_test.go new file mode 100644 index 000000000..5be9cfddc --- /dev/null +++ b/manager/shardstrategy/shard_strategy_test.go @@ -0,0 +1,161 @@ +package shardstrategy + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + v1 "k8s.io/api/core/v1" +) + +var ( + hashShardStrategy = &HashShardStrategy{ + ShardCount: 3, + } + + projectShardStrategy = &EnvironmentShardStrategy{ + EnvType: Project, + PerShardIDs: [][]string{ + []string{"flytesnacks"}, + []string{"flytefoo", "flytebar"}, + }, + } + + projectShardStrategyWildcard = &EnvironmentShardStrategy{ + EnvType: Project, + PerShardIDs: [][]string{ + []string{"flytesnacks"}, + []string{"flytefoo", "flytebar"}, + []string{"*"}, + }, + } + + domainShardStrategy = &EnvironmentShardStrategy{ + EnvType: Domain, + PerShardIDs: [][]string{ + []string{"production"}, + []string{"foo", "bar"}, + }, + } + + domainShardStrategyWildcard = &EnvironmentShardStrategy{ + EnvType: Domain, + PerShardIDs: [][]string{ + []string{"production"}, + []string{"foo", "bar"}, + []string{"*"}, + }, + } +) + +func TestGetPodCount(t *testing.T) { + tests := []struct { + name string + shardStrategy ShardStrategy + podCount int + }{ + {"hash", hashShardStrategy, 3}, + {"project", projectShardStrategy, 2}, + {"project_wildcard", projectShardStrategyWildcard, 3}, + {"domain", domainShardStrategy, 2}, + {"domain_wildcard", domainShardStrategyWildcard, 3}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.podCount, tt.shardStrategy.GetPodCount()) + }) + } +} + +func TestUpdatePodSpec(t *testing.T) { + t.Parallel() + tests := []struct { + name string + shardStrategy ShardStrategy + }{ + {"hash", hashShardStrategy}, + {"project", projectShardStrategy}, + {"project_wildcard", projectShardStrategyWildcard}, + {"domain", domainShardStrategy}, + {"domain_wildcard", domainShardStrategyWildcard}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for podIndex := 0; podIndex < tt.shardStrategy.GetPodCount(); podIndex++ { + podSpec := v1.PodSpec{ + Containers: []v1.Container{ + v1.Container{ + Name: "flytepropeller", + }, + }, + } + + err := tt.shardStrategy.UpdatePodSpec(&podSpec, "flytepropeller", podIndex) + assert.NoError(t, err) + } + }) + } +} + +func TestUpdatePodSpecInvalidPodIndex(t *testing.T) { + t.Parallel() + tests := []struct { + name string + shardStrategy ShardStrategy + }{ + {"hash", hashShardStrategy}, + {"project", projectShardStrategy}, + {"project_wildcard", projectShardStrategyWildcard}, + {"domain", domainShardStrategy}, + {"domain_wildcard", domainShardStrategyWildcard}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + podSpec := v1.PodSpec{ + Containers: []v1.Container{ + v1.Container{ + Name: "flytepropeller", + }, + }, + } + + lowerErr := tt.shardStrategy.UpdatePodSpec(&podSpec, "flytepropeller", -1) + assert.Error(t, lowerErr) + + upperErr := tt.shardStrategy.UpdatePodSpec(&podSpec, "flytepropeller", tt.shardStrategy.GetPodCount()) + assert.Error(t, upperErr) + }) + } +} + +func TestUpdatePodSpecInvalidPodSpec(t *testing.T) { + t.Parallel() + tests := []struct { + name string + shardStrategy ShardStrategy + }{ + {"hash", hashShardStrategy}, + {"project", projectShardStrategy}, + {"project_wildcard", projectShardStrategyWildcard}, + {"domain", domainShardStrategy}, + {"domain_wildcard", domainShardStrategyWildcard}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + podSpec := v1.PodSpec{ + Containers: []v1.Container{ + v1.Container{ + Name: "flytefoo", + }, + }, + } + + err := tt.shardStrategy.UpdatePodSpec(&podSpec, "flytepropeller", 0) + assert.Error(t, err) + }) + } +} diff --git a/pkg/controller/config/config.go b/pkg/controller/config/config.go index 156a74786..dce67eecb 100644 --- a/pkg/controller/config/config.go +++ b/pkg/controller/config/config.go @@ -137,6 +137,12 @@ type Config struct { NodeConfig NodeConfig `json:"node-config,omitempty" pflag:",config for a workflow node"` MaxStreakLength int `json:"max-streak-length" pflag:",Maximum number of consecutive rounds that one propeller worker can use for one workflow - >1 => turbo-mode is enabled."` EventConfig EventConfig `json:"event-config,omitempty" pflag:",Configures execution event behavior."` + IncludeShardKeyLabel []string `json:"include-shard-key-label" pflag:",Include the specified shard key label in the k8s FlyteWorkflow CRD label selector"` + ExcludeShardKeyLabel []string `json:"exclude-shard-key-label" pflag:",Exclude the specified shard key label from the k8s FlyteWorkflow CRD label selector"` + IncludeProjectLabel []string `json:"include-project-label" pflag:",Include the specified project label in the k8s FlyteWorkflow CRD label selector"` + ExcludeProjectLabel []string `json:"exclude-project-label" pflag:",Exclude the specified project label from the k8s FlyteWorkflow CRD label selector"` + IncludeDomainLabel []string `json:"include-domain-label" pflag:",Include the specified domain label in the k8s FlyteWorkflow CRD label selector"` + ExcludeDomainLabel []string `json:"exclude-domain-label" pflag:",Exclude the specified domain label from the k8s FlyteWorkflow CRD label selector"` } // KubeClientConfig contains the configuration used by flytepropeller to configure its internal Kubernetes Client. diff --git a/pkg/controller/config/config_flags.go b/pkg/controller/config/config_flags.go index d0590612b..039161f9a 100755 --- a/pkg/controller/config/config_flags.go +++ b/pkg/controller/config/config_flags.go @@ -95,5 +95,11 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "max-streak-length"), defaultConfig.MaxStreakLength, "Maximum number of consecutive rounds that one propeller worker can use for one workflow - >1 => turbo-mode is enabled.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "event-config.raw-output-policy"), defaultConfig.EventConfig.RawOutputPolicy, "How output data should be passed along in execution events.") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "event-config.fallback-to-output-reference"), defaultConfig.EventConfig.FallbackToOutputReference, "Whether output data should be sent by reference when it is too large to be sent inline in execution events.") + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "include-shard-key-label"), []string{}, "Include the specified shard key label in the k8s FlyteWorkflow CRD label selector") + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "exclude-shard-key-label"), []string{}, "Exclude the specified shard key label from the k8s FlyteWorkflow CRD label selector") + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "include-project-label"), []string{}, "Include the specified project label in the k8s FlyteWorkflow CRD label selector") + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "exclude-project-label"), []string{}, "Exclude the specified project label from the k8s FlyteWorkflow CRD label selector") + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "include-domain-label"), []string{}, "Include the specified domain label in the k8s FlyteWorkflow CRD label selector") + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "exclude-domain-label"), []string{}, "Exclude the specified domain label from the k8s FlyteWorkflow CRD label selector") return cmdFlags } diff --git a/pkg/controller/config/config_flags_test.go b/pkg/controller/config/config_flags_test.go index 7b1ca36d9..4b9ed3afe 100755 --- a/pkg/controller/config/config_flags_test.go +++ b/pkg/controller/config/config_flags_test.go @@ -729,4 +729,88 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_include-shard-key-label", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := join_Config("1,1", ",") + + cmdFlags.Set("include-shard-key-label", testValue) + if vStringSlice, err := cmdFlags.GetStringSlice("include-shard-key-label"); err == nil { + testDecodeRaw_Config(t, join_Config(vStringSlice, ","), &actual.IncludeShardKeyLabel) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_exclude-shard-key-label", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := join_Config("1,1", ",") + + cmdFlags.Set("exclude-shard-key-label", testValue) + if vStringSlice, err := cmdFlags.GetStringSlice("exclude-shard-key-label"); err == nil { + testDecodeRaw_Config(t, join_Config(vStringSlice, ","), &actual.ExcludeShardKeyLabel) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_include-project-label", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := join_Config("1,1", ",") + + cmdFlags.Set("include-project-label", testValue) + if vStringSlice, err := cmdFlags.GetStringSlice("include-project-label"); err == nil { + testDecodeRaw_Config(t, join_Config(vStringSlice, ","), &actual.IncludeProjectLabel) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_exclude-project-label", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := join_Config("1,1", ",") + + cmdFlags.Set("exclude-project-label", testValue) + if vStringSlice, err := cmdFlags.GetStringSlice("exclude-project-label"); err == nil { + testDecodeRaw_Config(t, join_Config(vStringSlice, ","), &actual.ExcludeProjectLabel) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_include-domain-label", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := join_Config("1,1", ",") + + cmdFlags.Set("include-domain-label", testValue) + if vStringSlice, err := cmdFlags.GetStringSlice("include-domain-label"); err == nil { + testDecodeRaw_Config(t, join_Config(vStringSlice, ","), &actual.IncludeDomainLabel) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_exclude-domain-label", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := join_Config("1,1", ",") + + cmdFlags.Set("exclude-domain-label", testValue) + if vStringSlice, err := cmdFlags.GetStringSlice("exclude-domain-label"); err == nil { + testDecodeRaw_Config(t, join_Config(vStringSlice, ","), &actual.ExcludeDomainLabel) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) } diff --git a/pkg/controller/controller.go b/pkg/controller/controller.go index 771e2dbf1..ef99e0c78 100644 --- a/pkg/controller/controller.go +++ b/pkg/controller/controller.go @@ -31,24 +31,22 @@ import ( "github.com/flyteorg/flytestdlib/storage" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" - corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/util/clock" "k8s.io/client-go/kubernetes" - "k8s.io/client-go/kubernetes/scheme" - typedcorev1 "k8s.io/client-go/kubernetes/typed/core/v1" "k8s.io/client-go/tools/cache" "k8s.io/client-go/tools/leaderelection" "k8s.io/client-go/tools/record" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" clientset "github.com/flyteorg/flytepropeller/pkg/client/clientset/versioned" - flyteScheme "github.com/flyteorg/flytepropeller/pkg/client/clientset/versioned/scheme" informers "github.com/flyteorg/flytepropeller/pkg/client/informers/externalversions" lister "github.com/flyteorg/flytepropeller/pkg/client/listers/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/controller/nodes" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" "github.com/flyteorg/flytepropeller/pkg/controller/workflow" + leader "github.com/flyteorg/flytepropeller/pkg/leaderelection" + "github.com/flyteorg/flytepropeller/pkg/utils" ) const resourceLevelMonitorCycleDuration = 5 * time.Second @@ -287,23 +285,6 @@ func newControllerMetrics(scope promutils.Scope) *metrics { } } -func newK8sEventRecorder(ctx context.Context, kubeclientset kubernetes.Interface, publishK8sEvents bool) (record.EventRecorder, error) { - // Create event broadcaster - // Add FlyteWorkflow controller types to the default Kubernetes Scheme so Events can be - // logged for FlyteWorkflow Controller types. - err := flyteScheme.AddToScheme(scheme.Scheme) - if err != nil { - return nil, err - } - logger.Info(ctx, "Creating event broadcaster") - eventBroadcaster := record.NewBroadcaster() - eventBroadcaster.StartLogging(logger.InfofNoCtx) - if publishK8sEvents { - eventBroadcaster.StartRecordingToSink(&typedcorev1.EventSinkImpl{Interface: kubeclientset.CoreV1().Events("")}) - } - return eventBroadcaster.NewRecorder(scheme.Scheme, corev1.EventSource{Component: controllerAgentName}), nil -} - func getAdminClient(ctx context.Context) (client service.AdminServiceClient, err error) { cfg := admin.GetConfig(ctx) clients, err := admin.NewClientsetBuilder().WithConfig(cfg).Build(ctx) @@ -351,7 +332,7 @@ func New(ctx context.Context, cfg *config.Config, kubeclientset kubernetes.Inter return nil, errors.Wrapf(err, "failed to initialize WF GC") } - eventRecorder, err := newK8sEventRecorder(ctx, kubeclientset, cfg.PublishK8sEvents) + eventRecorder, err := utils.NewK8sEventRecorder(ctx, kubeclientset, controllerAgentName, cfg.PublishK8sEvents) if err != nil { logger.Errorf(ctx, "failed to event recorder %v", err) return nil, errors.Wrapf(err, "failed to initialize resource lock.") @@ -363,7 +344,7 @@ func New(ctx context.Context, cfg *config.Config, kubeclientset kubernetes.Inter numWorkers: cfg.Workers, } - lock, err := newResourceLock(kubeclientset.CoreV1(), kubeclientset.CoordinationV1(), eventRecorder, cfg.LeaderElection) + lock, err := leader.NewResourceLock(kubeclientset.CoreV1(), kubeclientset.CoordinationV1(), eventRecorder, cfg.LeaderElection) if err != nil { logger.Errorf(ctx, "failed to initialize resource lock.") return nil, errors.Wrapf(err, "failed to initialize resource lock.") @@ -371,7 +352,7 @@ func New(ctx context.Context, cfg *config.Config, kubeclientset kubernetes.Inter if lock != nil { logger.Infof(ctx, "Creating leader elector for the controller.") - controller.leaderElector, err = newLeaderElector(lock, cfg.LeaderElection, controller.onStartedLeading, func() { + controller.leaderElector, err = leader.NewLeaderElector(lock, cfg.LeaderElection, controller.onStartedLeading, func() { logger.Fatal(ctx, "Lost leader state. Shutting down.") }) diff --git a/pkg/controller/leaderelection.go b/pkg/leaderelection/leader_election.go similarity index 94% rename from pkg/controller/leaderelection.go rename to pkg/leaderelection/leader_election.go index 251409e26..acbe5dc80 100644 --- a/pkg/controller/leaderelection.go +++ b/pkg/leaderelection/leader_election.go @@ -28,7 +28,7 @@ const ( ) // NewResourceLock creates a new config map resource lock for use in a leader election loop -func newResourceLock(corev1 v1.CoreV1Interface, coordinationV1 v12.CoordinationV1Interface, eventRecorder record.EventRecorder, options config.LeaderElectionConfig) ( +func NewResourceLock(corev1 v1.CoreV1Interface, coordinationV1 v12.CoordinationV1Interface, eventRecorder record.EventRecorder, options config.LeaderElectionConfig) ( resourcelock.Interface, error) { if !options.Enabled { @@ -66,7 +66,7 @@ func getUniqueLeaderID() string { return fmt.Sprintf("%v_%v", id, rand.String(10)) } -func newLeaderElector(lock resourcelock.Interface, cfg config.LeaderElectionConfig, +func NewLeaderElector(lock resourcelock.Interface, cfg config.LeaderElectionConfig, leaderFn func(ctx context.Context), leaderStoppedFn func()) (*leaderelection.LeaderElector, error) { return leaderelection.NewLeaderElector(leaderelection.LeaderElectionConfig{ Lock: lock, diff --git a/pkg/utils/k8s.go b/pkg/utils/k8s.go index 3d1da706f..b1ce78c2f 100644 --- a/pkg/utils/k8s.go +++ b/pkg/utils/k8s.go @@ -1,11 +1,17 @@ package utils import ( + "context" + "fmt" + "os" "regexp" "strings" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + flyteScheme "github.com/flyteorg/flytepropeller/pkg/client/clientset/versioned/scheme" + "github.com/flyteorg/flytepropeller/pkg/controller/config" + "github.com/flyteorg/flytestdlib/logger" "github.com/golang/protobuf/ptypes" "github.com/golang/protobuf/ptypes/timestamp" "github.com/pkg/errors" @@ -13,6 +19,12 @@ import ( "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/validation" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/kubernetes/scheme" + typedcorev1 "k8s.io/client-go/kubernetes/typed/core/v1" + restclient "k8s.io/client-go/rest" + "k8s.io/client-go/tools/clientcmd" + "k8s.io/client-go/tools/record" ) var NotTheOwnerError = errors.Errorf("FlytePropeller is not the owner") @@ -84,6 +96,44 @@ func ToK8sResourceRequirements(resources *core.Resources) (*v1.ResourceRequireme return res, nil } +// GetContainer searches the provided pod spec for a container with the specified name +func GetContainer(pod *v1.PodSpec, containerName string) (*v1.Container, error) { + for i := 0; i < len(pod.Containers); i++ { + if pod.Containers[i].Name == containerName { + return &pod.Containers[i], nil + } + } + + return nil, fmt.Errorf("container '%s' not found in podtemplate, ", containerName) +} + +func GetKubeConfig(_ context.Context, cfg *config.Config) (*kubernetes.Clientset, *restclient.Config, error) { + var kubecfg *restclient.Config + var err error + if cfg.KubeConfigPath != "" { + kubeConfigPath := os.ExpandEnv(cfg.KubeConfigPath) + kubecfg, err = clientcmd.BuildConfigFromFlags(cfg.MasterURL, kubeConfigPath) + if err != nil { + return nil, nil, errors.Wrapf(err, "Error building kubeconfig") + } + } else { + kubecfg, err = restclient.InClusterConfig() + if err != nil { + return nil, nil, errors.Wrapf(err, "Cannot get InCluster kubeconfig") + } + } + + kubecfg.QPS = cfg.KubeConfig.QPS + kubecfg.Burst = cfg.KubeConfig.Burst + kubecfg.Timeout = cfg.KubeConfig.Timeout.Duration + + kubeClient, err := kubernetes.NewForConfig(kubecfg) + if err != nil { + return nil, nil, errors.Wrapf(err, "Error building kubernetes clientset") + } + return kubeClient, kubecfg, err +} + func GetWorkflowIDFromOwner(reference *metav1.OwnerReference, namespace string) (v1alpha1.WorkflowID, error) { if reference == nil { return "", NotTheOwnerError @@ -113,3 +163,20 @@ func SanitizeLabelValue(name string) string { } return strings.Trim(name, "-") } + +func NewK8sEventRecorder(ctx context.Context, kubeclientset kubernetes.Interface, controllerAgentName string, publishK8sEvents bool) (record.EventRecorder, error) { + // Create event broadcaster + // Add FlyteWorkflow controller types to the default Kubernetes Scheme so Events can be + // logged for FlyteWorkflow Controller types. + err := flyteScheme.AddToScheme(scheme.Scheme) + if err != nil { + return nil, err + } + logger.Info(ctx, "Creating event broadcaster") + eventBroadcaster := record.NewBroadcaster() + eventBroadcaster.StartLogging(logger.InfofNoCtx) + if publishK8sEvents { + eventBroadcaster.StartRecordingToSink(&typedcorev1.EventSinkImpl{Interface: kubeclientset.CoreV1().Events("")}) + } + return eventBroadcaster.NewRecorder(scheme.Scheme, v1.EventSource{Component: controllerAgentName}), nil +} From c8e39a00d8eff98e9936cc3d68f83abd3a2cd394 Mon Sep 17 00:00:00 2001 From: Haytham Abuelfutuh Date: Mon, 6 Dec 2021 11:01:31 -0800 Subject: [PATCH 2/7] Create codeql-analysis.yml Signed-off-by: Haytham Abuelfutuh --- .github/workflows/codeql-analysis.yml | 70 +++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 .github/workflows/codeql-analysis.yml diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml new file mode 100644 index 000000000..fd6b11af7 --- /dev/null +++ b/.github/workflows/codeql-analysis.yml @@ -0,0 +1,70 @@ +# For most projects, this workflow file will not need changing; you simply need +# to commit it to your repository. +# +# You may wish to alter this file to override the set of languages analyzed, +# or to provide custom queries or build logic. +# +# ******** NOTE ******** +# We have attempted to detect the languages in your repository. Please check +# the `language` matrix defined below to confirm you have the correct set of +# supported CodeQL languages. +# +name: "CodeQL" + +on: + push: + branches: [ master ] + pull_request: + # The branches below must be a subset of the branches above + branches: [ master ] + schedule: + - cron: '23 3 * * 6' + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ 'go' ] + # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] + # Learn more about CodeQL language support at https://git.io/codeql-language-support + + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v1 + with: + languages: ${{ matrix.language }} + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + # queries: ./path/to/local/query, your-org/your-repo/queries@main + + # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). + # If this step fails, then you should remove it and run the build manually (see below) + - name: Autobuild + uses: github/codeql-action/autobuild@v1 + + # ℹī¸ Command-line programs to run using the OS shell. + # 📚 https://git.io/JvXDl + + # ✏ī¸ If the Autobuild fails above, remove it and uncomment the following three lines + # and modify them (or add more) to build your code if your project + # uses a compiled language + + #- run: | + # make bootstrap + # make release + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v1 From d08b1d608d49970a0a1188abfa60e879dfb3cfe7 Mon Sep 17 00:00:00 2001 From: Haytham Abuelfutuh Date: Tue, 4 Jan 2022 15:43:08 -0800 Subject: [PATCH 3/7] Handle code quality issue Signed-off-by: Haytham Abuelfutuh --- .../nodes/task/catalog/datacatalog/datacatalog.go | 9 ++++++++- .../nodes/task/catalog/datacatalog/transformer.go | 14 ++++++++++---- .../task/catalog/datacatalog/transformer_test.go | 3 ++- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/pkg/controller/nodes/task/catalog/datacatalog/datacatalog.go b/pkg/controller/nodes/task/catalog/datacatalog/datacatalog.go index d91936553..f65dd33c9 100644 --- a/pkg/controller/nodes/task/catalog/datacatalog/datacatalog.go +++ b/pkg/controller/nodes/task/catalog/datacatalog/datacatalog.go @@ -3,6 +3,7 @@ package datacatalog import ( "context" "crypto/x509" + "fmt" "time" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" @@ -124,7 +125,13 @@ func (m *CatalogClient) Get(ctx context.Context, key catalog.Key) (catalog.Entry // TODO should we look through all the tags to find the relevant one? relevantTag = artifact.GetTags()[0] } - md := EventCatalogMetadata(dataset.GetId(), relevantTag, GetSourceFromMetadata(dataset.GetMetadata(), artifact.GetMetadata(), key.Identifier)) + + source, err := GetSourceFromMetadata(dataset.GetMetadata(), artifact.GetMetadata(), key.Identifier) + if err != nil { + return catalog.Entry{}, fmt.Errorf("failed to get source from metadata. Error: %w", err) + } + + md := EventCatalogMetadata(dataset.GetId(), relevantTag, source) outputs, err := GenerateTaskOutputsFromArtifact(key.Identifier, key.TypedInterface, artifact) if err != nil { diff --git a/pkg/controller/nodes/task/catalog/datacatalog/transformer.go b/pkg/controller/nodes/task/catalog/datacatalog/transformer.go index 655ffe4b0..d3abc24ee 100644 --- a/pkg/controller/nodes/task/catalog/datacatalog/transformer.go +++ b/pkg/controller/nodes/task/catalog/datacatalog/transformer.go @@ -199,18 +199,24 @@ func GetArtifactMetadataForSource(taskExecutionID *core.TaskExecutionIdentifier) } } -// Returns the Source TaskExecutionIdentifier from the catalog metadata +// GetSourceFromMetadata returns the Source TaskExecutionIdentifier from the catalog metadata // For all the information not available it returns Unknown. This is because as of July-2020 Catalog does not have all // the information. After the first deployment of this code, it will have this and the "unknown's" can be phased out -func GetSourceFromMetadata(datasetMd, artifactMd *datacatalog.Metadata, currentID core.Identifier) *core.TaskExecutionIdentifier { +func GetSourceFromMetadata(datasetMd, artifactMd *datacatalog.Metadata, currentID core.Identifier) (*core.TaskExecutionIdentifier, error) { if datasetMd == nil || datasetMd.KeyMap == nil { datasetMd = &datacatalog.Metadata{KeyMap: map[string]string{}} } if artifactMd == nil || artifactMd.KeyMap == nil { artifactMd = &datacatalog.Metadata{KeyMap: map[string]string{}} } + // Jul-06-2020 DataCatalog stores only wfExecutionKey & taskVersionKey So we will default the project / domain to the current dataset's project domain - attempt, _ := strconv.Atoi(GetOrDefault(artifactMd.KeyMap, execTaskAttemptKey, "0")) + val := GetOrDefault(artifactMd.KeyMap, execTaskAttemptKey, "0") + attempt, err := strconv.Atoi(val) + if err != nil { + return nil, fmt.Errorf("failed to parse [%v] to integer. Error: %w", val, err) + } + return &core.TaskExecutionIdentifier{ TaskId: &core.Identifier{ ResourceType: currentID.ResourceType, @@ -228,7 +234,7 @@ func GetSourceFromMetadata(datasetMd, artifactMd *datacatalog.Metadata, currentI Name: GetOrDefault(artifactMd.KeyMap, execNameKey, "unknown"), }, }, - } + }, nil } // Given the Catalog Information (returned from a Catalog call), returns the CatalogMetadata that is populated in the event. diff --git a/pkg/controller/nodes/task/catalog/datacatalog/transformer_test.go b/pkg/controller/nodes/task/catalog/datacatalog/transformer_test.go index d4575874d..4d2485f27 100644 --- a/pkg/controller/nodes/task/catalog/datacatalog/transformer_test.go +++ b/pkg/controller/nodes/task/catalog/datacatalog/transformer_test.go @@ -279,8 +279,9 @@ func TestGetSourceFromMetadata(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := GetSourceFromMetadata(&datacatalog.Metadata{KeyMap: tt.args.datasetMd}, &datacatalog.Metadata{KeyMap: tt.args.artifactMd}, tt.args.currentID); !reflect.DeepEqual(got, tt.want) { + if got, err := GetSourceFromMetadata(&datacatalog.Metadata{KeyMap: tt.args.datasetMd}, &datacatalog.Metadata{KeyMap: tt.args.artifactMd}, tt.args.currentID); !reflect.DeepEqual(got, tt.want) { t.Errorf("GetSourceFromMetadata() = %v, want %v", got, tt.want) + assert.NoError(t, err) } }) } From eecdb481fe7b9911bdf259823d13b83f3900d0f9 Mon Sep 17 00:00:00 2001 From: Haytham Abuelfutuh Date: Tue, 4 Jan 2022 16:34:26 -0800 Subject: [PATCH 4/7] check boundaries Signed-off-by: Haytham Abuelfutuh --- .../nodes/task/catalog/datacatalog/transformer.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/pkg/controller/nodes/task/catalog/datacatalog/transformer.go b/pkg/controller/nodes/task/catalog/datacatalog/transformer.go index d3abc24ee..653ecee50 100644 --- a/pkg/controller/nodes/task/catalog/datacatalog/transformer.go +++ b/pkg/controller/nodes/task/catalog/datacatalog/transformer.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" "fmt" + "math" "reflect" "strconv" "strings" @@ -217,6 +218,14 @@ func GetSourceFromMetadata(datasetMd, artifactMd *datacatalog.Metadata, currentI return nil, fmt.Errorf("failed to parse [%v] to integer. Error: %w", val, err) } + attempt32 := uint32(0) + // GOOD: check for lower and upper bounds + if attempt > 0 && attempt <= math.MaxUint32 { + attempt32 = uint32(attempt) + } else { + return nil, fmt.Errorf("invalid attempts value [%v]", attempt) + } + return &core.TaskExecutionIdentifier{ TaskId: &core.Identifier{ ResourceType: currentID.ResourceType, @@ -225,7 +234,7 @@ func GetSourceFromMetadata(datasetMd, artifactMd *datacatalog.Metadata, currentI Name: currentID.Name, Version: GetOrDefault(datasetMd.KeyMap, taskVersionKey, "unknown"), }, - RetryAttempt: uint32(attempt), + RetryAttempt: attempt32, NodeExecutionId: &core.NodeExecutionIdentifier{ NodeId: GetOrDefault(artifactMd.KeyMap, execNodeIDKey, "unknown"), ExecutionId: &core.WorkflowExecutionIdentifier{ From 6b7237a36bf7fd9464f370708a0406f3162d9f41 Mon Sep 17 00:00:00 2001 From: Haytham Abuelfutuh Date: Tue, 4 Jan 2022 16:54:43 -0800 Subject: [PATCH 5/7] 0 is ok Signed-off-by: Haytham Abuelfutuh --- pkg/controller/nodes/task/catalog/datacatalog/transformer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/controller/nodes/task/catalog/datacatalog/transformer.go b/pkg/controller/nodes/task/catalog/datacatalog/transformer.go index 653ecee50..01e5c5b22 100644 --- a/pkg/controller/nodes/task/catalog/datacatalog/transformer.go +++ b/pkg/controller/nodes/task/catalog/datacatalog/transformer.go @@ -220,7 +220,7 @@ func GetSourceFromMetadata(datasetMd, artifactMd *datacatalog.Metadata, currentI attempt32 := uint32(0) // GOOD: check for lower and upper bounds - if attempt > 0 && attempt <= math.MaxUint32 { + if attempt >= 0 && attempt <= math.MaxUint32 { attempt32 = uint32(attempt) } else { return nil, fmt.Errorf("invalid attempts value [%v]", attempt) From a2046b3057e0ad369b337f15d98c3a9513234298 Mon Sep 17 00:00:00 2001 From: Haytham Abuelfutuh Date: Wed, 5 Jan 2022 09:18:12 -0800 Subject: [PATCH 6/7] Use ParseUint instead Signed-off-by: Haytham Abuelfutuh --- .../nodes/task/catalog/datacatalog/transformer.go | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/pkg/controller/nodes/task/catalog/datacatalog/transformer.go b/pkg/controller/nodes/task/catalog/datacatalog/transformer.go index 01e5c5b22..e8dbcb476 100644 --- a/pkg/controller/nodes/task/catalog/datacatalog/transformer.go +++ b/pkg/controller/nodes/task/catalog/datacatalog/transformer.go @@ -4,7 +4,6 @@ import ( "context" "encoding/base64" "fmt" - "math" "reflect" "strconv" "strings" @@ -213,19 +212,11 @@ func GetSourceFromMetadata(datasetMd, artifactMd *datacatalog.Metadata, currentI // Jul-06-2020 DataCatalog stores only wfExecutionKey & taskVersionKey So we will default the project / domain to the current dataset's project domain val := GetOrDefault(artifactMd.KeyMap, execTaskAttemptKey, "0") - attempt, err := strconv.Atoi(val) + attempt, err := strconv.ParseUint(val, 10, 32) if err != nil { return nil, fmt.Errorf("failed to parse [%v] to integer. Error: %w", val, err) } - attempt32 := uint32(0) - // GOOD: check for lower and upper bounds - if attempt >= 0 && attempt <= math.MaxUint32 { - attempt32 = uint32(attempt) - } else { - return nil, fmt.Errorf("invalid attempts value [%v]", attempt) - } - return &core.TaskExecutionIdentifier{ TaskId: &core.Identifier{ ResourceType: currentID.ResourceType, @@ -234,7 +225,7 @@ func GetSourceFromMetadata(datasetMd, artifactMd *datacatalog.Metadata, currentI Name: currentID.Name, Version: GetOrDefault(datasetMd.KeyMap, taskVersionKey, "unknown"), }, - RetryAttempt: attempt32, + RetryAttempt: uint32(attempt), NodeExecutionId: &core.NodeExecutionIdentifier{ NodeId: GetOrDefault(artifactMd.KeyMap, execNodeIDKey, "unknown"), ExecutionId: &core.WorkflowExecutionIdentifier{ From 4294805f63434301873609912608617651f2827e Mon Sep 17 00:00:00 2001 From: Haytham Abuelfutuh Date: Wed, 5 Jan 2022 09:22:36 -0800 Subject: [PATCH 7/7] bump for DCO Signed-off-by: Haytham Abuelfutuh