Skip to content

Commit

Permalink
#minor Global Pod Security Context and Security context support for a…
Browse files Browse the repository at this point in the history
…ll pods (flyteorg#223)

* Security Context added

Signed-off-by: Ketan Umare <[email protected]>

* Spark and all pods in Flyte now can have global Pod and Container security context

Signed-off-by: Ketan Umare <[email protected]>

* updated docs

Signed-off-by: Ketan Umare <[email protected]>
  • Loading branch information
kumare3 authored Dec 3, 2021
1 parent ba6908a commit 592f234
Show file tree
Hide file tree
Showing 13 changed files with 138 additions and 21 deletions.
6 changes: 3 additions & 3 deletions flyteplugins/go/tasks/aws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Copyright (c) 2018 Lyft. All rights reserved.
*/

// AWS-specific logic to handle execution and monitoring of batch jobs.
// Package aws contains AWS-specific logic to handle execution and monitoring of batch jobs.
package aws

import (
Expand All @@ -22,7 +22,7 @@ import (
)

const (
EnvSharedCredFilePath = "AWS_SHARED_CREDENTIALS_FILE"
EnvSharedCredFilePath = "AWS_SHARED_CREDENTIALS_FILE" // #nosec
EnvAwsProfile = "AWS_PROFILE"
ErrEmptyCredentials errors.ErrorCode = "EMPTY_CREDS"
ErrUnknownHost errors.ErrorCode = "UNKNOWN_HOST"
Expand All @@ -37,7 +37,7 @@ var single = singleton{
lock: sync.RWMutex{},
}

// A generic AWS Client that can be used for all AWS Client libraries.
// Client is a generic AWS Client that can be used for all AWS Client libraries.
type Client interface {
GetSession() *session.Session
GetSdkConfig() *aws.Config
Expand Down
10 changes: 10 additions & 0 deletions flyteplugins/go/tasks/config_load_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ func TestLoadConfig(t *testing.T) {
assert.Equal(t, map[string]string{"x/interruptible": "true"}, k8sConfig.InterruptibleNodeSelector)
assert.Equal(t, "x/flyte", k8sConfig.InterruptibleTolerations[0].Key)
assert.Equal(t, "interruptible", k8sConfig.InterruptibleTolerations[0].Value)
assert.NotNil(t, k8sConfig.DefaultPodSecurityContext)
assert.NotNil(t, k8sConfig.DefaultPodSecurityContext.FSGroup)
assert.Equal(t, *k8sConfig.DefaultPodSecurityContext.FSGroup, int64(2000))
assert.NotNil(t, k8sConfig.DefaultPodSecurityContext.RunAsGroup)
assert.Equal(t, *k8sConfig.DefaultPodSecurityContext.RunAsGroup, int64(3000))
assert.NotNil(t, k8sConfig.DefaultPodSecurityContext.RunAsUser)
assert.Equal(t, *k8sConfig.DefaultPodSecurityContext.RunAsUser, int64(1000))
assert.NotNil(t, k8sConfig.DefaultSecurityContext)
assert.NotNil(t, k8sConfig.DefaultSecurityContext.AllowPrivilegeEscalation)
assert.False(t, *k8sConfig.DefaultSecurityContext.AllowPrivilegeEscalation)
})

t.Run("logs-config-test", func(t *testing.T) {
Expand Down
24 changes: 18 additions & 6 deletions flyteplugins/go/tasks/pluginmachinery/flytek8s/config/config.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// This package contains configuration for the flytek8s module.
// Package config contains configuration for the flytek8s module - which is global configuration for all Flyte K8s interactions.
// This config is under the subsection `k8s` and registered under the Plugin config
// All K8s based plugins can optionally use the flytek8s module and this configuration allows controlling the defaults
// For example if for every container execution if some default Environment Variables or Annotations should be used, then they can be configured here
Expand Down Expand Up @@ -58,9 +58,11 @@ var (
K8sPluginConfigSection = config.MustRegisterSubSection(k8sPluginConfigSectionKey, &defaultK8sConfig)
)

// Top level k8s plugin config.
// K8sPluginConfig should be used to configure per-pod defaults for the entire platform. This allows adding global defaults
// for pods that are being launched. For example, default annotations, labels, if a finalizer should be injected,
// if taints/tolerations should be used for certain resource types etc.
type K8sPluginConfig struct {
// Boolean flag that indicates if a finalizer should be injected into every K8s resource launched
// InjectFinalizer is a boolean flag that indicates if a finalizer should be injected into every K8s resource launched
InjectFinalizer bool `json:"inject-finalizer" pflag:",Instructs the plugin to inject a finalizer on startTask and remove it on task termination."`

// -------------------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -128,9 +130,19 @@ type K8sPluginConfig struct {
CreateContainerErrorGracePeriod config2.Duration `json:"create-container-error-grace-period" pflag:"-,Time to wait for transient CreateContainerError errors to be resolved."`

// The name of the GPU resource to use when the task resource requests GPUs.
GpuResourceName v1.ResourceName `json:"gpu-resource-name" pflag:",The name of the GPU resource to use when the task resource requests GPUs."`
GpuResourceName v1.ResourceName `json:"gpu-resource-name" pflag:"-,The name of the GPU resource to use when the task resource requests GPUs."`

// DefaultPodSecurityContext provides a default pod security context that should be applied for every pod that is launched by FlytePropeller. This may not be applicable to all plugins. For
// downstream plugins - i.e. TensorflowOperators may not support setting this, but Spark does.
DefaultPodSecurityContext *v1.PodSecurityContext `json:"default-pod-security-context" pflag:"-,Optionally specify any default pod security context that should be applied to every Pod launched by FlytePropeller."`

// DefaultSecurityContext provides a default container security context that should be applied for the primary container launched and created by FlytePropeller. This may not be applicable to all plugins. For
// // downstream plugins - i.e. TensorflowOperators may not support setting this, but Spark does.
DefaultSecurityContext *v1.SecurityContext `json:"default-security-context" pflag:"-,Optionally specify a default security context that should be applied to every container launched/created by FlytePropeller. This will not be applied to plugins that do not support it or to user supplied containers in pod tasks."`
}

// FlyteCoPilotConfig specifies configuration for the Flyte CoPilot system. FlyteCoPilot, allows running flytekit-less containers
// in K8s, where the IO is managed by the FlyteCoPilot sidecar process.
type FlyteCoPilotConfig struct {
// Co-pilot sidecar container name
NamePrefix string `json:"name" pflag:",Flyte co-pilot sidecar container name prefix. (additional bits will be added after this)"`
Expand All @@ -153,12 +165,12 @@ type FlyteCoPilotConfig struct {
Storage string `json:"storage" pflag:",Default storage limit for individual inputs / outputs"`
}

// Retrieves the current k8s plugin config or default.
// GetK8sPluginConfig retrieves the current k8s plugin config or default.
func GetK8sPluginConfig() *K8sPluginConfig {
return K8sPluginConfigSection.GetConfig().(*K8sPluginConfig)
}

// [FOR TESTING ONLY] Sets current value for the config.
// SetK8sPluginConfig should be used for TESTING ONLY, It Sets current value for the config.
func SetK8sPluginConfig(cfg *K8sPluginConfig) error {
return K8sPluginConfigSection.SetConfig(cfg)
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@ func ToK8sContainer(ctx context.Context, taskContainer *core.Container, iFace *c
if err := AddCoPilotToContainer(ctx, config.GetK8sPluginConfig().CoPilot, container, iFace, taskContainer.DataConfig); err != nil {
return nil, err
}
if container.SecurityContext == nil && config.GetK8sPluginConfig().DefaultSecurityContext != nil {
container.SecurityContext = config.GetK8sPluginConfig().DefaultSecurityContext.DeepCopy()
}
return container, nil
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@ import (
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config"
mocks2 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks"
"github.com/flyteorg/flytestdlib/storage"
"github.com/stretchr/testify/mock"
"k8s.io/apimachinery/pkg/util/validation"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"

v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
"k8s.io/apimachinery/pkg/util/validation"
)

var zeroQuantity = resource.MustParse("0")
Expand Down Expand Up @@ -369,6 +371,13 @@ func TestToK8sContainer(t *testing.T) {
TaskExecMetadata: &mockTaskExecMetadata,
}

cfg := config.GetK8sPluginConfig()
allow := false
cfg.DefaultSecurityContext = &v1.SecurityContext{
AllowPrivilegeEscalation: &allow,
}
assert.NoError(t, config.SetK8sPluginConfig(cfg))

container, err := ToK8sContainer(context.TODO(), taskContainer, nil, templateParameters)
assert.NoError(t, err)
assert.Equal(t, container.Image, "myimage")
Expand All @@ -390,6 +399,8 @@ func TestToK8sContainer(t *testing.T) {
}, container.Env)
errs := validation.IsDNS1123Label(container.Name)
assert.Nil(t, errs)
assert.NotNil(t, container.SecurityContext)
assert.False(t, *container.SecurityContext.AllowPrivilegeEscalation)
}

func getTemplateParametersForTest(resourceRequirements, platformResources *v1.ResourceRequirements) template.Parameters {
Expand Down
19 changes: 17 additions & 2 deletions flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const OOMKilled = "OOMKilled"
const Interrupted = "Interrupted"
const SIGKILL = 137

// ApplyInterruptibleNodeAffinity configures the node-affinity for the pod using the configuration specified.
func ApplyInterruptibleNodeAffinity(interruptible bool, podSpec *v1.PodSpec) {
// Determine node selector terms to add to node affinity
var nodeSelectorRequirement v1.NodeSelectorRequirement
Expand Down Expand Up @@ -58,13 +59,13 @@ func ApplyInterruptibleNodeAffinity(interruptible bool, podSpec *v1.PodSpec) {
}
}

// Updates the base pod spec used to execute tasks. This is configured with plugins and task metadata-specific options
// UpdatePod updates the base pod spec used to execute tasks. This is configured with plugins and task metadata-specific options
func UpdatePod(taskExecutionMetadata pluginsCore.TaskExecutionMetadata,
resourceRequirements []v1.ResourceRequirements, podSpec *v1.PodSpec) {
UpdatePodWithInterruptibleFlag(taskExecutionMetadata, resourceRequirements, podSpec, false)
}

// Updates the base pod spec used to execute tasks. This is configured with plugins and task metadata-specific options
// UpdatePodWithInterruptibleFlag updates the base pod spec used to execute tasks. This is configured with plugins and task metadata-specific options
func UpdatePodWithInterruptibleFlag(taskExecutionMetadata pluginsCore.TaskExecutionMetadata,
resourceRequirements []v1.ResourceRequirements, podSpec *v1.PodSpec, omitInterruptible bool) {
isInterruptible := !omitInterruptible && taskExecutionMetadata.IsInterruptible()
Expand All @@ -87,12 +88,18 @@ func UpdatePodWithInterruptibleFlag(taskExecutionMetadata pluginsCore.TaskExecut
if podSpec.Affinity == nil && config.GetK8sPluginConfig().DefaultAffinity != nil {
podSpec.Affinity = config.GetK8sPluginConfig().DefaultAffinity.DeepCopy()
}
if podSpec.SecurityContext == nil && config.GetK8sPluginConfig().DefaultPodSecurityContext != nil {
podSpec.SecurityContext = config.GetK8sPluginConfig().DefaultPodSecurityContext.DeepCopy()
}
ApplyInterruptibleNodeAffinity(isInterruptible, podSpec)
}

// ToK8sPodSpec constructs a pod spec from the given TaskTemplate
func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*v1.PodSpec, error) {
return ToK8sPodSpecWithInterruptible(ctx, tCtx, false)
}

// ToK8sPodSpecWithInterruptible constructs a pod spec from the gien TaskTemplate and optionally add (interruptible instance) support.
func ToK8sPodSpecWithInterruptible(ctx context.Context, tCtx pluginsCore.TaskExecutionContext, omitInterruptible bool) (*v1.PodSpec, error) {
task, err := tCtx.TaskReader().Read(ctx)
if err != nil {
Expand Down Expand Up @@ -154,6 +161,10 @@ func BuildIdentityPod() *v1.Pod {
}
}

// DemystifyPending is one the core functions, that helps FlytePropeller determine if a pending pod is indeed pending,
// or it is actually stuck in a un-reparable state. In such a case the pod should be marked as dead and the task should
// be retried. This has to be handled sadly, as K8s is still largely designed for long running services that should
// recover from failures, but Flyte pods are completely automated and should either run or fail
// Important considerations.
// Pending Status in Pod could be for various reasons and sometimes could signal a problem
// Case I: Pending because the Image pull is failing and it is backing off
Expand Down Expand Up @@ -308,6 +319,9 @@ func DemystifySuccess(status v1.PodStatus, info pluginsCore.TaskInfo) (pluginsCo
return pluginsCore.PhaseInfoSuccess(&info), nil
}

// DeterminePrimaryContainerPhase as the name suggests, given all the containers, will return a pluginsCore.PhaseInfo object
// corresponding to the phase of the primaryContainer which is identified using the provided name.
// This is useful in case of sidecars or pod jobs, where Flyte will monitor successful exit of a single container.
func DeterminePrimaryContainerPhase(primaryContainerName string, statuses []v1.ContainerStatus, info *pluginsCore.TaskInfo) pluginsCore.PhaseInfo {
for _, s := range statuses {
if s.Name == primaryContainerName {
Expand All @@ -330,6 +344,7 @@ func DeterminePrimaryContainerPhase(primaryContainerName string, statuses []v1.C
fmt.Sprintf("Primary container [%s] not found in pod's container statuses", primaryContainerName), info)
}

// ConvertPodFailureToError retruns a legible error message and code from a failed v1.PodStatus field
func ConvertPodFailureToError(status v1.PodStatus) (code, message string) {
code = "UnknownError"
message = "Pod failed. No message received from kubernetes."
Expand Down
16 changes: 16 additions & 0 deletions flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,22 @@ func TestToK8sPod(t *testing.T) {
assert.Equal(t, 1, len(p.NodeSelector))
assert.Equal(t, "myScheduler", p.SchedulerName)
assert.Equal(t, "some-acceptable-name", p.Containers[0].Name)
assert.Nil(t, p.SecurityContext)
})

t.Run("default-pod-sec-ctx", func(t *testing.T) {
v := int64(1000)
assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{
DefaultPodSecurityContext: &v1.PodSecurityContext{
RunAsGroup: &v,
},
}))

x := dummyExecContext(&v1.ResourceRequirements{})
p, err := ToK8sPodSpec(ctx, x)
assert.NoError(t, err)
assert.NotNil(t, p.SecurityContext)
assert.Equal(t, *p.SecurityContext.RunAsGroup, v)
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,9 @@ plugins:
- FLYTE_AWS_ENDPOINT: "http://minio.flyte:9000"
- FLYTE_AWS_ACCESS_KEY_ID: minio
- FLYTE_AWS_SECRET_ACCESS_KEY: miniostorage
default-pod-security-context:
runAsUser: 1000
runAsGroup: 3000
fsGroup: 2000
default-security-context:
allowPrivilegeEscalation: false
18 changes: 10 additions & 8 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,22 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
}
driverSpec := sparkOp.DriverSpec{
SparkPodSpec: sparkOp.SparkPodSpec{
Annotations: annotations,
Labels: labels,
EnvVars: sparkEnvVars,
Image: &container.Image,
Annotations: annotations,
Labels: labels,
EnvVars: sparkEnvVars,
Image: &container.Image,
SecurityContenxt: config.GetK8sPluginConfig().DefaultPodSecurityContext.DeepCopy(),
},
ServiceAccount: &serviceAccountName,
}

executorSpec := sparkOp.ExecutorSpec{
SparkPodSpec: sparkOp.SparkPodSpec{
Annotations: annotations,
Labels: labels,
Image: &container.Image,
EnvVars: sparkEnvVars,
Annotations: annotations,
Labels: labels,
Image: &container.Image,
EnvVars: sparkEnvVars,
SecurityContenxt: config.GetK8sPluginConfig().DefaultPodSecurityContext.DeepCopy(),
},
}

Expand Down
6 changes: 6 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,11 @@ func TestBuildResourceSpark(t *testing.T) {
}))

// Set Interruptible Config
runAsUser := int64(1000)
assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{
DefaultPodSecurityContext: &corev1.PodSecurityContext{
RunAsUser: &runAsUser,
},
InterruptibleNodeSelector: map[string]string{
"x/interruptible": "true",
},
Expand All @@ -373,6 +377,8 @@ func TestBuildResourceSpark(t *testing.T) {
assert.Equal(t, sj.PythonApplicationType, sparkApp.Spec.Type)
assert.Equal(t, testArgs, sparkApp.Spec.Arguments)
assert.Equal(t, testImage, *sparkApp.Spec.Image)
assert.NotNil(t, sparkApp.Spec.Driver.SparkPodSpec.SecurityContenxt)
assert.Equal(t, *sparkApp.Spec.Driver.SparkPodSpec.SecurityContenxt.RunAsUser, runAsUser)

//Validate Driver/Executor Spec.
driverCores, _ := strconv.ParseInt(dummySparkConf["spark.driver.cores"], 10, 32)
Expand Down
6 changes: 6 additions & 0 deletions flyteplugins/go/tasks/testdata/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ plugins:
- FLYTE_AWS_ENDPOINT: "http://minio.flyte:9000"
- FLYTE_AWS_ACCESS_KEY_ID: minio
- FLYTE_AWS_SECRET_ACCESS_KEY: miniostorage
default-pod-security-context:
runAsUser: 1000
runAsGroup: 3000
fsGroup: 2000
default-security-context:
allowPrivilegeEscalation: false
# Spark Plugin configuration
spark:
spark-config-default:
Expand Down

0 comments on commit 592f234

Please sign in to comment.