From 979fabe1d1b22b01645259a03b8096f227681d08 Mon Sep 17 00:00:00 2001 From: Haytham Abuelfutuh Date: Mon, 22 Mar 2021 16:53:22 -0700 Subject: [PATCH] Pod Mutating Webhook & Secret Annotation Injector (#242) * Pod Mutating Webhook & Secret Annotation Injector Signed-off-by: Haytham Abuelfutuh * Marshal the entire secret object instead Signed-off-by: Haytham Abuelfutuh * docs Signed-off-by: Haytham Abuelfutuh * cmd docs Signed-off-by: Haytham Abuelfutuh * refactor Signed-off-by: Haytham Abuelfutuh * Unit tests Signed-off-by: Haytham Abuelfutuh * Update pkg/utils/secrets/marshaler.go Co-authored-by: Ketan Umare <16888709+kumare3@users.noreply.github.com> Signed-off-by: Haytham Abuelfutuh * introduce webhook in README Signed-off-by: Haytham Abuelfutuh Co-authored-by: Ketan Umare <16888709+kumare3@users.noreply.github.com> --- README.md | 71 +++++- cmd/controller/cmd/init_certs.go | 222 +++++++++++++++++ cmd/controller/cmd/root.go | 9 +- cmd/controller/cmd/webhook.go | 206 ++++++++++++++++ go.mod | 3 +- go.sum | 6 +- pkg/compiler/workflow_compiler.go | 2 + pkg/controller/executors/mocks/fake.go | 2 +- .../nodes/task/k8s/plugin_manager.go | 16 +- .../nodes/task/k8s/plugin_manager_test.go | 52 ++-- .../nodes/task/k8s/task_exec_context.go | 66 +++++ .../nodes/task/k8s/task_exec_context_test.go | 75 ++++++ .../nodes/task/secretmanager/secrets.go | 46 +++- pkg/utils/encoder.go | 6 +- pkg/utils/secrets/marshaler.go | 86 +++++++ pkg/utils/secrets/marshaler_test.go | 77 ++++++ pkg/webhook/config.go | 29 +++ pkg/webhook/config_flags.go | 50 ++++ pkg/webhook/config_flags_test.go | 212 ++++++++++++++++ pkg/webhook/global_secrets.go | 76 ++++++ pkg/webhook/global_secrets_test.go | 99 ++++++++ pkg/webhook/k8s_secrets.go | 102 ++++++++ pkg/webhook/k8s_secrets_test.go | 171 +++++++++++++ pkg/webhook/mocks/global_secret_provider.go | 54 ++++ pkg/webhook/mocks/mutator.go | 95 +++++++ pkg/webhook/mocks/secrets_injector.go | 97 ++++++++ pkg/webhook/pod.go | 231 ++++++++++++++++++ pkg/webhook/pod_test.go | 158 ++++++++++++ pkg/webhook/secrets.go | 62 +++++ pkg/webhook/secrets_test.go | 61 +++++ pkg/webhook/testdata/ca.crt | 1 + pkg/webhook/utils.go | 84 +++++++ pkg/webhook/utils_test.go | 178 ++++++++++++++ 33 files changed, 2657 insertions(+), 48 deletions(-) create mode 100644 cmd/controller/cmd/init_certs.go create mode 100644 cmd/controller/cmd/webhook.go create mode 100644 pkg/controller/nodes/task/k8s/task_exec_context.go create mode 100644 pkg/controller/nodes/task/k8s/task_exec_context_test.go create mode 100644 pkg/utils/secrets/marshaler.go create mode 100644 pkg/utils/secrets/marshaler_test.go create mode 100644 pkg/webhook/config.go create mode 100755 pkg/webhook/config_flags.go create mode 100755 pkg/webhook/config_flags_test.go create mode 100644 pkg/webhook/global_secrets.go create mode 100644 pkg/webhook/global_secrets_test.go create mode 100644 pkg/webhook/k8s_secrets.go create mode 100644 pkg/webhook/k8s_secrets_test.go create mode 100644 pkg/webhook/mocks/global_secret_provider.go create mode 100644 pkg/webhook/mocks/mutator.go create mode 100644 pkg/webhook/mocks/secrets_injector.go create mode 100644 pkg/webhook/pod.go create mode 100644 pkg/webhook/pod_test.go create mode 100644 pkg/webhook/secrets.go create mode 100644 pkg/webhook/secrets_test.go create mode 100644 pkg/webhook/testdata/ca.crt create mode 100644 pkg/webhook/utils.go create mode 100644 pkg/webhook/utils_test.go diff --git a/README.md b/README.md index 562f252b1..47896278b 100644 --- a/README.md +++ b/README.md @@ -11,20 +11,41 @@ Flyte Propeller Kubernetes operator to executes Flyte graphs natively on kubernetes +Components +========== + +Propeller +--------- +Propeller is a K8s native operator that executes Flyte workflows. Workflow Spec is written in Protobuf for +cross-compatibility. + +Propeller Webhook +----------------- +A Mutating Webhook that can be optionally deployed to extend Flyte Propeller's functionality. It currently supports +enables injecting secrets into pods launched directly or indirectly through Flyte backend plugins. + +kubectl-flyte +------------- +A Kubectl-plugin to interact with Flyte Workflow CRDs. It enables retrieving and rendering Flyte Workflows in CLI as +well as safely aborting running workflows. + Getting Started =============== kubectl-flyte tool ------------------ -kubectl-flyte is an command line tool that can be used as an extension to kubectl. It is a separate binary that is built from the propeller repo. +kubectl-flyte is an command line tool that can be used as an extension to kubectl. It is a separate binary that is built +from the propeller repo. Install ------- This command will install kubectl-flyte and flytepropeller to `~/go/bin` + ``` $ make compile ``` You can also use [Krew](https://github.com/kubernetes-sigs/krew) to install the kubectl-flyte CLI: + ``` $ kubectl krew install flyte ``` @@ -68,7 +89,8 @@ To retrieve all workflows in a namespace use the --namespace option, --namespace Success: 19, Failed: 0, Running: 0, Waiting: 0 ``` -To retrieve a specific workflow, namespace can either be provided in the format namespace/name or using the --namespace argument +To retrieve a specific workflow, namespace can either be provided in the format namespace/name or using the --namespace +argument ``` $ kubectl-flyte get flytekit-development/flytekit-development-ff806e973581f4508bf1 @@ -89,7 +111,8 @@ To delete a specific workflow $ kubectl-flyte delete --namespace flytekit-development flytekit-development-ff806e973581f4508bf1 ``` -To delete all completed workflows - they have to be either success/failed with a special isCompleted label set on them. The Label is set `here ` +To delete all completed workflows - they have to be either success/failed with a special isCompleted label set on them. +The Label is set `here ` ``` $ kubectl-flyte delete --namespace flytekit-development --all-completed @@ -97,40 +120,68 @@ To delete all completed workflows - they have to be either success/failed with a Running propeller locally ------------------------- -use the config.yaml in root found `here `. Cd into this folder and then run +use the config.yaml in root found `here `. Cd into +this folder and then run ``` - $ flytepropeller --logtostderr + $ flytepropeller ``` Following dependencies need to be met + 1. Blob store (you can forward minio port to localhost) 2. Admin Service endpoint (can be forwarded) OR *Disable* events to admin and launchplans 3. access to kubeconfig and kubeapi +Running webhook +--------------- + +API Server requires the webhook to serve traffic over SSL. To issue self-signed certs to be used for serving traffic, +use: + +``` + $ flytepropeller webhook init-certs +``` + +This will create a ca.crt, tls.crt and key.crt and store them to flyte-pod-webhook secret. If a secret of the same name +already exist, it'll not override it. + +Starting the webhook can be done by running: + +``` + $ flytepropeller webhook +``` + +The secret should be mounted and accessible to this command. It'll then create a MutatingWebhookConfiguration object +with the details of the webhook and that registers the webhook with ApiServer. + Making changes to CRD ===================== *Remember* changes to CRD should be carefully done, they should be backwards compatible or else you should use proper -operator versioning system. Once you do the changes, you have to follow the -following steps. +operator versioning system. Once you do the changes, you have to follow the following steps. - ensure the propeller code is checked out in $GOPATH/github.com/flyteorg/flytepropeller - Uncomment https://github.com/flyteorg/flytepropeller/blob/master/hack/tools.go#L5 - + ```bash go mod vendor ``` + - Now generate the code + ```bash make op_code_generate $make op_code_generate ``` -**Why do we have to do this?** -Flytepropeller uses old way of writing Custom controllers for K8s. The k8s.io/code-generator only works in the GOPATH relative code path (sadly). So you have checkout the code in the right place. -Also, `go mod vendor` is needed to get code-generator in a discoverable path. +**Why do we have to do this?** +Flytepropeller uses old way of writing Custom controllers for K8s. The k8s.io/code-generator only works in the GOPATH +relative code path (sadly). So you have checkout the code in the right place. Also, `go mod vendor` is needed to get +code-generator in a discoverable path. **TODO** + 1. We may be able to avoid needing the old style go-path 2. Migrate to using controller runtime diff --git a/cmd/controller/cmd/init_certs.go b/cmd/controller/cmd/init_certs.go new file mode 100644 index 000000000..65c5badf9 --- /dev/null +++ b/cmd/controller/cmd/init_certs.go @@ -0,0 +1,222 @@ +package cmd + +import ( + "bytes" + "context" + cryptorand "crypto/rand" + + "github.com/flyteorg/flytestdlib/logger" + "k8s.io/apimachinery/pkg/api/errors" + + "github.com/flyteorg/flytepropeller/pkg/controller/config" + + corev1 "k8s.io/api/core/v1" + v12 "k8s.io/apimachinery/pkg/apis/meta/v1" + v1 "k8s.io/client-go/kubernetes/typed/core/v1" + + "github.com/flyteorg/flytepropeller/pkg/webhook" + + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "os" + "time" + + "github.com/spf13/cobra" +) + +const ( + CaCertKey = "ca.crt" + ServerCertKey = "tls.crt" + ServerCertPrivateKey = "tls.key" +) + +// initCertsCmd initializes x509 TLS Certificates and saves them to a secret. +var initCertsCmd = &cobra.Command{ + Use: "init-certs", + Aliases: []string{"init-cert"}, + Short: "Generates CA, Cert and cert key and saves them into a secret using the configured --webhook.secretName", + Long: ` +K8s API Server Webhooks' Services are required to serve traffic over SSL. Ideally the SSL certificate is issued by a +known Certificate Authority that's already trusted by API Server (e.g. using Let's Encrypt Cluster Certificate Controller). +Otherwise, a self-issued certificate can be used to serve traffic as long as the CA for that certificate is stored in +the MutatingWebhookConfiguration object that registers the webhook with API Server. + +init-certs generates 4096-bit X509 Certificates for Certificate Authority as well as a derived Server cert and its private +key. It serializes all of them in PEM format (base64 encoding-compatible) and stores them into a new kubernetes secret. +If a secret with the same name already exist, it'll not update it. + +POD_NAMESPACE is an environment variable that can, optionally, be set on the Pod that runs init-certs command. If set, +the secret will be created in that namespace. It's critical that this command creates the secret in the right namespace +for the Webhook command to mount and read correctly. +`, + Example: "flytepropeller webhook init-certs", + RunE: func(cmd *cobra.Command, args []string) error { + return runCertsCmd(context.Background(), config.GetConfig(), webhook.GetConfig()) + }, +} + +type webhookCerts struct { + // base64 Encoded CA Cert + CaPEM *bytes.Buffer + // base64 Encoded Server Cert + ServerPEM *bytes.Buffer + // base64 Encoded Server Cert Key + PrivateKeyPEM *bytes.Buffer +} + +func init() { + webhookCmd.AddCommand(initCertsCmd) +} + +func runCertsCmd(ctx context.Context, propellerCfg *config.Config, cfg *webhook.Config) error { + podNamespace, found := os.LookupEnv(PodNamespaceEnvVar) + if !found { + podNamespace = podDefaultNamespace + } + + logger.Infof(ctx, "Issuing certs") + certs, err := createCerts(podNamespace) + if err != nil { + return err + } + + kubeClient, _, err := getKubeConfig(ctx, propellerCfg) + if err != nil { + return err + } + + logger.Infof(ctx, "Creating secret [%v] in Namespace [%v]", cfg.SecretName, podNamespace) + err = createWebhookSecret(ctx, podNamespace, cfg, certs, kubeClient.CoreV1().Secrets(podNamespace)) + if err != nil { + return err + } + + return nil +} + +func createWebhookSecret(ctx context.Context, namespace string, cfg *webhook.Config, certs webhookCerts, secretsClient v1.SecretInterface) error { + isImmutable := true + _, err := secretsClient.Create(ctx, &corev1.Secret{ + ObjectMeta: v12.ObjectMeta{ + Name: cfg.SecretName, + Namespace: namespace, + }, + Type: corev1.SecretTypeOpaque, + Data: map[string][]byte{ + CaCertKey: certs.CaPEM.Bytes(), + ServerCertKey: certs.ServerPEM.Bytes(), + ServerCertPrivateKey: certs.PrivateKeyPEM.Bytes(), + }, + Immutable: &isImmutable, + }, v12.CreateOptions{}) + + if errors.IsAlreadyExists(err) { + // TODO: Maybe get the secret and validate it has all the required keys? + logger.Infof(ctx, "A secret already exists with the same name. Ignoring creating secret.") + return nil + } + + logger.Infof(ctx, "Created secret [%v]", cfg.SecretName) + + return err +} + +func createCerts(serviceNamespace string) (certs webhookCerts, err error) { + // CA config + caRequest := &x509.Certificate{ + SerialNumber: big.NewInt(2021), + Subject: pkix.Name{ + Organization: []string{"flyte.org"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(1, 0, 0), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + + // CA private key + caPrivateKey, err := rsa.GenerateKey(cryptorand.Reader, 4096) + if err != nil { + return webhookCerts{}, err + } + + // Self signed CA certificate + caCert, err := x509.CreateCertificate(cryptorand.Reader, caRequest, caRequest, &caPrivateKey.PublicKey, caPrivateKey) + if err != nil { + return webhookCerts{}, err + } + + // PEM encode CA cert + caPEM := new(bytes.Buffer) + err = pem.Encode(caPEM, &pem.Block{ + Type: "CERTIFICATE", + Bytes: caCert, + }) + if err != nil { + return webhookCerts{}, err + } + + dnsNames := []string{"flyte-pod-webhook", + "flyte-pod-webhook." + serviceNamespace, "flyte-pod-webhook." + serviceNamespace + ".svc"} + commonName := "flyte-pod-webhook." + serviceNamespace + ".svc" + + // server cert config + certRequest := &x509.Certificate{ + DNSNames: dnsNames, + SerialNumber: big.NewInt(1658), + Subject: pkix.Name{ + CommonName: commonName, + Organization: []string{"flyte.org"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(1, 0, 0), + SubjectKeyId: []byte{1, 2, 3, 4, 6}, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + } + + // server private key + serverPrivateKey, err := rsa.GenerateKey(cryptorand.Reader, 4096) + if err != nil { + return webhookCerts{}, err + } + + // sign the server cert + cert, err := x509.CreateCertificate(cryptorand.Reader, certRequest, caRequest, &serverPrivateKey.PublicKey, caPrivateKey) + if err != nil { + return webhookCerts{}, err + } + + // PEM encode the server cert and key + serverCertPEM := new(bytes.Buffer) + err = pem.Encode(serverCertPEM, &pem.Block{ + Type: "CERTIFICATE", + Bytes: cert, + }) + + if err != nil { + return webhookCerts{}, fmt.Errorf("failed to Encode CertPEM. Error: %w", err) + } + + serverPrivKeyPEM := new(bytes.Buffer) + err = pem.Encode(serverPrivKeyPEM, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(serverPrivateKey), + }) + + if err != nil { + return webhookCerts{}, fmt.Errorf("failed to Encode Cert Private Key. Error: %w", err) + } + + return webhookCerts{ + CaPEM: caPEM, + ServerPEM: serverCertPEM, + PrivateKeyPEM: serverPrivKeyPEM, + }, nil +} diff --git a/cmd/controller/cmd/root.go b/cmd/controller/cmd/root.go index 3691f3230..3b0ec4808 100644 --- a/cmd/controller/cmd/root.go +++ b/cmd/controller/cmd/root.go @@ -1,3 +1,4 @@ +// Commands for FlytePropeller controller. package cmd import ( @@ -58,7 +59,7 @@ var rootCmd = &cobra.Command{ Short: "Operator for running Flyte Workflows", Long: `Flyte Propeller runs a workflow to completion by recursing through the nodes, handling their tasks to completion and propagating their status upstream.`, - PreRunE: initConfig, + PersistentPreRunE: initConfig, Run: func(cmd *cobra.Command, args []string) { executeRootCmd(config2.GetConfig()) }, @@ -93,12 +94,14 @@ func init() { rootCmd.AddCommand(viper.GetConfigCommand()) } -func initConfig(_ *cobra.Command, _ []string) error { +func initConfig(cmd *cobra.Command, _ []string) error { configAccessor = viper.NewAccessor(config.Options{ - StrictMode: true, + StrictMode: false, SearchPaths: []string{cfgFile}, }) + configAccessor.InitializePflags(cmd.PersistentFlags()) + err := configAccessor.UpdateConfig(context.TODO()) if err != nil { return err diff --git a/cmd/controller/cmd/webhook.go b/cmd/controller/cmd/webhook.go new file mode 100644 index 000000000..78d305cd8 --- /dev/null +++ b/cmd/controller/cmd/webhook.go @@ -0,0 +1,206 @@ +package cmd + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + + apiErrors "k8s.io/apimachinery/pkg/api/errors" + + "k8s.io/client-go/kubernetes" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/flyteorg/flytepropeller/pkg/controller/executors" + "github.com/flyteorg/flytepropeller/pkg/signals" + "github.com/flyteorg/flytepropeller/pkg/webhook" + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/profutils" + "github.com/flyteorg/flytestdlib/promutils" + "sigs.k8s.io/controller-runtime/pkg/manager" + + "github.com/flyteorg/flytepropeller/pkg/controller/config" + "github.com/spf13/cobra" +) + +const ( + PodNameEnvVar = "POD_NAME" + PodNamespaceEnvVar = "POD_NAMESPACE" + podDefaultNamespace = "default" +) + +var webhookCmd = &cobra.Command{ + Use: "webhook", + Aliases: []string{"webhooks"}, + Short: "Runs Propeller Pod Webhook that listens for certain labels and modify the pod accordingly.", + Long: ` +This command initializes propeller's Pod webhook that enables it to mutate pods whether they are created directly from +plugins or indirectly through the creation of other CRDs (e.g. Spark/Pytorch). +In order to use this Webhook: +1) Keys need to be mounted to the POD that runs this command; tls.crt should be a CA-issued cert (not a self-signed + cert), tls.key as the private key for that cert and, optionally, ca.crt in case tls.crt's CA is not a known + Certificate Authority (e.g. in case ca.crt is self-issued). +2) POD_NAME and POD_NAMESPACE environment variables need to be populated because the webhook initialization will lookup + this pod to copy OwnerReferences into the new MutatingWebhookConfiguration object it'll create to ensure proper + cleanup. + +A sample Container for this webhook might look like this: + + volumes: + - name: config-volume + configMap: + name: flyte-propeller-config-492gkfhbgk + # Certs secret created by running 'flytepropeller webhook init-certs' + - name: webhook-certs + secret: + secretName: flyte-pod-webhook + containers: + - name: webhook-server + image: + command: + - flytepropeller + args: + - webhook + - --config + - /etc/flyte/config/*.yaml + env: + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_NAMESPACE + valueFrom: + fieldRef: + fieldPath: metadata.namespace + volumeMounts: + - name: config-volume + mountPath: /etc/flyte/config + readOnly: true + # Mount certs from a secret + - name: webhook-certs + mountPath: /etc/webhook/certs + readOnly: true +`, + RunE: func(cmd *cobra.Command, args []string) error { + return runWebhook(context.Background(), config.GetConfig(), webhook.GetConfig()) + }, +} + +func init() { + rootCmd.AddCommand(webhookCmd) +} + +func runWebhook(origContext context.Context, propellerCfg *config.Config, cfg *webhook.Config) error { + // set up signals so we handle the first shutdown signal gracefully + ctx := signals.SetupSignalHandler(origContext) + + raw, err := json.Marshal(cfg) + if err != nil { + return err + } + + fmt.Println(string(raw)) + + kubeClient, kubecfg, err := getKubeConfig(ctx, propellerCfg) + if err != nil { + return err + } + + // Add the propeller subscope because the MetricsPrefix only has "flyte:" to get uniform collection of metrics. + propellerScope := promutils.NewScope(cfg.MetricsPrefix).NewSubScope("propeller").NewSubScope(safeMetricName(propellerCfg.LimitNamespace)) + + go func() { + err := profutils.StartProfilingServerWithDefaultHandlers(ctx, propellerCfg.ProfilerPort.Port, nil) + if err != nil { + logger.Panicf(ctx, "Failed to Start profiling and metrics server. Error: %v", err) + } + }() + + limitNamespace := "" + if propellerCfg.LimitNamespace != defaultNamespace { + limitNamespace = propellerCfg.LimitNamespace + } + + secretsWebhook := webhook.NewPodMutator(cfg, propellerScope.NewSubScope("webhook")) + + // Creates a MutationConfig to instruct ApiServer to call this service whenever a Pod is being created. + err = createMutationConfig(ctx, kubeClient, secretsWebhook) + if err != nil { + return err + } + + mgr, err := manager.New(kubecfg, manager.Options{ + Port: cfg.ListenPort, + CertDir: cfg.CertDir, + Namespace: limitNamespace, + SyncPeriod: &propellerCfg.DownstreamEval.Duration, + ClientBuilder: executors.NewFallbackClientBuilder(), + }) + + if err != nil { + logger.Fatalf(ctx, "Failed to initialize controller run-time manager. Error: %v", err) + } + + err = secretsWebhook.Register(ctx, mgr) + if err != nil { + logger.Fatalf(ctx, "Failed to register webhook with manager. Error: %v", err) + } + + logger.Infof(ctx, "Starting controller-runtime manager") + return mgr.Start(ctx) +} + +func createMutationConfig(ctx context.Context, kubeClient *kubernetes.Clientset, webhookObj *webhook.PodMutator) error { + shouldAddOwnerRef := true + podName, found := os.LookupEnv(PodNameEnvVar) + if !found { + shouldAddOwnerRef = false + } + + podNamespace, found := os.LookupEnv(PodNamespaceEnvVar) + if !found { + shouldAddOwnerRef = false + podNamespace = podDefaultNamespace + } + + mutateConfig, err := webhookObj.CreateMutationWebhookConfiguration(podNamespace) + if err != nil { + return err + } + + if shouldAddOwnerRef { + // Lookup the pod to retrieve its UID + p, err := kubeClient.CoreV1().Pods(podNamespace).Get(ctx, podName, metav1.GetOptions{}) + if err != nil { + logger.Infof(ctx, "Failed to get Pod [%v/%v]. Error: %v", podNamespace, podName, err) + return fmt.Errorf("failed to get pod. Error: %w", err) + } + + mutateConfig.OwnerReferences = p.OwnerReferences + } + + logger.Infof(ctx, "Creating MutatingWebhookConfiguration [%v/%v]", mutateConfig.GetNamespace(), mutateConfig.GetName()) + + _, err = kubeClient.AdmissionregistrationV1().MutatingWebhookConfigurations().Create(ctx, mutateConfig, metav1.CreateOptions{}) + var statusErr *apiErrors.StatusError + if err != nil && errors.As(err, &statusErr) && statusErr.Status().Reason == metav1.StatusReasonAlreadyExists { + logger.Infof(ctx, "Failed to create MutatingWebhookConfiguration. Will attempt to update. Error: %v", err) + obj, getErr := kubeClient.AdmissionregistrationV1().MutatingWebhookConfigurations().Get(ctx, mutateConfig.Name, metav1.GetOptions{}) + if getErr != nil { + logger.Infof(ctx, "Failed to get MutatingWebhookConfiguration. Error: %v", getErr) + return err + } + + obj.Webhooks = mutateConfig.Webhooks + _, err = kubeClient.AdmissionregistrationV1().MutatingWebhookConfigurations().Update(ctx, obj, metav1.UpdateOptions{}) + if err == nil { + logger.Infof(ctx, "Successfully updated existing mutating webhook config.") + } + + return err + } + + return nil +} diff --git a/go.mod b/go.mod index 0954f63dd..82df8750c 100644 --- a/go.mod +++ b/go.mod @@ -6,11 +6,12 @@ require ( github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295 github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 github.com/fatih/color v1.10.0 - github.com/flyteorg/flyteidl v0.18.20 + github.com/flyteorg/flyteidl v0.18.24 github.com/flyteorg/flyteplugins v0.5.38 github.com/flyteorg/flytestdlib v0.3.13 github.com/ghodss/yaml v1.0.0 github.com/go-redis/redis v6.15.7+incompatible + github.com/go-test/deep v1.0.7 github.com/golang/protobuf v1.4.3 github.com/google/uuid v1.2.0 github.com/grpc-ecosystem/go-grpc-middleware v1.1.0 diff --git a/go.sum b/go.sum index 8e92f5de2..265cd3279 100644 --- a/go.sum +++ b/go.sum @@ -72,6 +72,7 @@ github.com/Azure/go-autorest/logger v0.2.0/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZ github.com/Azure/go-autorest/tracing v0.5.0/go.mod h1:r/s2XiOKccPW3HrqB+W0TQzfbtp2fGCgRFtBroKn4Dk= github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo= github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU= +github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295 h1:xJ0dAkuxJXfwdH7IaSzBEbSQxEDz36YUmt7+CB4zoNA= @@ -231,8 +232,8 @@ github.com/fatih/color v1.10.0 h1:s36xzo75JdqLaaWoiEHk767eHiwo0598uUxyfiPkDsg= github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= github.com/flyteorg/flyteidl v0.18.17/go.mod h1:b5Fq4Z8a5b0mF6pEwTd48ufvikUGVkWSjZiMT0ZtqKI= -github.com/flyteorg/flyteidl v0.18.20 h1:OGOb2FOHWL363Qp8uzbJeFbQBKYPT30+afv+8BnBlGs= -github.com/flyteorg/flyteidl v0.18.20/go.mod h1:b5Fq4Z8a5b0mF6pEwTd48ufvikUGVkWSjZiMT0ZtqKI= +github.com/flyteorg/flyteidl v0.18.24 h1:Y4+y/tu6Qsb3jNXxuVsflycfSocfthUi6XsMgJTfGuc= +github.com/flyteorg/flyteidl v0.18.24/go.mod h1:b5Fq4Z8a5b0mF6pEwTd48ufvikUGVkWSjZiMT0ZtqKI= github.com/flyteorg/flyteplugins v0.5.38 h1:xAQ1J23cRxzwNDgzbmRuuvflq2PFetntRCjuM5RBfTw= github.com/flyteorg/flyteplugins v0.5.38/go.mod h1:CxerBGWWEmNYmPxSMHnwQEr9cc1Fbo/g5fcABazU6Jo= github.com/flyteorg/flytestdlib v0.3.13 h1:5ioA/q3ixlyqkFh5kDaHgmPyTP/AHtqq1K/TIbVLUzM= @@ -1229,6 +1230,7 @@ honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= +honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8= honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= k8s.io/api v0.0.0-20210217171935-8e2decd92398/go.mod h1:60tmSUpHxGPFerNHbo/ayI2lKxvtrhbxFyXuEIWJd78= k8s.io/api v0.18.2/go.mod h1:SJCWI7OLzhZSvbY7U8zwNl9UA4o1fizoug34OV/2r78= diff --git a/pkg/compiler/workflow_compiler.go b/pkg/compiler/workflow_compiler.go index dbdb46b8f..e01795c8d 100755 --- a/pkg/compiler/workflow_compiler.go +++ b/pkg/compiler/workflow_compiler.go @@ -38,6 +38,8 @@ import ( c "github.com/flyteorg/flytepropeller/pkg/compiler/common" "github.com/flyteorg/flytepropeller/pkg/compiler/errors" v "github.com/flyteorg/flytepropeller/pkg/compiler/validators" + + // #noSA1019 "github.com/golang/protobuf/proto" "k8s.io/apimachinery/pkg/util/sets" ) diff --git a/pkg/controller/executors/mocks/fake.go b/pkg/controller/executors/mocks/fake.go index 27fb94060..e31da3c61 100644 --- a/pkg/controller/executors/mocks/fake.go +++ b/pkg/controller/executors/mocks/fake.go @@ -7,7 +7,7 @@ import ( func NewFakeKubeClient() *Client { c := Client{} - c.On("GetClient").Return(fake.NewFakeClient()) + c.On("GetClient").Return(fake.NewClientBuilder().WithRuntimeObjects().Build()) c.On("GetCache").Return(&informertest.FakeInformers{}) return &c } diff --git a/pkg/controller/nodes/task/k8s/plugin_manager.go b/pkg/controller/nodes/task/k8s/plugin_manager.go index 2a38ab430..1bf65bd4e 100644 --- a/pkg/controller/nodes/task/k8s/plugin_manager.go +++ b/pkg/controller/nodes/task/k8s/plugin_manager.go @@ -176,12 +176,24 @@ func (e *PluginManager) getPodEffectiveResourceLimits(ctx context.Context, pod * func (e *PluginManager) LaunchResource(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (pluginsCore.Transition, error) { - o, err := e.plugin.BuildResource(ctx, tCtx) + tmpl, err := tCtx.TaskReader().Read(ctx) + if err != nil { + return pluginsCore.Transition{}, err + } + + k8sTaskCtxMetadata, err := newTaskExecutionMetadata(tCtx.TaskExecutionMetadata(), tmpl) + if err != nil { + return pluginsCore.Transition{}, err + } + + k8sTaskCtx := newTaskExecutionContext(tCtx, k8sTaskCtxMetadata) + + o, err := e.plugin.BuildResource(ctx, k8sTaskCtx) if err != nil { return pluginsCore.UnknownTransition, err } - e.AddObjectMetadata(tCtx.TaskExecutionMetadata(), o, config.GetK8sPluginConfig()) + e.AddObjectMetadata(k8sTaskCtxMetadata, o, config.GetK8sPluginConfig()) logger.Infof(ctx, "Creating Object: Type:[%v], Object:[%v/%v]", o.GetObjectKind().GroupVersionKind(), o.GetNamespace(), o.GetName()) key := backoff.ComposeResourceKey(o) diff --git a/pkg/controller/nodes/task/k8s/plugin_manager_test.go b/pkg/controller/nodes/task/k8s/plugin_manager_test.go index b0469528f..242e639cf 100644 --- a/pkg/controller/nodes/task/k8s/plugin_manager_test.go +++ b/pkg/controller/nodes/task/k8s/plugin_manager_test.go @@ -120,10 +120,14 @@ func (d *dummyOutputWriter) Put(ctx context.Context, reader io.OutputReader) err func getMockTaskContext(initPhase PluginPhase, wantPhase PluginPhase) pluginsCore.TaskExecutionContext { taskExecutionContext := &pluginsCoreMock.TaskExecutionContext{} - taskExecutionContext.On("TaskExecutionMetadata").Return(getMockTaskExecutionMetadata()) + taskExecutionContext.OnTaskExecutionMetadata().Return(getMockTaskExecutionMetadata()) + + tReader := &pluginsCoreMock.TaskReader{} + tReader.OnReadMatch(mock.Anything).Return(&core.TaskTemplate{}, nil) + taskExecutionContext.OnTaskReader().Return(tReader) customStateReader := &pluginsCoreMock.PluginStateReader{} - customStateReader.On("Get", mock.MatchedBy(func(i interface{}) bool { + customStateReader.OnGetMatch(mock.MatchedBy(func(i interface{}) bool { ps, ok := i.(*PluginState) if ok { ps.Phase = initPhase @@ -131,18 +135,18 @@ func getMockTaskContext(initPhase PluginPhase, wantPhase PluginPhase) pluginsCor } return false })).Return(uint8(0), nil) - taskExecutionContext.On("PluginStateReader").Return(customStateReader) + taskExecutionContext.OnPluginStateReader().Return(customStateReader) customStateWriter := &pluginsCoreMock.PluginStateWriter{} - customStateWriter.On("Put", mock.Anything, mock.MatchedBy(func(i interface{}) bool { + customStateWriter.OnPutMatch(mock.Anything, mock.MatchedBy(func(i interface{}) bool { ps, ok := i.(*PluginState) return ok && ps.Phase == wantPhase })).Return(nil) - taskExecutionContext.On("PluginStateWriter").Return(customStateWriter) - taskExecutionContext.On("OutputWriter").Return(&dummyOutputWriter{}) + taskExecutionContext.OnPluginStateWriter().Return(customStateWriter) + taskExecutionContext.OnOutputWriter().Return(&dummyOutputWriter{}) - taskExecutionContext.On("DataStore").Return(nil) - taskExecutionContext.On("MaxDatasetSizeBytes").Return(int64(0)) + taskExecutionContext.OnDataStore().Return(nil) + taskExecutionContext.OnMaxDatasetSizeBytes().Return(int64(0)) return taskExecutionContext } @@ -201,11 +205,11 @@ func TestK8sTaskExecutor_Handle_LaunchResource(t *testing.T) { var inputs *core.LiteralMap*/ t.Run("jobQueued", func(t *testing.T) { - tctx := getMockTaskContext(PluginPhaseNotStarted, PluginPhaseStarted) + tCtx := getMockTaskContext(PluginPhaseNotStarted, PluginPhaseStarted) // common setup code mockResourceHandler := &pluginsk8sMock.Plugin{} - mockResourceHandler.On("BuildResource", mock.Anything, tctx).Return(&v1.Pod{}, nil) - fakeClient := fake.NewFakeClient() + mockResourceHandler.OnBuildResourceMatch(mock.Anything, mock.Anything).Return(&v1.Pod{}, nil) + fakeClient := fake.NewClientBuilder().WithRuntimeObjects().Build() pluginManager, err := NewPluginManager(ctx, dummySetupContext(fakeClient), k8s.PluginEntry{ ID: "x", ResourceToWatch: &v1.Pod{}, @@ -213,7 +217,7 @@ func TestK8sTaskExecutor_Handle_LaunchResource(t *testing.T) { }, NewResourceMonitorIndex()) assert.NoError(t, err) - transition, err := pluginManager.Handle(ctx, tctx) + transition, err := pluginManager.Handle(ctx, tCtx) assert.NoError(t, err) assert.NotNil(t, transition) transitionInfo := transition.Info() @@ -221,10 +225,10 @@ func TestK8sTaskExecutor_Handle_LaunchResource(t *testing.T) { assert.Equal(t, pluginsCore.PhaseQueued, transitionInfo.Phase()) createdPod := &v1.Pod{} - pluginManager.AddObjectMetadata(tctx.TaskExecutionMetadata(), createdPod, &config.K8sPluginConfig{}) - assert.NoError(t, fakeClient.Get(ctx, k8stypes.NamespacedName{Namespace: tctx.TaskExecutionMetadata().GetNamespace(), - Name: tctx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()}, createdPod)) - assert.Equal(t, tctx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), createdPod.Name) + pluginManager.AddObjectMetadata(tCtx.TaskExecutionMetadata(), createdPod, &config.K8sPluginConfig{}) + assert.NoError(t, fakeClient.Get(ctx, k8stypes.NamespacedName{Namespace: tCtx.TaskExecutionMetadata().GetNamespace(), + Name: tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()}, createdPod)) + assert.Equal(t, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), createdPod.Name) assert.NoError(t, fakeClient.Delete(ctx, createdPod)) }) @@ -232,8 +236,8 @@ func TestK8sTaskExecutor_Handle_LaunchResource(t *testing.T) { tctx := getMockTaskContext(PluginPhaseNotStarted, PluginPhaseStarted) // common setup code mockResourceHandler := &pluginsk8sMock.Plugin{} - mockResourceHandler.On("BuildResource", mock.Anything, tctx).Return(&v1.Pod{}, nil) - fakeClient := fake.NewFakeClient() + mockResourceHandler.OnBuildResourceMatch(mock.Anything, mock.Anything).Return(&v1.Pod{}, nil) + fakeClient := fake.NewClientBuilder().WithRuntimeObjects().Build() pluginManager, err := NewPluginManager(ctx, dummySetupContext(fakeClient), k8s.PluginEntry{ ID: "x", ResourceToWatch: &v1.Pod{}, @@ -262,9 +266,9 @@ func TestK8sTaskExecutor_Handle_LaunchResource(t *testing.T) { tctx := getMockTaskContext(PluginPhaseNotStarted, PluginPhaseNotStarted) // common setup code mockResourceHandler := &pluginsk8sMock.Plugin{} - mockResourceHandler.On("BuildResource", mock.Anything, tctx).Return(&v1.Pod{}, nil) + mockResourceHandler.OnBuildResourceMatch(mock.Anything, mock.Anything).Return(&v1.Pod{}, nil) fakeClient := extendedFakeClient{ - Client: fake.NewFakeClient(), + Client: fake.NewClientBuilder().WithRuntimeObjects().Build(), CreateError: k8serrors.NewForbidden(schema.GroupResource{}, "", errors.New("exceeded quota")), } @@ -295,9 +299,9 @@ func TestK8sTaskExecutor_Handle_LaunchResource(t *testing.T) { tctx := getMockTaskContext(PluginPhaseNotStarted, PluginPhaseNotStarted) // common setup code mockResourceHandler := &pluginsk8sMock.Plugin{} - mockResourceHandler.On("BuildResource", mock.Anything, tctx).Return(&v1.Pod{}, nil) + mockResourceHandler.OnBuildResourceMatch(mock.Anything, mock.Anything).Return(&v1.Pod{}, nil) fakeClient := extendedFakeClient{ - Client: fake.NewFakeClient(), + Client: fake.NewClientBuilder().WithRuntimeObjects().Build(), CreateError: k8serrors.NewForbidden(schema.GroupResource{}, "", errors.New("auth error")), } @@ -326,7 +330,7 @@ func TestK8sTaskExecutor_Handle_LaunchResource(t *testing.T) { tctx := getMockTaskContext(PluginPhaseNotStarted, PluginPhaseNotStarted) // Creating a mock k8s plugin mockResourceHandler := &pluginsk8sMock.Plugin{} - mockResourceHandler.On("BuildResource", mock.Anything, tctx).Return(&v1.Pod{ + mockResourceHandler.OnBuildResourceMatch(mock.Anything, mock.Anything).Return(&v1.Pod{ TypeMeta: metav1.TypeMeta{ Kind: flytek8s.PodKind, APIVersion: v1.SchemeGroupVersion.String(), @@ -346,7 +350,7 @@ func TestK8sTaskExecutor_Handle_LaunchResource(t *testing.T) { }, }, nil) fakeClient := extendedFakeClient{ - Client: fake.NewFakeClient(), + Client: fake.NewClientBuilder().WithRuntimeObjects().Build(), CreateError: k8serrors.NewForbidden(schema.GroupResource{}, "", errors.New("is forbidden: "+ "exceeded quota: project-quota, requested: limits.memory=3Gi, "+ "used: limits.memory=7976Gi, limited: limits.memory=8000Gi")), diff --git a/pkg/controller/nodes/task/k8s/task_exec_context.go b/pkg/controller/nodes/task/k8s/task_exec_context.go new file mode 100644 index 000000000..90c6af7eb --- /dev/null +++ b/pkg/controller/nodes/task/k8s/task_exec_context.go @@ -0,0 +1,66 @@ +package k8s + +import ( + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/flyteorg/flytepropeller/pkg/utils/secrets" +) + +// TaskExecutionContext provides a layer on top of core TaskExecutionContext with a custom TaskExecutionMetadata. +type TaskExecutionContext struct { + pluginsCore.TaskExecutionContext + metadataOverride pluginsCore.TaskExecutionMetadata +} + +func (t TaskExecutionContext) TaskExecutionMetadata() pluginsCore.TaskExecutionMetadata { + return t.metadataOverride +} + +func newTaskExecutionContext(tCtx pluginsCore.TaskExecutionContext, metadataOverride pluginsCore.TaskExecutionMetadata) TaskExecutionContext { + return TaskExecutionContext{ + TaskExecutionContext: tCtx, + metadataOverride: metadataOverride, + } +} + +// TaskExecutionMetadata provides a layer on top of the core TaskExecutionMetadata with customized annotations and labels +// for k8s plugins. +type TaskExecutionMetadata struct { + pluginsCore.TaskExecutionMetadata + + annotations map[string]string + labels map[string]string +} + +func (t TaskExecutionMetadata) GetLabels() map[string]string { + return t.labels +} + +func (t TaskExecutionMetadata) GetAnnotations() map[string]string { + return t.annotations +} + +// newTaskExecutionMetadata creates a TaskExecutionMetadata with secrets serialized as annotations and a label added +// to trigger the flyte pod webhook +func newTaskExecutionMetadata(tCtx pluginsCore.TaskExecutionMetadata, taskTmpl *core.TaskTemplate) (TaskExecutionMetadata, error) { + var err error + secretsMap := make(map[string]string) + injectSecretsLabel := make(map[string]string) + if taskTmpl.SecurityContext != nil && len(taskTmpl.SecurityContext.Secrets) > 0 { + secretsMap, err = secrets.MarshalSecretsToMapStrings(taskTmpl.SecurityContext.Secrets) + if err != nil { + return TaskExecutionMetadata{}, err + } + + injectSecretsLabel = map[string]string{ + secrets.PodLabel: secrets.PodLabelValue, + } + } + + return TaskExecutionMetadata{ + TaskExecutionMetadata: tCtx, + annotations: utils.UnionMaps(tCtx.GetAnnotations(), secretsMap), + labels: utils.UnionMaps(tCtx.GetLabels(), injectSecretsLabel), + }, nil +} diff --git a/pkg/controller/nodes/task/k8s/task_exec_context_test.go b/pkg/controller/nodes/task/k8s/task_exec_context_test.go new file mode 100644 index 000000000..24836bd4f --- /dev/null +++ b/pkg/controller/nodes/task/k8s/task_exec_context_test.go @@ -0,0 +1,75 @@ +package k8s + +import ( + "testing" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/stretchr/testify/assert" +) + +func Test_newTaskExecutionMetadata(t *testing.T) { + t.Run("No Secret", func(t *testing.T) { + existingMetadata := &mocks.TaskExecutionMetadata{} + existingAnnotations := map[string]string{ + "existingKey": "existingValue", + } + existingMetadata.OnGetAnnotations().Return(existingAnnotations) + + existingLabels := map[string]string{ + "existingLabel": "existingLabelValue", + } + existingMetadata.OnGetLabels().Return(existingLabels) + + actual, err := newTaskExecutionMetadata(existingMetadata, &core.TaskTemplate{}) + assert.NoError(t, err) + + assert.Equal(t, existingAnnotations, actual.GetAnnotations()) + assert.Equal(t, existingLabels, actual.GetLabels()) + }) + + t.Run("Secret", func(t *testing.T) { + existingMetadata := &mocks.TaskExecutionMetadata{} + existingAnnotations := map[string]string{ + "existingKey": "existingValue", + } + existingMetadata.OnGetAnnotations().Return(existingAnnotations) + + existingLabels := map[string]string{ + "existingLabel": "existingLabelValue", + } + existingMetadata.OnGetLabels().Return(existingLabels) + + actual, err := newTaskExecutionMetadata(existingMetadata, &core.TaskTemplate{ + SecurityContext: &core.SecurityContext{ + Secrets: []*core.Secret{ + { + Group: "my_group", + Key: "my_key", + MountRequirement: core.Secret_ENV_VAR, + }, + }, + }, + }) + assert.NoError(t, err) + + assert.Equal(t, map[string]string{ + "existingKey": "existingValue", + "flyte.secrets/s0": "m4zg54lqhiqce2lzl4txe22voarau12fpe4caitnpfpwwzlzeifg122vnz1f53tfof1ws3tfnvsw34b1ebcu3vs6kzavecq", + }, actual.GetAnnotations()) + + assert.Equal(t, map[string]string{ + "existingLabel": "existingLabelValue", + "inject-flyte-secrets": "true", + }, actual.GetLabels()) + }) +} + +func Test_newTaskExecutionContext(t *testing.T) { + existing := &mocks.TaskExecutionContext{} + existing.OnTaskExecutionMetadata().Panic("Unexpected") + + newMetadata := &mocks.TaskExecutionMetadata{} + actualCtx := newTaskExecutionContext(existing, newMetadata) + assert.Equal(t, newMetadata, actualCtx.TaskExecutionMetadata()) +} diff --git a/pkg/controller/nodes/task/secretmanager/secrets.go b/pkg/controller/nodes/task/secretmanager/secrets.go index 5bc1e72a1..e097ed9af 100644 --- a/pkg/controller/nodes/task/secretmanager/secrets.go +++ b/pkg/controller/nodes/task/secretmanager/secrets.go @@ -6,11 +6,17 @@ import ( "io/ioutil" "os" "path/filepath" + "strings" + + coreIdl "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flytestdlib/logger" ) +// Env Var Lookup based on Prefix + SecretGroup + _ + SecretKey +const envVarLookupFormatter = "%s%s_%s" + +// FileEnvSecretManager allows retrieving secrets mounted to this process through Env Vars or Files. type FileEnvSecretManager struct { secretPath string envPrefix string @@ -23,6 +29,7 @@ func (f FileEnvSecretManager) Get(ctx context.Context, key string) (string, erro logger.Debugf(ctx, "Secret found %s", v) return v, nil } + secretFile := filepath.Join(f.secretPath, key) if _, err := os.Stat(secretFile); err != nil { if os.IsNotExist(err) { @@ -30,15 +37,50 @@ func (f FileEnvSecretManager) Get(ctx context.Context, key string) (string, erro } return "", err } + + logger.Debugf(ctx, "reading secrets from filePath [%s]", secretFile) + b, err := ioutil.ReadFile(secretFile) + if err != nil { + return "", err + } + return string(b), err +} + +// GetForSecret retrieves a secret from the environment of the running process. To lookup secret, both secret's key and +// group must be non-empty. GetForSecret will first lookup env variables using the configured +// Prefix+SecretGroup+_+SecretKey. If the secret is not found in environment, it'll lookup the secret from files using +// the configured SecretPath / SecretGroup / SecretKey. +func (f FileEnvSecretManager) GetForSecret(ctx context.Context, secret *coreIdl.Secret) (string, error) { + if len(secret.Group) == 0 || len(secret.Key) == 0 { + return "", fmt.Errorf("both key and group are required parameters. Secret: [%v]", secret.String()) + } + + envVar := fmt.Sprintf(envVarLookupFormatter, f.envPrefix, strings.ToUpper(secret.Group), strings.ToUpper(secret.Key)) + v, ok := os.LookupEnv(envVar) + if ok { + logger.Debugf(ctx, "Secret found %s", v) + return v, nil + } + + secretFile := filepath.Join(f.secretPath, filepath.Join(secret.Group, secret.Key)) + if _, err := os.Stat(secretFile); err != nil { + if os.IsNotExist(err) { + return "", fmt.Errorf("secrets not found - Env [%s], file [%s]", envVar, secretFile) + } + + return "", err + } + logger.Debugf(ctx, "reading secrets from filePath [%s]", secretFile) b, err := ioutil.ReadFile(secretFile) if err != nil { return "", err } + return string(b), err } -func NewFileEnvSecretManager(cfg *Config) core.SecretManager { +func NewFileEnvSecretManager(cfg *Config) FileEnvSecretManager { return FileEnvSecretManager{ secretPath: cfg.SecretFilePrefix, envPrefix: cfg.EnvironmentPrefix, diff --git a/pkg/utils/encoder.go b/pkg/utils/encoder.go index 397fe406b..db28be3eb 100644 --- a/pkg/utils/encoder.go +++ b/pkg/utils/encoder.go @@ -9,7 +9,7 @@ import ( const specialEncoderKey = "abcdefghijklmnopqrstuvwxyz123456" -var base32Encoder = base32.NewEncoding(specialEncoderKey).WithPadding(base32.NoPadding) +var Base32Encoder = base32.NewEncoding(specialEncoderKey).WithPadding(base32.NoPadding) // Creates a new UniqueID that is based on the inputID and of a specified length, if the given id is longer than the // maxLength. @@ -22,8 +22,8 @@ func FixedLengthUniqueID(inputID string, maxLength int) (string, error) { // Using 32a an error can never happen, so this will always remain not covered by a unit test _, _ = hasher.Write([]byte(inputID)) // #nosec b := hasher.Sum(nil) - // expected length after this step is 8 chars (1 + 7 chars from base32Encoder.EncodeToString(b)) - finalStr := "f" + base32Encoder.EncodeToString(b) + // expected length after this step is 8 chars (1 + 7 chars from Base32Encoder.EncodeToString(b)) + finalStr := "f" + Base32Encoder.EncodeToString(b) if len(finalStr) > maxLength { return finalStr, fmt.Errorf("max Length is too small, cannot create an encoded string that is so small") } diff --git a/pkg/utils/secrets/marshaler.go b/pkg/utils/secrets/marshaler.go new file mode 100644 index 000000000..8a5e72b1c --- /dev/null +++ b/pkg/utils/secrets/marshaler.go @@ -0,0 +1,86 @@ +package secrets + +import ( + "fmt" + "strconv" + "strings" + + "github.com/flyteorg/flytepropeller/pkg/utils" + + "github.com/golang/protobuf/proto" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" +) + +const ( + annotationPrefix = "flyte.secrets/s" + PodLabel = "inject-flyte-secrets" + PodLabelValue = "true" +) + +// Copied from: +// https://github.com/kubernetes/kubernetes/blob/master/staging/src/k8s.io/apimachinery/pkg/api/validation/objectmeta.go#L36 +const totalAnnotationSizeLimitB int = 256 * (1 << 10) // 256 kB + +func encodeSecret(secretAsString string) string { + res := utils.Base32Encoder.EncodeToString([]byte(secretAsString)) + return strings.TrimSuffix(res, "=") +} + +func decodeSecret(encoded string) (string, error) { + decodedRaw, err := utils.Base32Encoder.DecodeString(encoded) + if err != nil { + return encoded, err + } + + return string(decodedRaw), nil +} + +func marshalSecret(s *core.Secret) string { + return encodeSecret(proto.MarshalTextString(s)) +} + +func unmarshalSecret(encoded string) (*core.Secret, error) { + decoded, err := decodeSecret(encoded) + if err != nil { + return nil, err + } + + s := &core.Secret{} + err = proto.UnmarshalText(decoded, s) + return s, err +} + +func MarshalSecretsToMapStrings(secrets []*core.Secret) (map[string]string, error) { + res := make(map[string]string, len(secrets)) + for index, s := range secrets { + if _, found := core.Secret_MountType_name[int32(s.MountRequirement)]; !found { + return nil, fmt.Errorf("invalid mount requirement [%v]", s.MountRequirement) + } + + encodedSecret := marshalSecret(s) + res[annotationPrefix+strconv.Itoa(index)] = encodedSecret + + if len(encodedSecret) > totalAnnotationSizeLimitB { + return nil, fmt.Errorf("secret descriptor cannot exceed [%v]", totalAnnotationSizeLimitB) + } + } + + return res, nil +} + +func UnmarshalStringMapToSecrets(m map[string]string) ([]*core.Secret, error) { + res := make([]*core.Secret, 0, len(m)) + for key, val := range m { + if strings.HasPrefix(key, annotationPrefix) { + s, err := unmarshalSecret(val) + if err != nil { + return nil, fmt.Errorf("error unmarshaling secret [%v]. Error: %w", key, err) + } + + res = append(res, s) + } + } + + return res, nil +} diff --git a/pkg/utils/secrets/marshaler_test.go b/pkg/utils/secrets/marshaler_test.go new file mode 100644 index 000000000..86f510459 --- /dev/null +++ b/pkg/utils/secrets/marshaler_test.go @@ -0,0 +1,77 @@ +package secrets + +import ( + "reflect" + "testing" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + "github.com/stretchr/testify/assert" +) + +func TestEncodeSecretGroup(t *testing.T) { + input := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz01234567890._-/" + encoded := encodeSecret(input) + t.Log(input + " -> " + encoded) + decoded, err := decodeSecret(encoded) + assert.NoError(t, err) + assert.Equal(t, input, decoded) +} + +func TestMarshalSecretsToMapStrings(t *testing.T) { + type args struct { + secrets []*core.Secret + } + tests := []struct { + name string + args args + want map[string]string + wantErr bool + }{ + {name: "empty", args: args{secrets: []*core.Secret{}}, want: map[string]string{}, wantErr: false}, + {name: "nil", args: args{secrets: nil}, want: map[string]string{}, wantErr: false}, + {name: "forbidden characters", args: args{secrets: []*core.Secret{ + { + Group: ";':/\\", + }, + }}, want: map[string]string{ + "flyte.secrets/s0": "m4zg54lqhiqceozhhixvyxbcbi", + }, wantErr: false}, + {name: "Without group", args: args{secrets: []*core.Secret{ + { + Key: "my_key", + }, + }}, want: map[string]string{ + "flyte.secrets/s0": "nnsxsoraejwxsx2lmv3secq", + }, wantErr: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := MarshalSecretsToMapStrings(tt.args.secrets) + if (err != nil) != tt.wantErr { + t.Errorf("MarshalSecretsToMapStrings() error = %v, wantErr %v", err, tt.wantErr) + return + } else if err != nil { + return + } + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalSecretsToMapStrings() got = %v, want %v", got, tt.want) + } + }) + + t.Run(tt.name+"_unmarshal", func(t *testing.T) { + got, err := UnmarshalStringMapToSecrets(tt.want) + if (err != nil) != tt.wantErr { + t.Errorf("UnmarshalSecretsToMapStrings() error = %v, wantErr %v", err, tt.wantErr) + return + } else if err != nil { + return + } + + if tt.args.secrets != nil && !reflect.DeepEqual(got, tt.args.secrets) { + t.Errorf("UnmarshalSecretsToMapStrings() got = %v, want %v", got, tt.args.secrets) + } + }) + } +} diff --git a/pkg/webhook/config.go b/pkg/webhook/config.go new file mode 100644 index 000000000..9c08878dc --- /dev/null +++ b/pkg/webhook/config.go @@ -0,0 +1,29 @@ +package webhook + +import "github.com/flyteorg/flytestdlib/config" + +//go:generate pflags Config --default-var=defaultConfig + +var ( + defaultConfig = &Config{ + SecretName: "flyte-pod-webhook", + ServiceName: "flyte-pod-webhook", + MetricsPrefix: "flyte:", + CertDir: "/etc/webhook/certs", + ListenPort: 9443, + } + + configSection = config.MustRegisterSection("webhook", defaultConfig) +) + +type Config struct { + MetricsPrefix string `json:"metrics-prefix" pflag:",An optional prefix for all published metrics."` + CertDir string `json:"certDir" pflag:",Certificate directory to use to write generated certs. Defaults to /etc/webhook/certs/"` + ListenPort int `json:"listenPort" pflag:",The port to use to listen to webhook calls. Defaults to 9443"` + ServiceName string `json:"serviceName" pflag:",The name of the webhook service."` + SecretName string `json:"secretName" pflag:",Secret name to write generated certs to."` +} + +func GetConfig() *Config { + return configSection.GetConfig().(*Config) +} diff --git a/pkg/webhook/config_flags.go b/pkg/webhook/config_flags.go new file mode 100755 index 000000000..af96a6f71 --- /dev/null +++ b/pkg/webhook/config_flags.go @@ -0,0 +1,50 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package webhook + +import ( + "encoding/json" + "reflect" + + "fmt" + + "github.com/spf13/pflag" +) + +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func (Config) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + +func (Config) mustMarshalJSON(v json.Marshaler) string { + raw, err := v.MarshalJSON() + if err != nil { + panic(err) + } + + return string(raw) +} + +// GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "metrics-prefix"), defaultConfig.MetricsPrefix, "An optional prefix for all published metrics.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "certDir"), defaultConfig.CertDir, "Certificate directory to use to write generated certs. Defaults to /etc/webhook/certs/") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "listenPort"), defaultConfig.ListenPort, "The port to use to listen to webhook calls. Defaults to 9443") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "serviceName"), defaultConfig.ServiceName, "The name of the webhook service.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "secretName"), defaultConfig.SecretName, "Secret name to write generated certs to.") + return cmdFlags +} diff --git a/pkg/webhook/config_flags_test.go b/pkg/webhook/config_flags_test.go new file mode 100755 index 000000000..c012d7504 --- /dev/null +++ b/pkg/webhook/config_flags_test.go @@ -0,0 +1,212 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package webhook + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsConfig = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementConfig(t reflect.Kind) bool { + _, exists := dereferencableKindsConfig[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookConfig(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementConfig(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_Config(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookConfig, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_Config(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_Config(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_Config(val, result)) +} + +func testDecodeSlice_Config(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_Config(vStringSlice, result)) +} + +func TestConfig_GetPFlagSet(t *testing.T) { + val := Config{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestConfig_SetFlags(t *testing.T) { + actual := Config{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_metrics-prefix", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("metrics-prefix"); err == nil { + assert.Equal(t, string(defaultConfig.MetricsPrefix), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("metrics-prefix", testValue) + if vString, err := cmdFlags.GetString("metrics-prefix"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.MetricsPrefix) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_certDir", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("certDir"); err == nil { + assert.Equal(t, string(defaultConfig.CertDir), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("certDir", testValue) + if vString, err := cmdFlags.GetString("certDir"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.CertDir) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_listenPort", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("listenPort"); err == nil { + assert.Equal(t, int(defaultConfig.ListenPort), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("listenPort", testValue) + if vInt, err := cmdFlags.GetInt("listenPort"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.ListenPort) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_serviceName", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("serviceName"); err == nil { + assert.Equal(t, string(defaultConfig.ServiceName), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("serviceName", testValue) + if vString, err := cmdFlags.GetString("serviceName"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.ServiceName) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_secretName", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("secretName"); err == nil { + assert.Equal(t, string(defaultConfig.SecretName), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("secretName", testValue) + if vString, err := cmdFlags.GetString("secretName"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.SecretName) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/pkg/webhook/global_secrets.go b/pkg/webhook/global_secrets.go new file mode 100644 index 000000000..f1f86ac96 --- /dev/null +++ b/pkg/webhook/global_secrets.go @@ -0,0 +1,76 @@ +package webhook + +import ( + "context" + "fmt" + "strings" + + coreIdl "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flytestdlib/logger" + corev1 "k8s.io/api/core/v1" +) + +//go:generate mockery -all -case=underscore + +type GlobalSecretProvider interface { + GetForSecret(ctx context.Context, secret *coreIdl.Secret) (string, error) +} + +// GlobalSecrets allows the injection of secrets from the process memory space (env vars) or mounted files into pods +// intercepted through this admission webhook. Secrets injected through this type will be mounted as environment +// variables. If a secret has a mounting requirement that does not allow Env Vars, it'll fail to inject the secret. +type GlobalSecrets struct { + envSecretManager GlobalSecretProvider +} + +func (g GlobalSecrets) ID() string { + return "global" +} + +func (g GlobalSecrets) Inject(ctx context.Context, secret *coreIdl.Secret, p *corev1.Pod) (newP *corev1.Pod, injected bool, err error) { + v, err := g.envSecretManager.GetForSecret(ctx, secret) + if err != nil { + return p, false, err + } + + switch secret.MountRequirement { + case coreIdl.Secret_FILE: + return nil, false, fmt.Errorf("global secrets can only be injected as environment "+ + "variables [%v/%v]", secret.Group, secret.Key) + case coreIdl.Secret_ANY: + fallthrough + case coreIdl.Secret_ENV_VAR: + if len(secret.Group) == 0 { + return nil, false, fmt.Errorf("mounting a secret to env var requires selecting the "+ + "secret and a single key within. Key [%v]", secret.Key) + } + + envVar := corev1.EnvVar{ + Name: strings.ToUpper(K8sDefaultEnvVarPrefix + secret.Group + EnvVarGroupKeySeparator + secret.Key), + Value: v, + } + + prefixEnvVar := corev1.EnvVar{ + Name: K8sEnvVarPrefix, + Value: K8sDefaultEnvVarPrefix, + } + + p.Spec.InitContainers = UpdateEnvVars(p.Spec.InitContainers, prefixEnvVar) + p.Spec.Containers = UpdateEnvVars(p.Spec.Containers, prefixEnvVar) + + p.Spec.InitContainers = UpdateEnvVars(p.Spec.InitContainers, envVar) + p.Spec.Containers = UpdateEnvVars(p.Spec.Containers, envVar) + default: + err := fmt.Errorf("unrecognized mount requirement [%v] for secret [%v]", secret.MountRequirement.String(), secret.Key) + logger.Error(ctx, err) + return p, false, err + } + + return p, true, nil +} + +func NewGlobalSecrets(provider GlobalSecretProvider) GlobalSecrets { + return GlobalSecrets{ + envSecretManager: provider, + } +} diff --git a/pkg/webhook/global_secrets_test.go b/pkg/webhook/global_secrets_test.go new file mode 100644 index 000000000..306b4a74e --- /dev/null +++ b/pkg/webhook/global_secrets_test.go @@ -0,0 +1,99 @@ +package webhook + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/go-test/deep" + + coreIdl "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flytepropeller/pkg/webhook/mocks" + "github.com/stretchr/testify/mock" + corev1 "k8s.io/api/core/v1" +) + +func TestGlobalSecrets_Inject(t *testing.T) { + secretFound := &mocks.GlobalSecretProvider{} + secretFound.OnGetForSecretMatch(mock.Anything, mock.Anything).Return("my_password", nil) + + secretNotFound := &mocks.GlobalSecretProvider{} + secretNotFound.OnGetForSecretMatch(mock.Anything, mock.Anything).Return("", fmt.Errorf("secret not found")) + + inputPod := corev1.Pod{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "container1", + }, + }, + }, + } + + successPod := corev1.Pod{ + Spec: corev1.PodSpec{ + InitContainers: []corev1.Container{}, + Containers: []corev1.Container{ + { + Name: "container1", + Env: []corev1.EnvVar{ + { + Name: "FLYTE_SECRETS_ENV_PREFIX", + Value: "_FSEC_", + }, + { + Name: "_FSEC_GROUP_HELLO", + Value: "my_password", + }, + }, + }, + }, + }, + } + + type args struct { + secret *coreIdl.Secret + p *corev1.Pod + } + tests := []struct { + name string + envSecretManager GlobalSecretProvider + args args + want *corev1.Pod + wantErr bool + }{ + {name: "require group", envSecretManager: secretFound, args: args{secret: &coreIdl.Secret{Key: "hello"}, p: &corev1.Pod{}}, + want: &corev1.Pod{}, wantErr: true}, + {name: "simple", envSecretManager: secretFound, args: args{secret: &coreIdl.Secret{Group: "group", Key: "hello"}, p: &inputPod}, + want: &successPod, wantErr: false}, + {name: "require file", envSecretManager: secretFound, args: args{secret: &coreIdl.Secret{Key: "hello", MountRequirement: coreIdl.Secret_FILE}, + p: &corev1.Pod{}}, + want: &corev1.Pod{}, wantErr: true}, + {name: "not found", envSecretManager: secretNotFound, args: args{secret: &coreIdl.Secret{Key: "hello", MountRequirement: coreIdl.Secret_FILE}, + p: &corev1.Pod{}}, + want: &corev1.Pod{}, wantErr: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := GlobalSecrets{ + envSecretManager: tt.envSecretManager, + } + + assert.NotEmpty(t, g.ID()) + + got, _, err := g.Inject(context.Background(), tt.args.secret, tt.args.p) + if (err != nil) != tt.wantErr { + t.Errorf("Inject() error = %v, wantErr %v", err, tt.wantErr) + return + } else if err != nil { + return + } + + if diff := deep.Equal(got, tt.want); diff != nil { + t.Errorf("Inject() Diff = %v\r\n got = %v\r\n want = %v", diff, got, tt.want) + } + }) + } +} diff --git a/pkg/webhook/k8s_secrets.go b/pkg/webhook/k8s_secrets.go new file mode 100644 index 000000000..0efe3ede9 --- /dev/null +++ b/pkg/webhook/k8s_secrets.go @@ -0,0 +1,102 @@ +package webhook + +import ( + "context" + "fmt" + "os" + "path/filepath" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flytestdlib/logger" + corev1 "k8s.io/api/core/v1" +) + +const ( + K8sPathDefaultDirEnvVar = "FLYTE_SECRETS_DEFAULT_DIR" + K8sPathFilePrefixEnvVar = "FLYTE_SECRETS_FILE_PREFIX" + K8sEnvVarPrefix = "FLYTE_SECRETS_ENV_PREFIX" + K8sDefaultEnvVarPrefix = "_FSEC_" + EnvVarGroupKeySeparator = "_" +) + +var ( + K8sSecretPathPrefix = []string{string(os.PathSeparator), "etc", "flyte", "secrets"} +) + +// K8sSecretInjector allows injecting of secrets into pods by specifying either EnvVarSource or SecretVolumeSource in +// the Pod Spec. It'll, by default, mount secrets as files into pods. +// The current version does not allow mounting an entire secret object (with all keys inside it). It only supports mounting +// a single key from the referenced secret object. +// The secret.Group will be used to reference the k8s secret object, the Secret.Key will be used to reference a key inside +// and the secret.Version will be ignored. +// Environment variables will be named _FSEC__. Files will be mounted on +// /etc/flyte/secrets// +type K8sSecretInjector struct { +} + +func (i K8sSecretInjector) ID() string { + return "K8s" +} + +func (i K8sSecretInjector) Inject(ctx context.Context, secret *core.Secret, p *corev1.Pod) (newP *corev1.Pod, injected bool, err error) { + if len(secret.Group) == 0 || len(secret.Key) == 0 { + return nil, false, fmt.Errorf("k8s Secrets Webhook require both key and group to be set. "+ + "Secret: [%v]", secret) + } + + switch secret.MountRequirement { + case core.Secret_ANY: + fallthrough + case core.Secret_FILE: + // Inject a Volume that to the pod and all of its containers and init containers that mounts the secret into a + // file. + + volume := CreateVolumeForSecret(secret) + p.Spec.Volumes = append(p.Spec.Volumes, volume) + + // Mount the secret to all containers in the given pod. + mount := CreateVolumeMountForSecret(volume.Name, secret) + p.Spec.InitContainers = UpdateVolumeMounts(p.Spec.InitContainers, mount) + p.Spec.Containers = UpdateVolumeMounts(p.Spec.Containers, mount) + + // Set environment variable to let the container know where to find the mounted files. + defaultDirEnvVar := corev1.EnvVar{ + Name: K8sPathDefaultDirEnvVar, + Value: filepath.Join(K8sSecretPathPrefix...), + } + + p.Spec.InitContainers = UpdateEnvVars(p.Spec.InitContainers, defaultDirEnvVar) + p.Spec.Containers = UpdateEnvVars(p.Spec.Containers, defaultDirEnvVar) + + // Sets an empty prefix to let the containers know the file names will match the secret keys as-is. + prefixEnvVar := corev1.EnvVar{ + Name: K8sPathFilePrefixEnvVar, + Value: "", + } + + p.Spec.InitContainers = UpdateEnvVars(p.Spec.InitContainers, prefixEnvVar) + p.Spec.Containers = UpdateEnvVars(p.Spec.Containers, prefixEnvVar) + case core.Secret_ENV_VAR: + envVar := CreateEnvVarForSecret(secret) + p.Spec.InitContainers = UpdateEnvVars(p.Spec.InitContainers, envVar) + p.Spec.Containers = UpdateEnvVars(p.Spec.Containers, envVar) + + prefixEnvVar := corev1.EnvVar{ + Name: K8sEnvVarPrefix, + Value: K8sDefaultEnvVarPrefix, + } + + p.Spec.InitContainers = UpdateEnvVars(p.Spec.InitContainers, prefixEnvVar) + p.Spec.Containers = UpdateEnvVars(p.Spec.Containers, prefixEnvVar) + default: + err := fmt.Errorf("unrecognized mount requirement [%v] for secret [%v]", secret.MountRequirement.String(), secret.Key) + logger.Error(ctx, err) + return p, false, err + } + + return p, true, nil +} + +func NewK8sSecretsInjector() K8sSecretInjector { + return K8sSecretInjector{} +} diff --git a/pkg/webhook/k8s_secrets_test.go b/pkg/webhook/k8s_secrets_test.go new file mode 100644 index 000000000..c855bda01 --- /dev/null +++ b/pkg/webhook/k8s_secrets_test.go @@ -0,0 +1,171 @@ +package webhook + +import ( + "context" + "testing" + + "github.com/go-test/deep" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + coreIdl "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + corev1 "k8s.io/api/core/v1" +) + +func TestK8sSecretInjector_Inject(t *testing.T) { + inputPod := corev1.Pod{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "container1", + }, + }, + }, + } + + successPodEnv := corev1.Pod{ + Spec: corev1.PodSpec{ + InitContainers: []corev1.Container{}, + Containers: []corev1.Container{ + { + Name: "container1", + Env: []corev1.EnvVar{ + { + Name: "_FSEC_GROUP_HELLO", + ValueFrom: &corev1.EnvVarSource{ + SecretKeyRef: &corev1.SecretKeySelector{ + Key: "hello", + LocalObjectReference: corev1.LocalObjectReference{ + Name: "group", + }, + }, + }, + }, + { + Name: "FLYTE_SECRETS_ENV_PREFIX", + Value: "_FSEC_", + }, + }, + }, + }, + }, + } + + successPodFile := corev1.Pod{ + Spec: corev1.PodSpec{ + Volumes: []corev1.Volume{ + { + Name: "m4zg54lql4ugk2dmn4pq", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: "group", + Items: []corev1.KeyToPath{ + { + Key: "hello", + Path: "hello", + }, + }, + }, + }, + }, + }, + InitContainers: []corev1.Container{}, + Containers: []corev1.Container{ + { + Name: "container1", + VolumeMounts: []corev1.VolumeMount{ + { + Name: "m4zg54lql4ugk2dmn4pq", + MountPath: "/etc/flyte/secrets/group", + ReadOnly: true, + }, + }, + Env: []corev1.EnvVar{ + { + Name: "FLYTE_SECRETS_DEFAULT_DIR", + Value: "/etc/flyte/secrets", + }, + { + Name: "FLYTE_SECRETS_FILE_PREFIX", + }, + }, + }, + }, + }, + } + + successPodFileAllKeys := corev1.Pod{ + Spec: corev1.PodSpec{ + Volumes: []corev1.Volume{ + { + Name: "hello", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: "hello", + }, + }, + }, + }, + InitContainers: []corev1.Container{}, + Containers: []corev1.Container{ + { + Name: "container1", + VolumeMounts: []corev1.VolumeMount{ + { + Name: "hello", + MountPath: "/etc/flyte/secrets/hello", + ReadOnly: true, + }, + }, + Env: []corev1.EnvVar{ + { + Name: "FLYTE_SECRETS_DEFAULT_DIR", + Value: "/etc/flyte/secrets", + }, + { + Name: "FLYTE_SECRETS_FILE_PREFIX", + }, + }, + }, + }, + }, + } + + ctx := context.Background() + type args struct { + secret *core.Secret + p *corev1.Pod + } + tests := []struct { + name string + args args + want *corev1.Pod + wantErr bool + }{ + {name: "require group", args: args{secret: &coreIdl.Secret{Key: "hello", MountRequirement: coreIdl.Secret_ENV_VAR}, p: &corev1.Pod{}}, + want: &corev1.Pod{}, wantErr: true}, + {name: "simple", args: args{secret: &coreIdl.Secret{Group: "group", Key: "hello", MountRequirement: coreIdl.Secret_ENV_VAR}, p: inputPod.DeepCopy()}, + want: &successPodEnv, wantErr: false}, + {name: "require file single", args: args{secret: &coreIdl.Secret{Group: "group", Key: "hello", MountRequirement: coreIdl.Secret_FILE}, + p: inputPod.DeepCopy()}, + want: &successPodFile, wantErr: false}, + {name: "require file all keys", args: args{secret: &coreIdl.Secret{Key: "hello", MountRequirement: coreIdl.Secret_FILE}, + p: inputPod.DeepCopy()}, + want: &successPodFileAllKeys, wantErr: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + i := K8sSecretInjector{} + got, _, err := i.Inject(ctx, tt.args.secret, tt.args.p) + if (err != nil) != tt.wantErr { + t.Errorf("Inject() error = %v, wantErr %v", err, tt.wantErr) + return + } else if err != nil { + return + } + + if diff := deep.Equal(got, tt.want); diff != nil { + t.Errorf("Inject() Diff = %v\r\n got = %v\r\n want = %v", diff, got, tt.want) + } + }) + } +} diff --git a/pkg/webhook/mocks/global_secret_provider.go b/pkg/webhook/mocks/global_secret_provider.go new file mode 100644 index 000000000..af7376496 --- /dev/null +++ b/pkg/webhook/mocks/global_secret_provider.go @@ -0,0 +1,54 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + core "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + mock "github.com/stretchr/testify/mock" +) + +// GlobalSecretProvider is an autogenerated mock type for the GlobalSecretProvider type +type GlobalSecretProvider struct { + mock.Mock +} + +type GlobalSecretProvider_GetForSecret struct { + *mock.Call +} + +func (_m GlobalSecretProvider_GetForSecret) Return(_a0 string, _a1 error) *GlobalSecretProvider_GetForSecret { + return &GlobalSecretProvider_GetForSecret{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *GlobalSecretProvider) OnGetForSecret(ctx context.Context, secret *core.Secret) *GlobalSecretProvider_GetForSecret { + c := _m.On("GetForSecret", ctx, secret) + return &GlobalSecretProvider_GetForSecret{Call: c} +} + +func (_m *GlobalSecretProvider) OnGetForSecretMatch(matchers ...interface{}) *GlobalSecretProvider_GetForSecret { + c := _m.On("GetForSecret", matchers...) + return &GlobalSecretProvider_GetForSecret{Call: c} +} + +// GetForSecret provides a mock function with given fields: ctx, secret +func (_m *GlobalSecretProvider) GetForSecret(ctx context.Context, secret *core.Secret) (string, error) { + ret := _m.Called(ctx, secret) + + var r0 string + if rf, ok := ret.Get(0).(func(context.Context, *core.Secret) string); ok { + r0 = rf(ctx, secret) + } else { + r0 = ret.Get(0).(string) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *core.Secret) error); ok { + r1 = rf(ctx, secret) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/webhook/mocks/mutator.go b/pkg/webhook/mocks/mutator.go new file mode 100644 index 000000000..f182f1715 --- /dev/null +++ b/pkg/webhook/mocks/mutator.go @@ -0,0 +1,95 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + v1 "k8s.io/api/core/v1" +) + +// Mutator is an autogenerated mock type for the Mutator type +type Mutator struct { + mock.Mock +} + +type Mutator_ID struct { + *mock.Call +} + +func (_m Mutator_ID) Return(_a0 string) *Mutator_ID { + return &Mutator_ID{Call: _m.Call.Return(_a0)} +} + +func (_m *Mutator) OnID() *Mutator_ID { + c := _m.On("ID") + return &Mutator_ID{Call: c} +} + +func (_m *Mutator) OnIDMatch(matchers ...interface{}) *Mutator_ID { + c := _m.On("ID", matchers...) + return &Mutator_ID{Call: c} +} + +// ID provides a mock function with given fields: +func (_m *Mutator) ID() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +type Mutator_Mutate struct { + *mock.Call +} + +func (_m Mutator_Mutate) Return(newP *v1.Pod, changed bool, err error) *Mutator_Mutate { + return &Mutator_Mutate{Call: _m.Call.Return(newP, changed, err)} +} + +func (_m *Mutator) OnMutate(ctx context.Context, p *v1.Pod) *Mutator_Mutate { + c := _m.On("Mutate", ctx, p) + return &Mutator_Mutate{Call: c} +} + +func (_m *Mutator) OnMutateMatch(matchers ...interface{}) *Mutator_Mutate { + c := _m.On("Mutate", matchers...) + return &Mutator_Mutate{Call: c} +} + +// Mutate provides a mock function with given fields: ctx, p +func (_m *Mutator) Mutate(ctx context.Context, p *v1.Pod) (*v1.Pod, bool, error) { + ret := _m.Called(ctx, p) + + var r0 *v1.Pod + if rf, ok := ret.Get(0).(func(context.Context, *v1.Pod) *v1.Pod); ok { + r0 = rf(ctx, p) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.Pod) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(context.Context, *v1.Pod) bool); ok { + r1 = rf(ctx, p) + } else { + r1 = ret.Get(1).(bool) + } + + var r2 error + if rf, ok := ret.Get(2).(func(context.Context, *v1.Pod) error); ok { + r2 = rf(ctx, p) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} diff --git a/pkg/webhook/mocks/secrets_injector.go b/pkg/webhook/mocks/secrets_injector.go new file mode 100644 index 000000000..2fc5da7f0 --- /dev/null +++ b/pkg/webhook/mocks/secrets_injector.go @@ -0,0 +1,97 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + core "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + mock "github.com/stretchr/testify/mock" + + v1 "k8s.io/api/core/v1" +) + +// SecretsInjector is an autogenerated mock type for the SecretsInjector type +type SecretsInjector struct { + mock.Mock +} + +type SecretsInjector_ID struct { + *mock.Call +} + +func (_m SecretsInjector_ID) Return(_a0 string) *SecretsInjector_ID { + return &SecretsInjector_ID{Call: _m.Call.Return(_a0)} +} + +func (_m *SecretsInjector) OnID() *SecretsInjector_ID { + c := _m.On("ID") + return &SecretsInjector_ID{Call: c} +} + +func (_m *SecretsInjector) OnIDMatch(matchers ...interface{}) *SecretsInjector_ID { + c := _m.On("ID", matchers...) + return &SecretsInjector_ID{Call: c} +} + +// ID provides a mock function with given fields: +func (_m *SecretsInjector) ID() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +type SecretsInjector_Inject struct { + *mock.Call +} + +func (_m SecretsInjector_Inject) Return(newP *v1.Pod, injected bool, err error) *SecretsInjector_Inject { + return &SecretsInjector_Inject{Call: _m.Call.Return(newP, injected, err)} +} + +func (_m *SecretsInjector) OnInject(ctx context.Context, secrets *core.Secret, p *v1.Pod) *SecretsInjector_Inject { + c := _m.On("Inject", ctx, secrets, p) + return &SecretsInjector_Inject{Call: c} +} + +func (_m *SecretsInjector) OnInjectMatch(matchers ...interface{}) *SecretsInjector_Inject { + c := _m.On("Inject", matchers...) + return &SecretsInjector_Inject{Call: c} +} + +// Inject provides a mock function with given fields: ctx, secrets, p +func (_m *SecretsInjector) Inject(ctx context.Context, secrets *core.Secret, p *v1.Pod) (*v1.Pod, bool, error) { + ret := _m.Called(ctx, secrets, p) + + var r0 *v1.Pod + if rf, ok := ret.Get(0).(func(context.Context, *core.Secret, *v1.Pod) *v1.Pod); ok { + r0 = rf(ctx, secrets, p) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.Pod) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(context.Context, *core.Secret, *v1.Pod) bool); ok { + r1 = rf(ctx, secrets, p) + } else { + r1 = ret.Get(1).(bool) + } + + var r2 error + if rf, ok := ret.Get(2).(func(context.Context, *core.Secret, *v1.Pod) error); ok { + r2 = rf(ctx, secrets, p) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} diff --git a/pkg/webhook/pod.go b/pkg/webhook/pod.go new file mode 100644 index 000000000..a1ab50c47 --- /dev/null +++ b/pkg/webhook/pod.go @@ -0,0 +1,231 @@ +// The PodMutator is a controller-runtime webhook that intercepts Pod Creation events and mutates them. Currently, there +// is only one registered Mutator, that's the SecretsMutator. It works as follows: +// +// - The Webhook only works on Pods. If propeller/plugins launch a resource outside of K8s (or in a separate k8s +// cluster), it's the responsibility of the plugin to correctly pass secret injection information. +// - When a k8s-plugin builds a resource, propeller's PluginManager will automatically inject a label `inject-flyte +// -secrets: true` and serialize the secret injection information into the annotations. +// - If a plugin does not use the K8sPlugin interface, it's its responsibility to pass secret injection information. +// - If a k8s plugin creates a CRD that launches other Pods (e.g. Spark/PyTorch... etc.), it's its responsibility to +// make sure the labels/annotations set on the CRD by PluginManager are propagated to those launched Pods. This +// ensures secret injection happens no matter how many levels of indirections there are. +// - The Webhook expects 'inject-flyte-secrets: true' as a label on the Pod. Otherwise it won't listen/observe that pod. +// - Once it intercepts the admission request, it goes over all registered Mutators and invoke them in the order they +// are registered as. If a Mutator fails and it's marked as `required`, the operation will fail and the admission +// will be rejected. +// - The SecretsMutator will attempt to lookup the requested secret from the process environment. If the secret is +// already mounted, it'll inject it as plain-text into the Pod Spec (Less secure). +// - If it's not found in the environment it'll, instead, fallback to the enabled Secrets Injector (K8s, Confidant, +// Vault... etc.). +// - Each SecretsInjector will mutate the Pod differently depending on how its backend secrets system injects the secrets +// for example: +// - For K8s secrets, it'll either add EnvFromSecret or VolumeMountSource (depending on the MountRequirement +// stated in the flyteIdl.Secret object) into the Pod. There is no validation that the secret exist and is available +// to the Pod at this point. If the secret is not accessible, the Pod will fail with ContainerCreationConfigError and +// will be retried. +// - For Vault secrets, it'll inject the right annotations to trigger Vault's own sidecar/webhook to mount the secret. +package webhook + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "os" + "path/filepath" + "strings" + + "github.com/flyteorg/flytepropeller/pkg/utils/secrets" + + admissionregistrationv1 "k8s.io/api/admissionregistration/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/promutils" + "k8s.io/apimachinery/pkg/runtime/schema" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/manager" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + + corev1 "k8s.io/api/core/v1" +) + +const webhookName = "flyte-pod-webhook.flyte.org" + +// PodMutator implements controller-runtime WebHook interface. +type PodMutator struct { + decoder *admission.Decoder + cfg *Config + Mutators []MutatorConfig +} + +type MutatorConfig struct { + Mutator Mutator + Required bool +} + +type Mutator interface { + ID() string + Mutate(ctx context.Context, p *corev1.Pod) (newP *corev1.Pod, changed bool, err error) +} + +func (pm *PodMutator) InjectClient(_ client.Client) error { + return nil +} + +// InjectDecoder injects the decoder into a mutatingHandler. +func (pm *PodMutator) InjectDecoder(d *admission.Decoder) error { + pm.decoder = d + return nil +} + +func (pm *PodMutator) Handle(ctx context.Context, request admission.Request) admission.Response { + // Get the object in the request + obj := &corev1.Pod{} + err := pm.decoder.Decode(request, obj) + if err != nil { + return admission.Errored(http.StatusBadRequest, err) + } + + newObj, changed, err := pm.Mutate(ctx, obj) + if err != nil { + return admission.Errored(http.StatusBadRequest, err) + } + + if changed { + marshalled, err := json.Marshal(newObj) + if err != nil { + return admission.Errored(http.StatusInternalServerError, err) + } + + // Create the patch + return admission.PatchResponseFromRaw(request.Object.Raw, marshalled) + } + + return admission.Allowed("No changes") +} + +func (pm PodMutator) Mutate(ctx context.Context, p *corev1.Pod) (newP *corev1.Pod, changed bool, err error) { + newP = p + for _, m := range pm.Mutators { + tempP := newP + tempChanged := false + tempP, tempChanged, err = m.Mutator.Mutate(ctx, tempP) + if err != nil { + if m.Required { + err = fmt.Errorf("failed to mutate using [%v]. Since it's a required mutator, failing early. Error: %v", m.Mutator.ID(), err) + logger.Info(ctx, err) + return p, false, err + } + + logger.Infof(ctx, "Failed to mutate using [%v]. Since it's not a required mutator, skipping. Error: %v", m.Mutator.ID(), err) + continue + } + + newP = tempP + if tempChanged { + changed = true + } + } + + return newP, changed, nil +} + +func (pm *PodMutator) Register(ctx context.Context, mgr manager.Manager) error { + wh := &admission.Webhook{ + Handler: pm, + } + + mutatePath := getPodMutatePath() + logger.Infof(ctx, "Registering path [%v]", mutatePath) + mgr.GetWebhookServer().Register(mutatePath, wh) + return nil +} + +func (pm PodMutator) GetMutatePath() string { + return getPodMutatePath() +} + +func getPodMutatePath() string { + pod := flytek8s.BuildIdentityPod() + return generateMutatePath(pod.GroupVersionKind()) +} + +func generateMutatePath(gvk schema.GroupVersionKind) string { + return "/mutate-" + strings.Replace(gvk.Group, ".", "-", -1) + "-" + + gvk.Version + "-" + strings.ToLower(gvk.Kind) +} + +func (pm PodMutator) CreateMutationWebhookConfiguration(namespace string) (*admissionregistrationv1.MutatingWebhookConfiguration, error) { + caBytes, err := ioutil.ReadFile(filepath.Join(pm.cfg.CertDir, "ca.crt")) + if err != nil { + // ca.crt is optional. If not provided, API Server will assume the webhook is serving SSL using a certificate + // issued by a known Cert Authority. + if os.IsNotExist(err) { + caBytes = make([]byte, 0) + } else { + return nil, err + } + } + + path := pm.GetMutatePath() + fail := admissionregistrationv1.Ignore + sideEffects := admissionregistrationv1.SideEffectClassNoneOnDryRun + + mutateConfig := &admissionregistrationv1.MutatingWebhookConfiguration{ + ObjectMeta: metav1.ObjectMeta{ + Name: pm.cfg.ServiceName, + Namespace: namespace, + }, + Webhooks: []admissionregistrationv1.MutatingWebhook{ + { + Name: webhookName, + ClientConfig: admissionregistrationv1.WebhookClientConfig{ + CABundle: caBytes, // CA bundle created earlier + Service: &admissionregistrationv1.ServiceReference{ + Name: pm.cfg.ServiceName, + Namespace: namespace, + Path: &path, + }, + }, + Rules: []admissionregistrationv1.RuleWithOperations{ + { + Operations: []admissionregistrationv1.OperationType{ + admissionregistrationv1.Create, + }, + Rule: admissionregistrationv1.Rule{ + APIGroups: []string{"*"}, + APIVersions: []string{"v1"}, + Resources: []string{"pods"}, + }, + }, + }, + FailurePolicy: &fail, + SideEffects: &sideEffects, + AdmissionReviewVersions: []string{ + "v1", + "v1beta1", + }, + ObjectSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + secrets.PodLabel: secrets.PodLabelValue, + }, + }, + }}, + } + + return mutateConfig, nil +} + +func NewPodMutator(cfg *Config, scope promutils.Scope) *PodMutator { + return &PodMutator{ + cfg: cfg, + Mutators: []MutatorConfig{ + { + Mutator: NewSecretsMutator(scope.NewSubScope("secrets")), + }, + }, + } +} diff --git a/pkg/webhook/pod_test.go b/pkg/webhook/pod_test.go new file mode 100644 index 000000000..d0af79cd1 --- /dev/null +++ b/pkg/webhook/pod_test.go @@ -0,0 +1,158 @@ +package webhook + +import ( + "context" + "fmt" + "testing" + + "k8s.io/client-go/tools/clientcmd/api/latest" + + "k8s.io/apimachinery/pkg/runtime" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + + "github.com/flyteorg/flytestdlib/promutils" + + "github.com/flyteorg/flytepropeller/pkg/webhook/mocks" + "github.com/stretchr/testify/mock" + + "github.com/stretchr/testify/assert" + admissionv1 "k8s.io/api/admission/v1" + corev1 "k8s.io/api/core/v1" +) + +func TestPodMutator_Mutate(t *testing.T) { + inputPod := &corev1.Pod{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "container1", + }, + }, + }, + } + + successMutator := &mocks.Mutator{} + successMutator.OnID().Return("SucceedingMutator") + successMutator.OnMutateMatch(mock.Anything, mock.Anything).Return(nil, false, nil) + + failedMutator := &mocks.Mutator{} + failedMutator.OnID().Return("FailingMutator") + failedMutator.OnMutateMatch(mock.Anything, mock.Anything).Return(nil, false, fmt.Errorf("failing mock")) + + t.Run("Required Mutator Succeeded", func(t *testing.T) { + pm := &PodMutator{ + Mutators: []MutatorConfig{ + { + Mutator: successMutator, + Required: true, + }, + }, + } + ctx := context.Background() + _, changed, err := pm.Mutate(ctx, inputPod.DeepCopy()) + assert.NoError(t, err) + assert.False(t, changed) + }) + + t.Run("Required Mutator Failed", func(t *testing.T) { + pm := &PodMutator{ + Mutators: []MutatorConfig{ + { + Mutator: failedMutator, + Required: true, + }, + }, + } + ctx := context.Background() + _, _, err := pm.Mutate(ctx, inputPod.DeepCopy()) + assert.Error(t, err) + }) + + t.Run("Non-required Mutator Failed", func(t *testing.T) { + pm := &PodMutator{ + Mutators: []MutatorConfig{ + { + Mutator: failedMutator, + Required: false, + }, + }, + } + ctx := context.Background() + _, _, err := pm.Mutate(ctx, inputPod.DeepCopy()) + assert.NoError(t, err) + }) +} + +func Test_CreateMutationWebhookConfiguration(t *testing.T) { + pm := NewPodMutator(&Config{ + CertDir: "testdata", + ServiceName: "my-service", + }, promutils.NewTestScope()) + + t.Run("Empty namespace", func(t *testing.T) { + c, err := pm.CreateMutationWebhookConfiguration("") + assert.NoError(t, err) + assert.NotNil(t, c) + }) + + t.Run("With namespace", func(t *testing.T) { + c, err := pm.CreateMutationWebhookConfiguration("my-namespace") + assert.NoError(t, err) + assert.NotNil(t, c) + }) +} + +func Test_Handle(t *testing.T) { + pm := NewPodMutator(&Config{ + CertDir: "testdata", + ServiceName: "my-service", + }, promutils.NewTestScope()) + + decoder, err := admission.NewDecoder(latest.Scheme) + assert.NoError(t, err) + assert.NoError(t, pm.InjectDecoder(decoder)) + + req := admission.Request{ + AdmissionRequest: admissionv1.AdmissionRequest{ + Object: runtime.RawExtension{ + Raw: []byte(`{ + "apiVersion": "v1", + "kind": "Pod", + "metadata": { + "name": "foo", + "namespace": "default" + }, + "spec": { + "containers": [ + { + "image": "bar:v2", + "name": "bar" + } + ] + } +}`), + }, + OldObject: runtime.RawExtension{ + Raw: []byte(`{ + "apiVersion": "v1", + "kind": "Pod", + "metadata": { + "name": "foo", + "namespace": "default" + }, + "spec": { + "containers": [ + { + "image": "bar:v1", + "name": "bar" + } + ] + } +}`), + }, + }, + } + + resp := pm.Handle(context.Background(), req) + assert.True(t, resp.Allowed) +} diff --git a/pkg/webhook/secrets.go b/pkg/webhook/secrets.go new file mode 100644 index 000000000..3609e99f9 --- /dev/null +++ b/pkg/webhook/secrets.go @@ -0,0 +1,62 @@ +package webhook + +import ( + "context" + + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/secretmanager" + + "github.com/flyteorg/flytestdlib/logger" + + "github.com/flyteorg/flytestdlib/promutils" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + secretUtils "github.com/flyteorg/flytepropeller/pkg/utils/secrets" + + corev1 "k8s.io/api/core/v1" +) + +type SecretsMutator struct { + injectors []SecretsInjector +} + +type SecretsInjector interface { + ID() string + Inject(ctx context.Context, secrets *core.Secret, p *corev1.Pod) (newP *corev1.Pod, injected bool, err error) +} + +func (s SecretsMutator) ID() string { + return "secrets" +} + +func (s *SecretsMutator) Mutate(ctx context.Context, p *corev1.Pod) (newP *corev1.Pod, injected bool, err error) { + secrets, err := secretUtils.UnmarshalStringMapToSecrets(p.GetAnnotations()) + if err != nil { + return p, false, err + } + + for _, secret := range secrets { + for _, injector := range s.injectors { + p, injected, err = injector.Inject(ctx, secret, p) + if err != nil { + logger.Infof(ctx, "Failed to inject a secret using injector [%v]. Error: %v", injector.ID(), err) + } else if injected { + break + } + } + + if err != nil { + return p, false, err + } + } + + return p, injected, nil +} + +func NewSecretsMutator(_ promutils.Scope) *SecretsMutator { + return &SecretsMutator{ + injectors: []SecretsInjector{ + NewGlobalSecrets(secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig())), + NewK8sSecretsInjector(), + }, + } +} diff --git a/pkg/webhook/secrets_test.go b/pkg/webhook/secrets_test.go new file mode 100644 index 000000000..246d926d9 --- /dev/null +++ b/pkg/webhook/secrets_test.go @@ -0,0 +1,61 @@ +package webhook + +import ( + "context" + "fmt" + "testing" + + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/stretchr/testify/assert" + + "github.com/flyteorg/flytepropeller/pkg/webhook/mocks" + "github.com/stretchr/testify/mock" + + corev1 "k8s.io/api/core/v1" +) + +func TestSecretsWebhook_Mutate(t *testing.T) { + t.Run("No injectors", func(t *testing.T) { + m := SecretsMutator{} + _, changed, err := m.Mutate(context.Background(), &corev1.Pod{}) + assert.NoError(t, err) + assert.False(t, changed) + }) + + podWithAnnotations := &corev1.Pod{ + ObjectMeta: v1.ObjectMeta{ + Annotations: map[string]string{ + "flyte.secrets/s0": "nnsxsorcnv4v623fperca", + }, + }, + } + + t.Run("First fail", func(t *testing.T) { + mutator := &mocks.SecretsInjector{} + mutator.OnInjectMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil, false, fmt.Errorf("failed")) + mutator.OnID().Return("my id") + + m := SecretsMutator{ + injectors: []SecretsInjector{mutator}, + } + + _, changed, err := m.Mutate(context.Background(), podWithAnnotations.DeepCopy()) + assert.Error(t, err) + assert.False(t, changed) + }) + + t.Run("added", func(t *testing.T) { + mutator := &mocks.SecretsInjector{} + mutator.OnInjectMatch(mock.Anything, mock.Anything, mock.Anything).Return(&corev1.Pod{}, true, nil) + mutator.OnID().Return("my id") + + m := SecretsMutator{ + injectors: []SecretsInjector{mutator}, + } + + _, changed, err := m.Mutate(context.Background(), podWithAnnotations.DeepCopy()) + assert.NoError(t, err) + assert.True(t, changed) + }) +} diff --git a/pkg/webhook/testdata/ca.crt b/pkg/webhook/testdata/ca.crt new file mode 100644 index 000000000..5066d5a55 --- /dev/null +++ b/pkg/webhook/testdata/ca.crt @@ -0,0 +1 @@ +SGVsbG8gV29ybGQK diff --git a/pkg/webhook/utils.go b/pkg/webhook/utils.go new file mode 100644 index 000000000..471b1c212 --- /dev/null +++ b/pkg/webhook/utils.go @@ -0,0 +1,84 @@ +package webhook + +import ( + "path/filepath" + "strings" + + "github.com/flyteorg/flytepropeller/pkg/utils" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + corev1 "k8s.io/api/core/v1" +) + +func hasEnvVar(envVars []corev1.EnvVar, envVarKey string) bool { + for _, e := range envVars { + if e.Name == envVarKey { + return true + } + } + + return false +} + +func CreateEnvVarForSecret(secret *core.Secret) corev1.EnvVar { + return corev1.EnvVar{ + Name: strings.ToUpper(K8sDefaultEnvVarPrefix + secret.Group + EnvVarGroupKeySeparator + secret.Key), + ValueFrom: &corev1.EnvVarSource{ + SecretKeyRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{ + Name: secret.Group, + }, + Key: secret.Key, + }, + }, + } +} + +func CreateVolumeForSecret(secret *core.Secret) corev1.Volume { + return corev1.Volume{ + Name: utils.Base32Encoder.EncodeToString([]byte(secret.Group + EnvVarGroupKeySeparator + secret.Key + EnvVarGroupKeySeparator + secret.GroupVersion)), + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: secret.Group, + Items: []corev1.KeyToPath{ + { + Key: secret.Key, + Path: secret.Key, + }, + }, + }, + }, + } +} + +func CreateVolumeMountForSecret(volumeName string, secret *core.Secret) corev1.VolumeMount { + return corev1.VolumeMount{ + Name: volumeName, + ReadOnly: true, + MountPath: filepath.Join(filepath.Join(K8sSecretPathPrefix...), secret.Group), + } +} + +func UpdateVolumeMounts(containers []corev1.Container, mount corev1.VolumeMount) []corev1.Container { + res := make([]corev1.Container, 0, len(containers)) + for _, c := range containers { + c.VolumeMounts = append(c.VolumeMounts, mount) + res = append(res, c) + } + + return res +} + +func UpdateEnvVars(containers []corev1.Container, envVar corev1.EnvVar) []corev1.Container { + res := make([]corev1.Container, 0, len(containers)) + for _, c := range containers { + if !hasEnvVar(c.Env, envVar.Name) { + c.Env = append(c.Env, envVar) + } + + res = append(res, c) + } + + return res +} diff --git a/pkg/webhook/utils_test.go b/pkg/webhook/utils_test.go new file mode 100644 index 000000000..f8e938f1b --- /dev/null +++ b/pkg/webhook/utils_test.go @@ -0,0 +1,178 @@ +package webhook + +import ( + "path/filepath" + "reflect" + "testing" + + corev1 "k8s.io/api/core/v1" +) + +func Test_hasEnvVar(t *testing.T) { + type args struct { + envVars []corev1.EnvVar + envVarKey string + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "exists", + args: args{ + envVars: []corev1.EnvVar{ + { + Name: "ENV_VAR_1", + }, + }, + envVarKey: "ENV_VAR_1", + }, + want: true, + }, + + { + name: "doesn't exist", + args: args{ + envVars: []corev1.EnvVar{ + { + Name: "ENV_VAR_1", + }, + }, + envVarKey: "ENV_VAR", + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := hasEnvVar(tt.args.envVars, tt.args.envVarKey); got != tt.want { + t.Errorf("hasEnvVar() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUpdateVolumeMounts(t *testing.T) { + type args struct { + containers []corev1.Container + volumeMount corev1.VolumeMount + } + tests := []struct { + name string + args args + want []corev1.Container + }{ + { + name: "volume", + args: args{ + containers: []corev1.Container{ + { + Name: "my_container", + }, + }, + volumeMount: corev1.VolumeMount{ + Name: "my_secret", + ReadOnly: true, + MountPath: filepath.Join(filepath.Join(K8sSecretPathPrefix...), "my_secret"), + }, + }, + want: []corev1.Container{ + { + Name: "my_container", + VolumeMounts: []corev1.VolumeMount{ + { + Name: "my_secret", + ReadOnly: true, + MountPath: filepath.Join(filepath.Join(K8sSecretPathPrefix...), "my_secret"), + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := UpdateVolumeMounts(tt.args.containers, tt.args.volumeMount); !reflect.DeepEqual(got, tt.want) { + t.Errorf("UpdateVolumeMounts() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUpdateEnvVars(t *testing.T) { + type args struct { + containers []corev1.Container + envVar corev1.EnvVar + } + tests := []struct { + name string + args args + want []corev1.Container + }{ + { + name: "env vars already exists", + args: args{ + containers: []corev1.Container{ + { + Name: "my_container", + Env: []corev1.EnvVar{ + { + Name: "my_secret", + Value: "my_val_already", + }, + }, + }, + }, + envVar: corev1.EnvVar{ + Name: "my_secret", + Value: "my_val", + }, + }, + want: []corev1.Container{ + { + Name: "my_container", + Env: []corev1.EnvVar{ + { + Name: "my_secret", + Value: "my_val_already", + }, + }, + }, + }, + }, + { + name: "env vars already added", + args: args{ + containers: []corev1.Container{ + { + Name: "my_container", + }, + }, + envVar: corev1.EnvVar{ + Name: "my_secret", + Value: "my_val", + }, + }, + want: []corev1.Container{ + { + Name: "my_container", + Env: []corev1.EnvVar{ + { + Name: "my_secret", + Value: "my_val", + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := UpdateEnvVars(tt.args.containers, tt.args.envVar); !reflect.DeepEqual(got, tt.want) { + t.Errorf("UpdateEnvVars() = %v, want %v", got, tt.want) + } + }) + } +}