diff --git a/go.mod b/go.mod index 01686634b..d1d2d2ec0 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/athena v1.0.0 github.com/bstadlbauer/dask-k8s-operator-go-client v0.1.0 github.com/coocood/freecache v1.1.1 - github.com/flyteorg/flyteidl v1.3.2 + github.com/flyteorg/flyteidl v1.3.5 github.com/flyteorg/flytestdlib v1.0.11 github.com/go-test/deep v1.0.7 github.com/golang/protobuf v1.5.2 diff --git a/go.sum b/go.sum index 142b57cf9..e06ba1a08 100644 --- a/go.sum +++ b/go.sum @@ -283,8 +283,8 @@ github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGE github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= -github.com/flyteorg/flyteidl v1.3.2 h1:s4DC8go2ou5LtZ+CFcS31r0mhv3baelNV81C1KZS26U= -github.com/flyteorg/flyteidl v1.3.2/go.mod h1:OJAq333OpInPnMhvVz93AlEjmlQ+t0FAD4aakIYE4OU= +github.com/flyteorg/flyteidl v1.3.5 h1:rSaWMndeENr0QxRKj02kp6N/qQdbgDwpFeZsZbvU45A= +github.com/flyteorg/flyteidl v1.3.5/go.mod h1:OJAq333OpInPnMhvVz93AlEjmlQ+t0FAD4aakIYE4OU= github.com/flyteorg/flytestdlib v1.0.0/go.mod h1:QSVN5wIM1lM9d60eAEbX7NwweQXW96t5x4jbyftn89c= github.com/flyteorg/flytestdlib v1.0.11 h1:f7B8x2/zMuimEVi4Jx0zqzvNhdi7aq7+ZWoqHsbp4F4= github.com/flyteorg/flytestdlib v1.0.11/go.mod h1:nIBmBHtjTJvhZEn3e/EwVC/iMkR2tUX8hEiXjRBpH/s= diff --git a/go/tasks/pluginmachinery/flytek8s/container_helper.go b/go/tasks/pluginmachinery/flytek8s/container_helper.go index 5085fbd6c..60eee5614 100755 --- a/go/tasks/pluginmachinery/flytek8s/container_helper.go +++ b/go/tasks/pluginmachinery/flytek8s/container_helper.go @@ -3,18 +3,19 @@ package flytek8s import ( "context" - "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteplugins/go/tasks/errors" + pluginscore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" - "k8s.io/apimachinery/pkg/util/validation" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" + + "github.com/flyteorg/flytestdlib/logger" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" "k8s.io/apimachinery/pkg/util/rand" - - "github.com/flyteorg/flyteplugins/go/tasks/errors" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" + "k8s.io/apimachinery/pkg/util/validation" ) const resourceGPU = "gpu" @@ -193,22 +194,16 @@ func ApplyResourceOverrides(resources, platformResources v1.ResourceRequirements return resources } -// ToK8sContainer transforms a task template target of type core.Container into a bare-bones kubernetes container, which -// can be further modified with flyte-specific customizations specified by various static and run-time attributes. -func ToK8sContainer(ctx context.Context, taskContainer *core.Container, iFace *core.TypedInterface, parameters template.Parameters) (*v1.Container, error) { - // Perform preliminary validations - if parameters.TaskExecMetadata.GetOverrides() == nil { - return nil, errors.Errorf(errors.BadTaskSpecification, "platform/compiler error, overrides not set for task") - } - if parameters.TaskExecMetadata.GetOverrides() == nil || parameters.TaskExecMetadata.GetOverrides().GetResources() == nil { - return nil, errors.Errorf(errors.BadTaskSpecification, "resource requirements not found for container task, required!") - } +// BuildRawContainer constructs a Container based on the definition passed by the taskContainer and +// TaskExecutionMetadata. +func BuildRawContainer(ctx context.Context, taskContainer *core.Container, taskExecMetadata pluginscore.TaskExecutionMetadata) (*v1.Container, error) { // Make the container name the same as the pod name, unless it violates K8s naming conventions // Container names are subject to the DNS-1123 standard - containerName := parameters.TaskExecMetadata.GetTaskExecutionID().GetGeneratedName() + containerName := taskExecMetadata.GetTaskExecutionID().GetGeneratedName() if errs := validation.IsDNS1123Label(containerName); len(errs) > 0 { containerName = rand.String(4) } + container := &v1.Container{ Name: containerName, Image: taskContainer.GetImage(), @@ -217,12 +212,49 @@ func ToK8sContainer(ctx context.Context, taskContainer *core.Container, iFace *c Env: ToK8sEnvVar(taskContainer.GetEnv()), TerminationMessagePolicy: v1.TerminationMessageFallbackToLogsOnError, } - if err := AddCoPilotToContainer(ctx, config.GetK8sPluginConfig().CoPilot, container, iFace, taskContainer.DataConfig); err != nil { + + return container, nil +} + +// ToK8sContainer builds a Container based on the definition passed by the TaskExecutionContext. This involves applying +// all Flyte configuration including k8s plugins and resource requests. +func ToK8sContainer(ctx context.Context, tCtx pluginscore.TaskExecutionContext) (*v1.Container, error) { + taskTemplate, err := tCtx.TaskReader().Read(ctx) + if err != nil { + logger.Warnf(ctx, "failed to read task information when trying to construct container, err: %s", err.Error()) + return nil, err + } + + // validate arguments + if taskTemplate.GetContainer() == nil { + return nil, errors.Errorf(errors.BadTaskSpecification, "unable to create container with no definition in TaskTemplate") + } + if tCtx.TaskExecutionMetadata().GetOverrides() == nil || tCtx.TaskExecutionMetadata().GetOverrides().GetResources() == nil { + return nil, errors.Errorf(errors.BadTaskSpecification, "resource requirements not found for container task, required!") + } + + // build raw container + container, err := BuildRawContainer(ctx, taskTemplate.GetContainer(), tCtx.TaskExecutionMetadata()) + if err != nil { return nil, err } + if container.SecurityContext == nil && config.GetK8sPluginConfig().DefaultSecurityContext != nil { container.SecurityContext = config.GetK8sPluginConfig().DefaultSecurityContext.DeepCopy() } + + // add flyte resource customizations to the container + templateParameters := template.Parameters{ + TaskExecMetadata: tCtx.TaskExecutionMetadata(), + Inputs: tCtx.InputReader(), + OutputPath: tCtx.OutputWriter(), + Task: tCtx.TaskReader(), + } + + if err := AddFlyteCustomizationsToContainer(ctx, templateParameters, ResourceCustomizationModeAssignResources, container); err != nil { + return nil, err + } + return container, nil } diff --git a/go/tasks/pluginmachinery/flytek8s/container_helper_test.go b/go/tasks/pluginmachinery/flytek8s/container_helper_test.go index 1c45feccb..f0ef5feaf 100755 --- a/go/tasks/pluginmachinery/flytek8s/container_helper_test.go +++ b/go/tasks/pluginmachinery/flytek8s/container_helper_test.go @@ -335,26 +335,45 @@ func TestMergeResources_PartialResourceKeys(t *testing.T) { } func TestToK8sContainer(t *testing.T) { - taskContainer := &core.Container{ - Image: "myimage", - Args: []string{ - "arg1", - "arg2", - "arg3", - }, - Command: []string{ - "com1", - "com2", - "com3", - }, - Env: []*core.KeyValuePair{ - { - Key: "k", - Value: "v", + taskTemplate := &core.TaskTemplate{ + Type: "test", + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: "myimage", + Args: []string{ + "arg1", + "arg2", + "arg3", + }, + Command: []string{ + "com1", + "com2", + "com3", + }, + Env: []*core.KeyValuePair{ + { + Key: "k", + Value: "v", + }, + }, }, }, } + taskReader := &mocks.TaskReader{} + taskReader.On("Read", mock.Anything).Return(taskTemplate, nil) + + inputReader := &mocks2.InputReader{} + inputReader.OnGetInputPath().Return(storage.DataReference("test-data-reference")) + inputReader.OnGetInputPrefixPath().Return(storage.DataReference("test-data-reference-prefix")) + inputReader.OnGetMatch(mock.Anything).Return(&core.LiteralMap{}, nil) + + outputWriter := &mocks2.OutputWriter{} + outputWriter.OnGetOutputPrefixPath().Return("") + outputWriter.OnGetRawOutputPrefix().Return("") + outputWriter.OnGetCheckpointPrefix().Return("/checkpoint") + outputWriter.OnGetPreviousCheckpointsPrefix().Return("/prev") + mockTaskExecMetadata := mocks.TaskExecutionMetadata{} mockTaskOverrides := mocks.TaskOverrides{} mockTaskOverrides.OnGetResources().Return(&v1.ResourceRequirements{ @@ -364,12 +383,16 @@ func TestToK8sContainer(t *testing.T) { }) mockTaskExecMetadata.OnGetOverrides().Return(&mockTaskOverrides) mockTaskExecutionID := mocks.TaskExecutionID{} + mockTaskExecutionID.OnGetID().Return(core.TaskExecutionIdentifier{}) mockTaskExecutionID.OnGetGeneratedName().Return("gen_name") mockTaskExecMetadata.OnGetTaskExecutionID().Return(&mockTaskExecutionID) + mockTaskExecMetadata.OnGetPlatformResources().Return(&v1.ResourceRequirements{}) - templateParameters := template.Parameters{ - TaskExecMetadata: &mockTaskExecMetadata, - } + tCtx := &mocks.TaskExecutionContext{} + tCtx.OnTaskExecutionMetadata().Return(&mockTaskExecMetadata) + tCtx.OnInputReader().Return(inputReader) + tCtx.OnTaskReader().Return(taskReader) + tCtx.OnOutputWriter().Return(outputWriter) cfg := config.GetK8sPluginConfig() allow := false @@ -378,7 +401,7 @@ func TestToK8sContainer(t *testing.T) { } assert.NoError(t, config.SetK8sPluginConfig(cfg)) - container, err := ToK8sContainer(context.TODO(), taskContainer, nil, templateParameters) + container, err := ToK8sContainer(context.TODO(), tCtx) assert.NoError(t, err) assert.Equal(t, container.Image, "myimage") assert.EqualValues(t, []string{ diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/go/tasks/pluginmachinery/flytek8s/pod_helper.go index e9736ff6a..e5cf7cefa 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -7,6 +7,9 @@ import ( "strings" "time" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + pluginserrors "github.com/flyteorg/flyteplugins/go/tasks/errors" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" @@ -17,15 +20,17 @@ import ( "github.com/imdario/mergo" v1 "k8s.io/api/core/v1" - v12 "k8s.io/apimachinery/pkg/apis/meta/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) const PodKind = "pod" const OOMKilled = "OOMKilled" const Interrupted = "Interrupted" const SIGKILL = 137 + const defaultContainerTemplateName = "default" const primaryContainerTemplateName = "primary" +const PrimaryContainerKey = "primary_container_name" // ApplyInterruptibleNodeSelectorRequirement configures the node selector requirement of the node-affinity using the configuration specified. func ApplyInterruptibleNodeSelectorRequirement(interruptible bool, affinity *v1.Affinity) { @@ -104,67 +109,237 @@ func UpdatePod(taskExecutionMetadata pluginsCore.TaskExecutionMetadata, ApplyInterruptibleNodeAffinity(taskExecutionMetadata.IsInterruptible(), podSpec) } -// ToK8sPodSpec constructs a pod spec from the given TaskTemplate -func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*v1.PodSpec, error) { - task, err := tCtx.TaskReader().Read(ctx) +func mergeMapInto(src map[string]string, dst map[string]string) { + for key, value := range src { + dst[key] = value + } +} + +// BuildRawPod constructs a PodSpec and ObjectMeta based on the definition passed by the TaskExecutionContext. This +// definition does not include any configuration injected by Flyte. +func BuildRawPod(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*v1.PodSpec, *metav1.ObjectMeta, string, error) { + taskTemplate, err := tCtx.TaskReader().Read(ctx) if err != nil { logger.Warnf(ctx, "failed to read task information when trying to construct Pod, err: %s", err.Error()) - return nil, err + return nil, nil, "", err + } + + var podSpec *v1.PodSpec + objectMeta := metav1.ObjectMeta{ + Annotations: make(map[string]string), + Labels: make(map[string]string), + } + primaryContainerName := "" + + switch target := taskTemplate.GetTarget().(type) { + case *core.TaskTemplate_Container: + // handles tasks defined by a single container + c, err := ToK8sContainer(ctx, tCtx) + if err != nil { + return nil, nil, "", err + } + + primaryContainerName = c.Name + podSpec = &v1.PodSpec{ + Containers: []v1.Container{ + *c, + }, + } + case *core.TaskTemplate_K8SPod: + // handles pod tasks that marshal the pod spec to the k8s_pod task target. + if target.K8SPod.PodSpec == nil { + return nil, nil, "", pluginserrors.Errorf(pluginserrors.BadTaskSpecification, + "Pod tasks with task type version > 1 should specify their target as a K8sPod with a defined pod spec") + } + + err := utils.UnmarshalStructToObj(target.K8SPod.PodSpec, &podSpec) + if err != nil { + return nil, nil, "", pluginserrors.Errorf(pluginserrors.BadTaskSpecification, + "Unable to unmarshal task k8s pod [%v], Err: [%v]", target.K8SPod.PodSpec, err.Error()) + } + + // get primary container name + var ok bool + if primaryContainerName, ok = taskTemplate.GetConfig()[PrimaryContainerKey]; !ok { + return nil, nil, "", pluginserrors.Errorf(pluginserrors.BadTaskSpecification, + "invalid TaskSpecification, config missing [%s] key in [%v]", PrimaryContainerKey, taskTemplate.GetConfig()) + } + + // update annotations and labels + if taskTemplate.GetK8SPod().Metadata != nil { + mergeMapInto(target.K8SPod.Metadata.Annotations, objectMeta.Annotations) + mergeMapInto(target.K8SPod.Metadata.Labels, objectMeta.Labels) + } + default: + return nil, nil, "", pluginserrors.Errorf(pluginserrors.BadTaskSpecification, + "invalid TaskSpecification, unable to determine Pod configuration") } - if task.GetContainer() == nil { - logger.Errorf(ctx, "Default Pod creation logic works for default container in the task template only.") - return nil, fmt.Errorf("container not specified in task template") + + return podSpec, &objectMeta, primaryContainerName, nil +} + +// ApplyFlytePodConfiguration updates the PodSpec and ObjectMeta with various Flyte configuration. This includes +// applying default k8s configuration, resource requests, injecting copilot containers, and merging with the +// configuration PodTemplate (if exists). +func ApplyFlytePodConfiguration(ctx context.Context, tCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec, objectMeta *metav1.ObjectMeta, primaryContainerName string) (*v1.PodSpec, *metav1.ObjectMeta, error) { + taskTemplate, err := tCtx.TaskReader().Read(ctx) + if err != nil { + logger.Warnf(ctx, "failed to read task information when trying to construct Pod, err: %s", err.Error()) + return nil, nil, err } + + // add flyte resource customizations to containers templateParameters := template.Parameters{ - Task: tCtx.TaskReader(), Inputs: tCtx.InputReader(), OutputPath: tCtx.OutputWriter(), + Task: tCtx.TaskReader(), TaskExecMetadata: tCtx.TaskExecutionMetadata(), } - c, err := ToK8sContainer(ctx, task.GetContainer(), task.Interface, templateParameters) + + resourceRequests := make([]v1.ResourceRequirements, 0, len(podSpec.Containers)) + var primaryContainer *v1.Container + for index, container := range podSpec.Containers { + var resourceMode = ResourceCustomizationModeEnsureExistingResourcesInRange + if container.Name == primaryContainerName { + resourceMode = ResourceCustomizationModeMergeExistingResources + } + + if err := AddFlyteCustomizationsToContainer(ctx, templateParameters, resourceMode, &podSpec.Containers[index]); err != nil { + return nil, nil, err + } + + resourceRequests = append(resourceRequests, podSpec.Containers[index].Resources) + if container.Name == primaryContainerName { + primaryContainer = &podSpec.Containers[index] + } + } + + if primaryContainer == nil { + return nil, nil, pluginserrors.Errorf(pluginserrors.BadTaskSpecification, "invalid TaskSpecification, primary container [%s] not defined", primaryContainerName) + } + + // add copilot configuration to primaryContainer and PodSpec (if necessary) + if taskTemplate.GetContainer() != nil { + if err := AddCoPilotToContainer(ctx, config.GetK8sPluginConfig().CoPilot, primaryContainer, + taskTemplate.Interface, taskTemplate.GetContainer().DataConfig); err != nil { + return nil, nil, err + } + + if err := AddCoPilotToPod(ctx, config.GetK8sPluginConfig().CoPilot, podSpec, taskTemplate.GetInterface(), + tCtx.TaskExecutionMetadata(), tCtx.InputReader(), tCtx.OutputWriter(), taskTemplate.GetContainer().GetDataConfig()); err != nil { + return nil, nil, err + } + } + + // update primaryContainer and PodSpec with k8s plugin configuration, etc + UpdatePod(tCtx.TaskExecutionMetadata(), resourceRequests, podSpec) + if primaryContainer.SecurityContext == nil && config.GetK8sPluginConfig().DefaultSecurityContext != nil { + primaryContainer.SecurityContext = config.GetK8sPluginConfig().DefaultSecurityContext.DeepCopy() + } + + // merge PodSpec and ObjectMeta with configuration pod template (if exists) + podSpec, objectMeta, err = MergeWithBasePodTemplate(ctx, tCtx, podSpec, objectMeta, primaryContainerName) if err != nil { - return nil, err + return nil, nil, err } - err = AddFlyteCustomizationsToContainer(ctx, templateParameters, ResourceCustomizationModeAssignResources, c) + + return podSpec, objectMeta, nil +} + +// ToK8sPodSpec builds a PodSpec and ObjectMeta based on the definition passed by the TaskExecutionContext. This +// involves parsing the raw PodSpec definition and applying all Flyte configuration options. +func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*v1.PodSpec, *metav1.ObjectMeta, error) { + // build raw PodSpec and ObjectMeta + podSpec, objectMeta, primaryContainerName, err := BuildRawPod(ctx, tCtx) if err != nil { - return nil, err + return nil, nil, err } - containers := []v1.Container{ - *c, + // add flyte configuration + podSpec, objectMeta, err = ApplyFlytePodConfiguration(ctx, tCtx, podSpec, objectMeta, primaryContainerName) + if err != nil { + return nil, nil, err } - pod := &v1.PodSpec{ - Containers: containers, + + return podSpec, objectMeta, nil +} + +// getBasePodTemplate attempts to retrieve the PodTemplate to use as the base for k8s Pod configuration. This value can +// come from one of the following: +// (1) PodTemplate name in the TaskMetadata: This name is then looked up in the PodTemplateStore. +// (2) Default PodTemplate name from configuration: This name is then looked up in the PodTemplateStore. +func getBasePodTemplate(ctx context.Context, tCtx pluginsCore.TaskExecutionContext, podTemplateStore PodTemplateStore) (*v1.PodTemplate, error) { + taskTemplate, err := tCtx.TaskReader().Read(ctx) + if err != nil { + return nil, pluginserrors.Errorf(pluginserrors.BadTaskSpecification, "TaskSpecification cannot be read, Err: [%v]", err.Error()) } - UpdatePod(tCtx.TaskExecutionMetadata(), []v1.ResourceRequirements{c.Resources}, pod) - if err := AddCoPilotToPod(ctx, config.GetK8sPluginConfig().CoPilot, pod, task.GetInterface(), tCtx.TaskExecutionMetadata(), tCtx.InputReader(), tCtx.OutputWriter(), task.GetContainer().GetDataConfig()); err != nil { - return nil, err + var podTemplate *v1.PodTemplate + if taskTemplate.Metadata != nil && len(taskTemplate.Metadata.PodTemplateName) > 0 { + // retrieve PodTemplate by name from PodTemplateStore + podTemplate = podTemplateStore.LoadOrDefault(tCtx.TaskExecutionMetadata().GetNamespace(), taskTemplate.Metadata.PodTemplateName) + if podTemplate == nil { + return nil, pluginserrors.Errorf(pluginserrors.BadTaskSpecification, "PodTemplate '%s' does not exist", taskTemplate.Metadata.PodTemplateName) + } + } else { + // check for default PodTemplate + podTemplate = podTemplateStore.LoadOrDefault(tCtx.TaskExecutionMetadata().GetNamespace(), config.GetK8sPluginConfig().DefaultPodTemplateName) } - return pod, nil + return podTemplate, nil } -func MergePodSpecs(podTemplatePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContainerName string) (*v1.PodSpec, error) { - if podTemplatePodSpec == nil || podSpec == nil { - return nil, errors.New("podTemplatePodSpec and podSpec cannot be nil") - } +// MergeWithBasePodTemplate attempts to merge the provided PodSpec and ObjectMeta with the configuration PodTemplate for +// this task. +func MergeWithBasePodTemplate(ctx context.Context, tCtx pluginsCore.TaskExecutionContext, + podSpec *v1.PodSpec, objectMeta *metav1.ObjectMeta, primaryContainerName string) (*v1.PodSpec, *metav1.ObjectMeta, error) { - var podTemplatePodSpecCopy *v1.PodSpec = podTemplatePodSpec.DeepCopy() + // attempt to retrieve base PodTemplate + podTemplate, err := getBasePodTemplate(ctx, tCtx, DefaultPodTemplateStore) + if err != nil { + return nil, nil, err + } else if podTemplate == nil { + // if no PodTemplate to merge as base -> return + return podSpec, objectMeta, nil + } - err := mergo.Merge(podTemplatePodSpecCopy, podSpec, mergo.WithOverride, mergo.WithAppendSlice) + // merge podSpec with podTemplate + mergedPodSpec, err := mergePodSpecs(&podTemplate.Template.Spec, podSpec, primaryContainerName) if err != nil { + return nil, nil, err + } + + // merge PodTemplate PodSpec with podSpec + var mergedObjectMeta *metav1.ObjectMeta = podTemplate.Template.ObjectMeta.DeepCopy() + if err := mergo.Merge(mergedObjectMeta, objectMeta, mergo.WithOverride, mergo.WithAppendSlice); err != nil { + return nil, nil, err + } + + return mergedPodSpec, mergedObjectMeta, nil +} + +// mergePodSpecs merges the two provided PodSpecs. This process uses the first as the base configuration, where values +// set by the first PodSpec are overwritten by the second in the return value. Additionally, this function applies +// container-level configuration from the basePodSpec. +func mergePodSpecs(basePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContainerName string) (*v1.PodSpec, error) { + if basePodSpec == nil || podSpec == nil { + return nil, errors.New("neither the basePodSpec or the podSpec can be nil") + } + + // merge PodTemplate PodSpec with podSpec + var mergedPodSpec *v1.PodSpec = basePodSpec.DeepCopy() + if err := mergo.Merge(mergedPodSpec, podSpec, mergo.WithOverride, mergo.WithAppendSlice); err != nil { return nil, err } // merge template Containers var mergedContainers []v1.Container var defaultContainerTemplate, primaryContainerTemplate *v1.Container - for i := 0; i < len(podTemplatePodSpecCopy.Containers); i++ { - if podTemplatePodSpecCopy.Containers[i].Name == defaultContainerTemplateName { - defaultContainerTemplate = &podTemplatePodSpecCopy.Containers[i] - } else if podTemplatePodSpecCopy.Containers[i].Name == primaryContainerTemplateName { - primaryContainerTemplate = &podTemplatePodSpecCopy.Containers[i] + for i := 0; i < len(mergedPodSpec.Containers); i++ { + if mergedPodSpec.Containers[i].Name == defaultContainerTemplateName { + defaultContainerTemplate = &mergedPodSpec.Containers[i] + } else if mergedPodSpec.Containers[i].Name == primaryContainerTemplateName { + primaryContainerTemplate = &mergedPodSpec.Containers[i] } } @@ -187,10 +362,9 @@ func MergePodSpecs(podTemplatePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryC } } - // if applicable merge with existing container # TODO test + // if applicable merge with existing container if mergedContainer == nil { mergedContainers = append(mergedContainers, container) - } else { err := mergo.Merge(mergedContainer, container, mergo.WithOverride, mergo.WithAppendSlice) if err != nil { @@ -199,43 +373,15 @@ func MergePodSpecs(podTemplatePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryC mergedContainers = append(mergedContainers, *mergedContainer) } - - } - - // update Pod fields - podTemplatePodSpecCopy.Containers = mergedContainers - - return podTemplatePodSpecCopy, nil -} - -func BuildPodWithSpec(podTemplate *v1.PodTemplate, podSpec *v1.PodSpec, primaryContainerName string) (*v1.Pod, error) { - pod := v1.Pod{ - TypeMeta: v12.TypeMeta{ - Kind: PodKind, - APIVersion: v1.SchemeGroupVersion.String(), - }, - } - - if podTemplate != nil { - // merge template PodSpec - mergedPodSpec, err := MergePodSpecs(&podTemplate.Template.Spec, podSpec, primaryContainerName) - if err != nil { - return nil, err - } - - pod.ObjectMeta = podTemplate.Template.ObjectMeta - pod.Spec = *mergedPodSpec - - } else { - pod.Spec = *podSpec } - return &pod, nil + mergedPodSpec.Containers = mergedContainers + return mergedPodSpec, nil } func BuildIdentityPod() *v1.Pod { return &v1.Pod{ - TypeMeta: v12.TypeMeta{ + TypeMeta: metav1.TypeMeta{ Kind: PodKind, APIVersion: v1.SchemeGroupVersion.String(), }, @@ -506,8 +652,8 @@ func DemystifyFailure(status v1.PodStatus, info pluginsCore.TaskInfo) (pluginsCo return pluginsCore.PhaseInfoRetryableFailure(code, message, &info), nil } -func GetLastTransitionOccurredAt(pod *v1.Pod) v12.Time { - var lastTransitionTime v12.Time +func GetLastTransitionOccurredAt(pod *v1.Pod) metav1.Time { + var lastTransitionTime metav1.Time containerStatuses := append(pod.Status.ContainerStatuses, pod.Status.InitContainerStatuses...) for _, containerStatus := range containerStatuses { if r := containerStatus.LastTerminationState.Running; r != nil { @@ -522,7 +668,7 @@ func GetLastTransitionOccurredAt(pod *v1.Pod) v12.Time { } if lastTransitionTime.IsZero() { - lastTransitionTime = v12.NewTime(time.Now()) + lastTransitionTime = metav1.NewTime(time.Now()) } return lastTransitionTime diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go b/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go index 716807452..59ea6c4dd 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go @@ -9,25 +9,24 @@ import ( "testing" "time" - config1 "github.com/flyteorg/flytestdlib/config" - "github.com/flyteorg/flytestdlib/config/viper" - - "github.com/flyteorg/flytestdlib/storage" - "github.com/stretchr/testify/mock" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + pluginsCoreMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" + pluginsIOMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" + + config1 "github.com/flyteorg/flytestdlib/config" + "github.com/flyteorg/flytestdlib/config/viper" + "github.com/flyteorg/flytestdlib/storage" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" - metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - pluginsCoreMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" - pluginsIOMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" ) func dummyTaskExecutionMetadata(resources *v1.ResourceRequirements) pluginsCore.TaskExecutionMetadata { @@ -35,7 +34,7 @@ func dummyTaskExecutionMetadata(resources *v1.ResourceRequirements) pluginsCore. taskExecutionMetadata.On("GetNamespace").Return("test-namespace") taskExecutionMetadata.On("GetAnnotations").Return(map[string]string{"annotation-1": "val1"}) taskExecutionMetadata.On("GetLabels").Return(map[string]string{"label-1": "val1"}) - taskExecutionMetadata.On("GetOwnerReference").Return(metaV1.OwnerReference{ + taskExecutionMetadata.On("GetOwnerReference").Return(metav1.OwnerReference{ Kind: "node", Name: "blah", }) @@ -325,7 +324,7 @@ func toK8sPodInterruptible(t *testing.T) { }, }) - p, err := ToK8sPodSpec(ctx, x) + p, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.Len(t, p.Tolerations, 2) assert.Equal(t, "x/flyte", p.Tolerations[1].Key) @@ -392,7 +391,7 @@ func TestToK8sPod(t *testing.T) { }, }) - p, err := ToK8sPodSpec(ctx, x) + p, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.Equal(t, len(p.Tolerations), 1) }) @@ -409,7 +408,7 @@ func TestToK8sPod(t *testing.T) { }, }) - p, err := ToK8sPodSpec(ctx, x) + p, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.Equal(t, len(p.Tolerations), 0) assert.Equal(t, "some-acceptable-name", p.Containers[0].Name) @@ -436,7 +435,7 @@ func TestToK8sPod(t *testing.T) { DefaultMemoryRequest: resource.MustParse("1024Mi"), })) - p, err := ToK8sPodSpec(ctx, x) + p, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.Equal(t, 1, len(p.NodeSelector)) assert.Equal(t, "myScheduler", p.SchedulerName) @@ -453,7 +452,7 @@ func TestToK8sPod(t *testing.T) { })) x := dummyExecContext(&v1.ResourceRequirements{}) - p, err := ToK8sPodSpec(ctx, x) + p, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.NotNil(t, p.SecurityContext) assert.Equal(t, *p.SecurityContext.RunAsGroup, v) @@ -465,7 +464,7 @@ func TestToK8sPod(t *testing.T) { EnableHostNetworkingPod: &enabled, })) x := dummyExecContext(&v1.ResourceRequirements{}) - p, err := ToK8sPodSpec(ctx, x) + p, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.True(t, p.HostNetwork) }) @@ -476,7 +475,7 @@ func TestToK8sPod(t *testing.T) { EnableHostNetworkingPod: &enabled, })) x := dummyExecContext(&v1.ResourceRequirements{}) - p, err := ToK8sPodSpec(ctx, x) + p, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.False(t, p.HostNetwork) }) @@ -484,7 +483,7 @@ func TestToK8sPod(t *testing.T) { t.Run("skipSettingHostNetwork", func(t *testing.T) { assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{})) x := dummyExecContext(&v1.ResourceRequirements{}) - p, err := ToK8sPodSpec(ctx, x) + p, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.False(t, p.HostNetwork) }) @@ -518,7 +517,7 @@ func TestToK8sPod(t *testing.T) { })) x := dummyExecContext(&v1.ResourceRequirements{}) - p, err := ToK8sPodSpec(ctx, x) + p, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.NotNil(t, p.DNSConfig) assert.Equal(t, []string{"8.8.8.8", "8.8.4.4"}, p.DNSConfig.Nameservers) @@ -756,7 +755,7 @@ func TestDemystifyPending(t *testing.T) { t.Run("CreateContainerErrorWithinGracePeriod", func(t *testing.T) { s2 := *s.DeepCopy() - s2.Conditions[0].LastTransitionTime = metaV1.Now() + s2.Conditions[0].LastTransitionTime = metav1.Now() s2.ContainerStatuses = []v1.ContainerStatus{ { Ready: false, @@ -775,7 +774,7 @@ func TestDemystifyPending(t *testing.T) { t.Run("CreateContainerErrorOutsideGracePeriod", func(t *testing.T) { s2 := *s.DeepCopy() - s2.Conditions[0].LastTransitionTime.Time = metaV1.Now().Add(-config.GetK8sPluginConfig().CreateContainerErrorGracePeriod.Duration) + s2.Conditions[0].LastTransitionTime.Time = metav1.Now().Add(-config.GetK8sPluginConfig().CreateContainerErrorGracePeriod.Duration) s2.ContainerStatuses = []v1.ContainerStatus{ { Ready: false, @@ -968,7 +967,7 @@ func TestDeterminePrimaryContainerPhase(t *testing.T) { Name: primaryContainerName, State: v1.ContainerState{ Running: &v1.ContainerStateRunning{ - StartedAt: metaV1.Now(), + StartedAt: metav1.Now(), }, }, }, @@ -1014,16 +1013,232 @@ func TestDeterminePrimaryContainerPhase(t *testing.T) { }) } +func TestGetPodTemplate(t *testing.T) { + ctx := context.TODO() + + podTemplate := v1.PodTemplate{ + ObjectMeta: metav1.ObjectMeta{ + Name: "foo", + Namespace: "bar", + }, + } + + t.Run("PodTemplateDoesNotExist", func(t *testing.T) { + // initialize TaskExecutionContext + task := &core.TaskTemplate{ + Type: "test", + } + + taskReader := &pluginsCoreMock.TaskReader{} + taskReader.On("Read", mock.Anything).Return(task, nil) + + tCtx := &pluginsCoreMock.TaskExecutionContext{} + tCtx.OnTaskExecutionMetadata().Return(dummyTaskExecutionMetadata(&v1.ResourceRequirements{})) + tCtx.OnTaskReader().Return(taskReader) + + // initialize PodTemplateStore + store := NewPodTemplateStore() + store.SetDefaultNamespace(podTemplate.Namespace) + + // validate base PodTemplate + basePodTemplate, err := getBasePodTemplate(ctx, tCtx, store) + assert.Nil(t, err) + assert.Nil(t, basePodTemplate) + }) + + t.Run("PodTemplateFromTaskTemplateNameExists", func(t *testing.T) { + // initialize TaskExecutionContext + task := &core.TaskTemplate{ + Metadata: &core.TaskMetadata{ + PodTemplateName: "foo", + }, + Type: "test", + } + + taskReader := &pluginsCoreMock.TaskReader{} + taskReader.On("Read", mock.Anything).Return(task, nil) + + tCtx := &pluginsCoreMock.TaskExecutionContext{} + tCtx.OnTaskExecutionMetadata().Return(dummyTaskExecutionMetadata(&v1.ResourceRequirements{})) + tCtx.OnTaskReader().Return(taskReader) + + // initialize PodTemplateStore + store := NewPodTemplateStore() + store.SetDefaultNamespace(podTemplate.Namespace) + store.Store(&podTemplate) + + // validate base PodTemplate + basePodTemplate, err := getBasePodTemplate(ctx, tCtx, store) + assert.Nil(t, err) + assert.True(t, reflect.DeepEqual(podTemplate, *basePodTemplate)) + }) + + t.Run("PodTemplateFromTaskTemplateNameDoesNotExist", func(t *testing.T) { + // initialize TaskExecutionContext + task := &core.TaskTemplate{ + Type: "test", + Metadata: &core.TaskMetadata{ + PodTemplateName: "foo", + }, + } + + taskReader := &pluginsCoreMock.TaskReader{} + taskReader.On("Read", mock.Anything).Return(task, nil) + + tCtx := &pluginsCoreMock.TaskExecutionContext{} + tCtx.OnTaskExecutionMetadata().Return(dummyTaskExecutionMetadata(&v1.ResourceRequirements{})) + tCtx.OnTaskReader().Return(taskReader) + + // initialize PodTemplateStore + store := NewPodTemplateStore() + store.SetDefaultNamespace(podTemplate.Namespace) + + // validate base PodTemplate + basePodTemplate, err := getBasePodTemplate(ctx, tCtx, store) + assert.NotNil(t, err) + assert.Nil(t, basePodTemplate) + }) + + t.Run("PodTemplateFromDefaultPodTemplate", func(t *testing.T) { + // set default PodTemplate name configuration + assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ + DefaultPodTemplateName: "foo", + })) + + // initialize TaskExecutionContext + task := &core.TaskTemplate{ + Type: "test", + } + + taskReader := &pluginsCoreMock.TaskReader{} + taskReader.On("Read", mock.Anything).Return(task, nil) + + tCtx := &pluginsCoreMock.TaskExecutionContext{} + tCtx.OnTaskExecutionMetadata().Return(dummyTaskExecutionMetadata(&v1.ResourceRequirements{})) + tCtx.OnTaskReader().Return(taskReader) + + // initialize PodTemplateStore + store := NewPodTemplateStore() + store.SetDefaultNamespace(podTemplate.Namespace) + store.Store(&podTemplate) + + // validate base PodTemplate + basePodTemplate, err := getBasePodTemplate(ctx, tCtx, store) + assert.Nil(t, err) + assert.True(t, reflect.DeepEqual(podTemplate, *basePodTemplate)) + }) +} + +func TestMergeWithBasePodTemplate(t *testing.T) { + podSpec := v1.PodSpec{ + Containers: []v1.Container{ + v1.Container{ + Name: "foo", + }, + v1.Container{ + Name: "bar", + }, + }, + } + + objectMeta := metav1.ObjectMeta{ + Labels: map[string]string{ + "fooKey": "barValue", + }, + } + + t.Run("BasePodTemplateDoesNotExist", func(t *testing.T) { + task := &core.TaskTemplate{ + Type: "test", + } + + taskReader := &pluginsCoreMock.TaskReader{} + taskReader.On("Read", mock.Anything).Return(task, nil) + + tCtx := &pluginsCoreMock.TaskExecutionContext{} + tCtx.OnTaskExecutionMetadata().Return(dummyTaskExecutionMetadata(&v1.ResourceRequirements{})) + tCtx.OnTaskReader().Return(taskReader) + + resultPodSpec, resultObjectMeta, err := MergeWithBasePodTemplate(context.TODO(), tCtx, &podSpec, &objectMeta, "foo") + assert.Nil(t, err) + assert.True(t, reflect.DeepEqual(podSpec, *resultPodSpec)) + assert.True(t, reflect.DeepEqual(objectMeta, *resultObjectMeta)) + }) + + t.Run("BasePodTemplateExists", func(t *testing.T) { + primaryContainerTemplate := v1.Container{ + Name: primaryContainerTemplateName, + TerminationMessagePath: "/dev/primary-termination-log", + } + + podTemplate := v1.PodTemplate{ + ObjectMeta: metav1.ObjectMeta{ + Name: "fooTemplate", + Namespace: "test-namespace", + }, + Template: v1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + "fooKey": "bazValue", + "barKey": "bazValue", + }, + }, + Spec: v1.PodSpec{ + Containers: []v1.Container{ + primaryContainerTemplate, + }, + }, + }, + } + + DefaultPodTemplateStore.Store(&podTemplate) + + task := &core.TaskTemplate{ + Metadata: &core.TaskMetadata{ + PodTemplateName: "fooTemplate", + }, + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Command: []string{"command"}, + Args: []string{"{{.Input}}"}, + }, + }, + Type: "test", + } + + taskReader := &pluginsCoreMock.TaskReader{} + taskReader.On("Read", mock.Anything).Return(task, nil) + + tCtx := &pluginsCoreMock.TaskExecutionContext{} + tCtx.OnTaskExecutionMetadata().Return(dummyTaskExecutionMetadata(&v1.ResourceRequirements{})) + tCtx.OnTaskReader().Return(taskReader) + + resultPodSpec, resultObjectMeta, err := MergeWithBasePodTemplate(context.TODO(), tCtx, &podSpec, &objectMeta, "foo") + assert.Nil(t, err) + + // test that template podSpec is merged + primaryContainer := resultPodSpec.Containers[0] + assert.Equal(t, podSpec.Containers[0].Name, primaryContainer.Name) + assert.Equal(t, primaryContainerTemplate.TerminationMessagePath, primaryContainer.TerminationMessagePath) + + // test that template object metadata is copied + assert.Contains(t, resultObjectMeta.Labels, "fooKey") + assert.Equal(t, resultObjectMeta.Labels["fooKey"], "barValue") + assert.Contains(t, resultObjectMeta.Labels, "barKey") + assert.Equal(t, resultObjectMeta.Labels["barKey"], "bazValue") + }) +} + func TestMergePodSpecs(t *testing.T) { var priority int32 = 1 - podSpec1, _ := MergePodSpecs(nil, nil, "foo") + podSpec1, _ := mergePodSpecs(nil, nil, "foo") assert.Nil(t, podSpec1) - podSpec2, _ := MergePodSpecs(&v1.PodSpec{}, nil, "foo") + podSpec2, _ := mergePodSpecs(&v1.PodSpec{}, nil, "foo") assert.Nil(t, podSpec2) - podSpec3, _ := MergePodSpecs(nil, &v1.PodSpec{}, "foo") + podSpec3, _ := mergePodSpecs(nil, &v1.PodSpec{}, "foo") assert.Nil(t, podSpec3) podSpec := v1.PodSpec{ @@ -1077,7 +1292,7 @@ func TestMergePodSpecs(t *testing.T) { }, } - mergedPodSpec, err := MergePodSpecs(&podTemplateSpec, &podSpec, "foo") + mergedPodSpec, err := mergePodSpecs(&podTemplateSpec, &podSpec, "foo") assert.Nil(t, err) // validate a PodTemplate-only field @@ -1100,54 +1315,4 @@ func TestMergePodSpecs(t *testing.T) { defaultContainer := mergedPodSpec.Containers[1] assert.Equal(t, podSpec.Containers[1].Name, defaultContainer.Name) assert.Equal(t, defaultContainerTemplate.TerminationMessagePath, defaultContainer.TerminationMessagePath) - -} - -func TestBuildPodWithSpec(t *testing.T) { - podSpec := v1.PodSpec{ - Containers: []v1.Container{ - v1.Container{ - Name: "foo", - }, - v1.Container{ - Name: "bar", - }, - }, - } - - pod, err := BuildPodWithSpec(nil, &podSpec, "foo") - assert.Nil(t, err) - assert.True(t, reflect.DeepEqual(pod.Spec, podSpec)) - - primaryContainerTemplate := v1.Container{ - Name: primaryContainerTemplateName, - TerminationMessagePath: "/dev/primary-termination-log", - } - - podTemplate := v1.PodTemplate{ - Template: v1.PodTemplateSpec{ - ObjectMeta: metav1.ObjectMeta{ - Labels: map[string]string{ - "fooKey": "barVal", - }, - }, - Spec: v1.PodSpec{ - Containers: []v1.Container{ - primaryContainerTemplate, - }, - }, - }, - } - - pod, err = BuildPodWithSpec(&podTemplate, &podSpec, "foo") - assert.Nil(t, err) - - // Test that template podSpec is merged - primaryContainer := pod.Spec.Containers[0] - assert.Equal(t, podSpec.Containers[0].Name, primaryContainer.Name) - assert.Equal(t, primaryContainerTemplate.TerminationMessagePath, primaryContainer.TerminationMessagePath) - - // Test that template object metadata is copied - assert.Contains(t, pod.ObjectMeta.Labels, "fooKey") - assert.Equal(t, pod.ObjectMeta.Labels["fooKey"], "barVal") } diff --git a/go/tasks/pluginmachinery/flytek8s/pod_template_store.go b/go/tasks/pluginmachinery/flytek8s/pod_template_store.go index f84c83c8c..cb79feee1 100644 --- a/go/tasks/pluginmachinery/flytek8s/pod_template_store.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_template_store.go @@ -1,8 +1,11 @@ package flytek8s import ( + "context" "sync" + "github.com/flyteorg/flytestdlib/logger" + v1 "k8s.io/api/core/v1" "k8s.io/client-go/tools/cache" ) @@ -12,24 +15,42 @@ var DefaultPodTemplateStore PodTemplateStore = NewPodTemplateStore() // PodTemplateStore maintains a thread-safe mapping of active PodTemplates with their associated // namespaces. type PodTemplateStore struct { - sync.Map + *sync.Map defaultNamespace string } // NewPodTemplateStore initializes a new PodTemplateStore func NewPodTemplateStore() PodTemplateStore { - return PodTemplateStore{} + return PodTemplateStore{ + Map: &sync.Map{}, + } } -// LoadOrDefautl returns the PodTemplate associated with the given namespace. If one does not exist -// it attempts to retrieve the one associated with the defaultNamespace parameter. -func (p *PodTemplateStore) LoadOrDefault(namespace string) *v1.PodTemplate { - if podTemplate, ok := p.Load(namespace); ok { - return podTemplate.(*v1.PodTemplate) +// Delete removes the specified PodTemplate from the store. +func (p *PodTemplateStore) Delete(podTemplate *v1.PodTemplate) { + if value, ok := p.Load(podTemplate.Name); ok { + podTemplates := value.(*sync.Map) + podTemplates.Delete(podTemplate.Namespace) + logger.Debugf(context.Background(), "deleted PodTemplate '%s:%s' from store", podTemplate.Namespace, podTemplate.Name) + + // we specifically are not deleting empty maps from the store because this may introduce race + // conditions where a PodTemplate is being added to the 2nd dimension map while the top level map + // is concurrently being deleted. } +} - if podTemplate, ok := p.Load(p.defaultNamespace); ok { - return podTemplate.(*v1.PodTemplate) +// LoadOrDefault returns the PodTemplate with the specified name in the given namespace. If one +// does not exist it attempts to retrieve the one associated with the defaultNamespace. +func (p *PodTemplateStore) LoadOrDefault(namespace string, podTemplateName string) *v1.PodTemplate { + if value, ok := p.Load(podTemplateName); ok { + podTemplates := value.(*sync.Map) + if podTemplate, ok := podTemplates.Load(namespace); ok { + return podTemplate.(*v1.PodTemplate) + } + + if podTemplate, ok := podTemplates.Load(p.defaultNamespace); ok { + return podTemplate.(*v1.PodTemplate) + } } return nil @@ -40,26 +61,31 @@ func (p *PodTemplateStore) SetDefaultNamespace(namespace string) { p.defaultNamespace = namespace } +// Store loads the specified PodTemplate into the store. +func (p *PodTemplateStore) Store(podTemplate *v1.PodTemplate) { + value, _ := p.LoadOrStore(podTemplate.Name, &sync.Map{}) + podTemplates := value.(*sync.Map) + podTemplates.Store(podTemplate.Namespace, podTemplate) + logger.Debugf(context.Background(), "registered PodTemplate '%s:%s' in store", podTemplate.Namespace, podTemplate.Name) +} + // GetPodTemplateUpdatesHandler returns a new ResourceEventHandler which adds / removes -// PodTemplates with the associated podTemplateName to / from the provided PodTemplateStore. -func GetPodTemplateUpdatesHandler(store *PodTemplateStore, podTemplateName string) cache.ResourceEventHandler { +// PodTemplates to / from the provided PodTemplateStore. +func GetPodTemplateUpdatesHandler(store *PodTemplateStore) cache.ResourceEventHandler { return cache.ResourceEventHandlerFuncs{ AddFunc: func(obj interface{}) { - podTemplate, ok := obj.(*v1.PodTemplate) - if ok && podTemplate.Name == podTemplateName { - store.Store(podTemplate.Namespace, podTemplate) + if podTemplate, ok := obj.(*v1.PodTemplate); ok { + store.Store(podTemplate) } }, UpdateFunc: func(old, new interface{}) { - podTemplate, ok := new.(*v1.PodTemplate) - if ok && podTemplate.Name == podTemplateName { - store.Store(podTemplate.Namespace, podTemplate) + if podTemplate, ok := new.(*v1.PodTemplate); ok { + store.Store(podTemplate) } }, DeleteFunc: func(obj interface{}) { - podTemplate, ok := obj.(*v1.PodTemplate) - if ok && podTemplate.Name == podTemplateName { - store.Delete(podTemplate.Namespace) + if podTemplate, ok := obj.(*v1.PodTemplate); ok { + store.Delete(podTemplate) } }, } diff --git a/go/tasks/pluginmachinery/flytek8s/pod_template_store_test.go b/go/tasks/pluginmachinery/flytek8s/pod_template_store_test.go index b86a6ca7b..1b1a163b2 100644 --- a/go/tasks/pluginmachinery/flytek8s/pod_template_store_test.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_template_store_test.go @@ -40,7 +40,7 @@ func TestPodTemplateStore(t *testing.T) { kubeClient := fake.NewSimpleClientset() informerFactory := informers.NewSharedInformerFactoryWithOptions(kubeClient, 30*time.Second) - updateHandler := GetPodTemplateUpdatesHandler(&store, podTemplate.Name) + updateHandler := GetPodTemplateUpdatesHandler(&store) informerFactory.Core().V1().PodTemplates().Informer().AddEventHandler(updateHandler) go informerFactory.Start(ctx.Done()) @@ -49,7 +49,7 @@ func TestPodTemplateStore(t *testing.T) { assert.NoError(t, err) time.Sleep(50 * time.Millisecond) - createPodTemplate := store.LoadOrDefault(podTemplate.Namespace) + createPodTemplate := store.LoadOrDefault(podTemplate.Namespace, podTemplate.Name) assert.NotNil(t, createPodTemplate) assert.True(t, reflect.DeepEqual(podTemplate, createPodTemplate)) @@ -57,16 +57,23 @@ func TestPodTemplateStore(t *testing.T) { newNamespacePodTemplate := podTemplate.DeepCopy() newNamespacePodTemplate.Namespace = "foo" - nonDefaultPodTemplate := store.LoadOrDefault(newNamespacePodTemplate.Namespace) - assert.NotNil(t, nonDefaultPodTemplate) - assert.True(t, reflect.DeepEqual(podTemplate, nonDefaultPodTemplate)) + nonDefaultNamespacePodTemplate := store.LoadOrDefault(newNamespacePodTemplate.Namespace, newNamespacePodTemplate.Name) + assert.NotNil(t, nonDefaultNamespacePodTemplate) + assert.True(t, reflect.DeepEqual(podTemplate, nonDefaultNamespacePodTemplate)) + + // non-default name podTemplate does not exist + newNamePodTemplate := podTemplate.DeepCopy() + newNamePodTemplate.Name = "foo" + + nonDefaultNamePodTemplate := store.LoadOrDefault(newNamePodTemplate.Namespace, newNamePodTemplate.Name) + assert.Nil(t, nonDefaultNamePodTemplate) // non-default namespace podTemplate exists _, err = kubeClient.CoreV1().PodTemplates(newNamespacePodTemplate.Namespace).Create(ctx, newNamespacePodTemplate, metav1.CreateOptions{}) assert.NoError(t, err) time.Sleep(50 * time.Millisecond) - createNewNamespacePodTemplate := store.LoadOrDefault(newNamespacePodTemplate.Namespace) + createNewNamespacePodTemplate := store.LoadOrDefault(newNamespacePodTemplate.Namespace, newNamespacePodTemplate.Name) assert.NotNil(t, createNewNamespacePodTemplate) assert.True(t, reflect.DeepEqual(newNamespacePodTemplate, createNewNamespacePodTemplate)) @@ -77,7 +84,7 @@ func TestPodTemplateStore(t *testing.T) { assert.NoError(t, err) time.Sleep(50 * time.Millisecond) - updatePodTemplate := store.LoadOrDefault(podTemplate.Namespace) + updatePodTemplate := store.LoadOrDefault(podTemplate.Namespace, podTemplate.Name) assert.NotNil(t, updatePodTemplate) assert.True(t, reflect.DeepEqual(updatedPodTemplate, updatePodTemplate)) @@ -86,6 +93,6 @@ func TestPodTemplateStore(t *testing.T) { assert.NoError(t, err) time.Sleep(50 * time.Millisecond) - deletePodTemplate := store.LoadOrDefault(podTemplate.Namespace) + deletePodTemplate := store.LoadOrDefault(podTemplate.Namespace, podTemplate.Name) assert.Nil(t, deletePodTemplate) } diff --git a/go/tasks/plugins/array/k8s/subtask.go b/go/tasks/plugins/array/k8s/subtask.go index 5cdc80596..6411f06c0 100644 --- a/go/tasks/plugins/array/k8s/subtask.go +++ b/go/tasks/plugins/array/k8s/subtask.go @@ -10,6 +10,7 @@ import ( "github.com/flyteorg/flyteplugins/go/tasks/errors" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/tasklog" @@ -323,7 +324,7 @@ func getSubtaskPhaseInfo(ctx context.Context, stCtx SubTaskExecutionContext, cfg // getTaskContainerIndex returns the index of the primary container in a k8s pod. func getTaskContainerIndex(pod *v1.Pod) (int, error) { - primaryContainerName, ok := pod.Annotations[podPlugin.PrimaryContainerKey] + primaryContainerName, ok := pod.Annotations[flytek8s.PrimaryContainerKey] // For tasks with a Container target, we only ever build one container as part of the pod if !ok { if len(pod.Spec.Containers) == 1 { diff --git a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index 1b7624888..f605088e5 100644 --- a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -62,26 +62,12 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu launcherReplicas := mpiTaskExtraArgs.GetNumLauncherReplicas() slots := mpiTaskExtraArgs.GetSlots() - podSpec, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) + podSpec, objectMeta, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } - common.OverrideDefaultContainerName(taskCtx, podSpec, kubeflowv1.MPIJobDefaultContainerName) - podTemplate := flytek8s.DefaultPodTemplateStore.LoadOrDefault(taskCtx.TaskExecutionMetadata().GetNamespace()) - - objectMeta := metav1.ObjectMeta{} - - if podTemplate != nil { - mergedPodSpec, err := flytek8s.MergePodSpecs(&podTemplate.Template.Spec, podSpec, kubeflowv1.MPIJobDefaultContainerName) - if err != nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to merge default pod template: [%v]", err.Error()) - } - podSpec = mergedPodSpec - objectMeta = podTemplate.Template.ObjectMeta - } - // workersPodSpec is deepCopy of podSpec submitted by flyte // WorkerPodSpec doesn't need any Argument & command. It will be trigger from launcher pod workersPodSpec := podSpec.DeepCopy() @@ -115,7 +101,7 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu jobSpec.MPIReplicaSpecs[t.replicaType] = &commonOp.ReplicaSpec{ Replicas: t.replicaNum, Template: v1.PodTemplateSpec{ - ObjectMeta: objectMeta, + ObjectMeta: *objectMeta, Spec: t.podSpec, }, RestartPolicy: commonOp.RestartPolicyNever, diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 01421f1a9..287590ed5 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -62,26 +62,12 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) } - podSpec, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) + podSpec, objectMeta, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } - common.OverrideDefaultContainerName(taskCtx, podSpec, kubeflowv1.PytorchJobDefaultContainerName) - podTemplate := flytek8s.DefaultPodTemplateStore.LoadOrDefault(taskCtx.TaskExecutionMetadata().GetNamespace()) - - objectMeta := metav1.ObjectMeta{} - - if podTemplate != nil { - mergedPodSpec, err := flytek8s.MergePodSpecs(&podTemplate.Template.Spec, podSpec, kubeflowv1.PytorchJobDefaultContainerName) - if err != nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to merge default pod template: [%v]", err.Error()) - } - podSpec = mergedPodSpec - objectMeta = podTemplate.Template.ObjectMeta - } - workers := pytorchTaskExtraArgs.GetWorkers() if workers == 0 { return nil, fmt.Errorf("number of worker should be more then 0") @@ -91,7 +77,7 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx PyTorchReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{ kubeflowv1.PyTorchJobReplicaTypeMaster: { Template: v1.PodTemplateSpec{ - ObjectMeta: objectMeta, + ObjectMeta: *objectMeta, Spec: *podSpec, }, RestartPolicy: commonOp.RestartPolicyNever, @@ -99,7 +85,7 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx kubeflowv1.PyTorchJobReplicaTypeWorker: { Replicas: &workers, Template: v1.PodTemplateSpec{ - ObjectMeta: objectMeta, + ObjectMeta: *objectMeta, Spec: *podSpec, }, RestartPolicy: commonOp.RestartPolicyNever, diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index a5c813ca4..ae531085c 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -62,26 +62,12 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) } - podSpec, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) + podSpec, objectMeta, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } - common.OverrideDefaultContainerName(taskCtx, podSpec, kubeflowv1.TFJobDefaultContainerName) - podTemplate := flytek8s.DefaultPodTemplateStore.LoadOrDefault(taskCtx.TaskExecutionMetadata().GetNamespace()) - - objectMeta := metav1.ObjectMeta{} - - if podTemplate != nil { - mergedPodSpec, err := flytek8s.MergePodSpecs(&podTemplate.Template.Spec, podSpec, kubeflowv1.TFJobDefaultContainerName) - if err != nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to merge default pod template: [%v]", err.Error()) - } - podSpec = mergedPodSpec - objectMeta = podTemplate.Template.ObjectMeta - } - workers := tensorflowTaskExtraArgs.GetWorkers() psReplicas := tensorflowTaskExtraArgs.GetPsReplicas() chiefReplicas := tensorflowTaskExtraArgs.GetChiefReplicas() @@ -110,7 +96,7 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task jobSpec.TFReplicaSpecs[t.replicaType] = &commonOp.ReplicaSpec{ Replicas: t.replicaNum, Template: v1.PodTemplateSpec{ - ObjectMeta: objectMeta, + ObjectMeta: *objectMeta, Spec: t.podSpec, }, RestartPolicy: commonOp.RestartPolicyNever, diff --git a/go/tasks/plugins/k8s/pod/container.go b/go/tasks/plugins/k8s/pod/container.go deleted file mode 100644 index 8f485b097..000000000 --- a/go/tasks/plugins/k8s/pod/container.go +++ /dev/null @@ -1,43 +0,0 @@ -package pod - -import ( - "context" - - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - - "github.com/flyteorg/flyteplugins/go/tasks/errors" - pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" - - v1 "k8s.io/api/core/v1" -) - -const ( - ContainerTaskType = "container" - PythonTaskType = "python-task" - RawContainerTaskType = "raw-container" -) - -type containerPodBuilder struct { -} - -func (containerPodBuilder) buildPodSpec(ctx context.Context, task *core.TaskTemplate, taskCtx pluginsCore.TaskExecutionContext) (*v1.PodSpec, error) { - podSpec, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) - if err != nil { - return nil, err - } - - return podSpec, nil -} - -func (containerPodBuilder) getPrimaryContainerName(task *core.TaskTemplate, taskCtx pluginsCore.TaskExecutionContext) (string, error) { - primaryContainerName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() - if primaryContainerName == "" { - return "", errors.Errorf(errors.BadTaskSpecification, "invalid TaskSpecification, missing generated name") - } - return primaryContainerName, nil -} - -func (containerPodBuilder) updatePodMetadata(ctx context.Context, pod *v1.Pod, task *core.TaskTemplate, taskCtx pluginsCore.TaskExecutionContext) error { - return nil -} diff --git a/go/tasks/plugins/k8s/pod/plugin.go b/go/tasks/plugins/k8s/pod/plugin.go index 781456747..33f69b3c2 100644 --- a/go/tasks/plugins/k8s/pod/plugin.go +++ b/go/tasks/plugins/k8s/pod/plugin.go @@ -3,88 +3,135 @@ package pod import ( "context" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - - "github.com/flyteorg/flyteplugins/go/tasks/errors" + pluginserrors "github.com/flyteorg/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyteplugins/go/tasks/logs" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/tasklog" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" + + "github.com/flyteorg/flytestdlib/logger" v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/controller-runtime/pkg/client" ) const ( - podTaskType = "pod" - PrimaryContainerKey = "primary_container_name" -) - -var ( - DefaultPodPlugin = plugin{ - defaultPodBuilder: containerPodBuilder{}, - podBuilders: map[string]podBuilder{ - SidecarTaskType: sidecarPodBuilder{}, - }, - } + ContainerTaskType = "container" + podTaskType = "pod" + pythonTaskType = "python-task" + rawContainerTaskType = "raw-container" + SidecarTaskType = "sidecar" ) -type podBuilder interface { - buildPodSpec(ctx context.Context, task *core.TaskTemplate, taskCtx pluginsCore.TaskExecutionContext) (*v1.PodSpec, error) - getPrimaryContainerName(task *core.TaskTemplate, taskCtx pluginsCore.TaskExecutionContext) (string, error) - updatePodMetadata(ctx context.Context, pod *v1.Pod, task *core.TaskTemplate, taskCtx pluginsCore.TaskExecutionContext) error +// Why, you might wonder do we recreate the generated go struct generated from the plugins.SidecarJob proto? Because +// although we unmarshal the task custom json, the PodSpec itself is not generated from a proto definition, +// but a proper go struct defined in k8s libraries. Therefore we only unmarshal the sidecar as a json, rather than jsonpb. +type sidecarJob struct { + PodSpec *v1.PodSpec + PrimaryContainerName string + Annotations map[string]string + Labels map[string]string } +var DefaultPodPlugin = plugin{} + type plugin struct { - defaultPodBuilder podBuilder - podBuilders map[string]podBuilder } -func (plugin) BuildIdentityResource(_ context.Context, _ pluginsCore.TaskExecutionMetadata) ( - client.Object, error) { +func (plugin) BuildIdentityResource(_ context.Context, _ pluginsCore.TaskExecutionMetadata) (client.Object, error) { return flytek8s.BuildIdentityPod(), nil } func (p plugin) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) { - // read TaskTemplate - task, err := taskCtx.TaskReader().Read(ctx) + taskTemplate, err := taskCtx.TaskReader().Read(ctx) if err != nil { - return nil, errors.Errorf(errors.BadTaskSpecification, - "TaskSpecification cannot be read, Err: [%v]", err.Error()) + logger.Warnf(ctx, "failed to read task information when trying to construct Pod, err: %s", err.Error()) + return nil, err } - // initialize PodBuilder - builder, exists := p.podBuilders[task.Type] - if !exists { - builder = p.defaultPodBuilder + var podSpec *v1.PodSpec + objectMeta := &metav1.ObjectMeta{ + Annotations: make(map[string]string), + Labels: make(map[string]string), } + primaryContainerName := "" - podSpec, err := builder.buildPodSpec(ctx, task, taskCtx) - if err != nil { - return nil, err - } + if taskTemplate.Type == SidecarTaskType && taskTemplate.TaskTypeVersion == 0 { + // handles pod tasks when they are defined as Sidecar tasks and marshal the podspec using k8s proto. + sidecarJob := sidecarJob{} + err := utils.UnmarshalStructToObj(taskTemplate.GetCustom(), &sidecarJob) + if err != nil { + return nil, pluginserrors.Errorf(pluginserrors.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) + } - podSpec.ServiceAccountName = flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) + if sidecarJob.PodSpec == nil { + return nil, pluginserrors.Errorf(pluginserrors.BadTaskSpecification, "invalid TaskSpecification, nil PodSpec [%v]", taskTemplate.GetCustom()) + } - podTemplate := flytek8s.DefaultPodTemplateStore.LoadOrDefault(taskCtx.TaskExecutionMetadata().GetNamespace()) - primaryContainerName, err := builder.getPrimaryContainerName(task, taskCtx) - if err != nil { - return nil, err + podSpec = sidecarJob.PodSpec + + // get primary container name + primaryContainerName = sidecarJob.PrimaryContainerName + + // update annotations and labels + objectMeta.Annotations = utils.UnionMaps(objectMeta.Annotations, sidecarJob.Annotations) + objectMeta.Labels = utils.UnionMaps(objectMeta.Labels, sidecarJob.Labels) + } else if taskTemplate.Type == SidecarTaskType && taskTemplate.TaskTypeVersion == 1 { + // handles pod tasks that marshal the pod spec to the task custom. + err := utils.UnmarshalStructToObj(taskTemplate.GetCustom(), &podSpec) + if err != nil { + return nil, pluginserrors.Errorf(pluginserrors.BadTaskSpecification, + "Unable to unmarshal task custom [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) + } + + // get primary container name + if len(taskTemplate.GetConfig()) == 0 { + return nil, pluginserrors.Errorf(pluginserrors.BadTaskSpecification, + "invalid TaskSpecification, config needs to be non-empty and include missing [%s] key", flytek8s.PrimaryContainerKey) + } + + var ok bool + if primaryContainerName, ok = taskTemplate.GetConfig()[flytek8s.PrimaryContainerKey]; !ok { + return nil, pluginserrors.Errorf(pluginserrors.BadTaskSpecification, + "invalid TaskSpecification, config missing [%s] key in [%v]", flytek8s.PrimaryContainerKey, taskTemplate.GetConfig()) + } + + // update annotations and labels + if taskTemplate.GetK8SPod() != nil && taskTemplate.GetK8SPod().Metadata != nil { + objectMeta.Annotations = utils.UnionMaps(objectMeta.Annotations, taskTemplate.GetK8SPod().Metadata.Annotations) + objectMeta.Labels = utils.UnionMaps(objectMeta.Labels, taskTemplate.GetK8SPod().Metadata.Labels) + } + } else { + // handles both container / pod tasks that use the TaskTemplate Container and K8sPod fields + var err error + podSpec, objectMeta, primaryContainerName, err = flytek8s.BuildRawPod(ctx, taskCtx) + if err != nil { + return nil, err + } } - pod, err := flytek8s.BuildPodWithSpec(podTemplate, podSpec, primaryContainerName) + // update podSpec and objectMeta with Flyte customizations + podSpec, objectMeta, err = flytek8s.ApplyFlytePodConfiguration(ctx, taskCtx, podSpec, objectMeta, primaryContainerName) if err != nil { return nil, err } - // update pod metadata - if err = builder.updatePodMetadata(ctx, pod, task, taskCtx); err != nil { - return nil, err + // set primary container name if this is executed as a sidecar + if taskTemplate.Type == SidecarTaskType { + objectMeta.Annotations[flytek8s.PrimaryContainerKey] = primaryContainerName } + podSpec.ServiceAccountName = flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) + + pod := flytek8s.BuildIdentityPod() + pod.ObjectMeta = *objectMeta + pod.Spec = *podSpec + return pod, nil } @@ -126,7 +173,7 @@ func (plugin) GetTaskPhaseWithLogs(ctx context.Context, pluginContext k8s.Plugin return pluginsCore.PhaseInfoUndefined, nil } - primaryContainerName, exists := r.GetAnnotations()[PrimaryContainerKey] + primaryContainerName, exists := r.GetAnnotations()[flytek8s.PrimaryContainerKey] if !exists { // if the primary container annotation dos not exist, then the task requires all containers // to succeed to declare success. therefore, if the pod is not in one of the above states we @@ -157,7 +204,7 @@ func init() { pluginmachinery.PluginRegistry().RegisterK8sPlugin( k8s.PluginEntry{ ID: ContainerTaskType, - RegisteredTaskTypes: []pluginsCore.TaskType{ContainerTaskType, PythonTaskType, RawContainerTaskType}, + RegisteredTaskTypes: []pluginsCore.TaskType{ContainerTaskType, pythonTaskType, rawContainerTaskType}, ResourceToWatch: &v1.Pod{}, Plugin: DefaultPodPlugin, IsDefault: true, @@ -176,7 +223,7 @@ func init() { pluginmachinery.PluginRegistry().RegisterK8sPlugin( k8s.PluginEntry{ ID: podTaskType, - RegisteredTaskTypes: []pluginsCore.TaskType{ContainerTaskType, PythonTaskType, RawContainerTaskType, SidecarTaskType}, + RegisteredTaskTypes: []pluginsCore.TaskType{ContainerTaskType, pythonTaskType, rawContainerTaskType, SidecarTaskType}, ResourceToWatch: &v1.Pod{}, Plugin: DefaultPodPlugin, IsDefault: true, diff --git a/go/tasks/plugins/k8s/pod/sidecar.go b/go/tasks/plugins/k8s/pod/sidecar.go deleted file mode 100644 index 1c4cf0276..000000000 --- a/go/tasks/plugins/k8s/pod/sidecar.go +++ /dev/null @@ -1,184 +0,0 @@ -package pod - -import ( - "context" - - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - - "github.com/flyteorg/flyteplugins/go/tasks/errors" - pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" - - v1 "k8s.io/api/core/v1" -) - -const ( - SidecarTaskType = "sidecar" -) - -// Why, you might wonder do we recreate the generated go struct generated from the plugins.SidecarJob proto? Because -// although we unmarshal the task custom json, the PodSpec itself is not generated from a proto definition, -// but a proper go struct defined in k8s libraries. Therefore we only unmarshal the sidecar as a json, rather than jsonpb. -type sidecarJob struct { - PodSpec *v1.PodSpec - PrimaryContainerName string - Annotations map[string]string - Labels map[string]string -} - -type sidecarPodBuilder struct { -} - -func (sidecarPodBuilder) buildPodSpec(ctx context.Context, task *core.TaskTemplate, taskCtx pluginsCore.TaskExecutionContext) (*v1.PodSpec, error) { - var podSpec v1.PodSpec - switch task.TaskTypeVersion { - case 0: - // Handles pod tasks when they are defined as Sidecar tasks and marshal the podspec using k8s proto. - sidecarJob := sidecarJob{} - err := utils.UnmarshalStructToObj(task.GetCustom(), &sidecarJob) - if err != nil { - return nil, errors.Errorf(errors.BadTaskSpecification, - "invalid TaskSpecification [%v], Err: [%v]", task.GetCustom(), err.Error()) - } - - if sidecarJob.PodSpec == nil { - return nil, errors.Errorf(errors.BadTaskSpecification, - "invalid TaskSpecification, nil PodSpec [%v]", task.GetCustom()) - } - - podSpec = *sidecarJob.PodSpec - case 1: - // Handles pod tasks that marshal the pod spec to the task custom. - err := utils.UnmarshalStructToObj(task.GetCustom(), &podSpec) - if err != nil { - return nil, errors.Errorf(errors.BadTaskSpecification, - "Unable to unmarshal task custom [%v], Err: [%v]", task.GetCustom(), err.Error()) - } - default: - // Handles pod tasks that marshal the pod spec to the k8s_pod task target. - if task.GetK8SPod() == nil || task.GetK8SPod().PodSpec == nil { - return nil, errors.Errorf(errors.BadTaskSpecification, - "Pod tasks with task type version > 1 should specify their target as a K8sPod with a defined pod spec") - } - - err := utils.UnmarshalStructToObj(task.GetK8SPod().PodSpec, &podSpec) - if err != nil { - return nil, errors.Errorf(errors.BadTaskSpecification, - "Unable to unmarshal task custom [%v], Err: [%v]", task.GetCustom(), err.Error()) - } - } - - // Set the restart policy to *not* inherit from the default so that a completed pod doesn't get caught in a - // CrashLoopBackoff after the initial job completion. - podSpec.RestartPolicy = v1.RestartPolicyNever - - return &podSpec, nil -} - -func (sidecarPodBuilder) getPrimaryContainerName(task *core.TaskTemplate, taskCtx pluginsCore.TaskExecutionContext) (string, error) { - switch task.TaskTypeVersion { - case 0: - // Handles pod tasks when they are defined as Sidecar tasks and marshal the podspec using k8s proto. - sidecarJob := sidecarJob{} - err := utils.UnmarshalStructToObj(task.GetCustom(), &sidecarJob) - if err != nil { - return "", errors.Errorf(errors.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", task.GetCustom(), err.Error()) - } - - return sidecarJob.PrimaryContainerName, nil - default: - if len(task.GetConfig()) == 0 { - return "", errors.Errorf(errors.BadTaskSpecification, - "invalid TaskSpecification, config needs to be non-empty and include missing [%s] key", PrimaryContainerKey) - } - - primaryContainerName, ok := task.GetConfig()[PrimaryContainerKey] - if !ok { - return "", errors.Errorf(errors.BadTaskSpecification, - "invalid TaskSpecification, config missing [%s] key in [%v]", PrimaryContainerKey, task.GetConfig()) - } - - return primaryContainerName, nil - } -} - -func mergeMapInto(src map[string]string, dst map[string]string) { - for key, value := range src { - dst[key] = value - } -} - -func (s sidecarPodBuilder) updatePodMetadata(ctx context.Context, pod *v1.Pod, task *core.TaskTemplate, taskCtx pluginsCore.TaskExecutionContext) error { - pod.Annotations = make(map[string]string) - pod.Labels = make(map[string]string) - - switch task.TaskTypeVersion { - case 0: - // Handles pod tasks when they are defined as Sidecar tasks and marshal the podspec using k8s proto. - sidecarJob := sidecarJob{} - err := utils.UnmarshalStructToObj(task.GetCustom(), &sidecarJob) - if err != nil { - return errors.Errorf(errors.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", task.GetCustom(), err.Error()) - } - - mergeMapInto(sidecarJob.Annotations, pod.Annotations) - mergeMapInto(sidecarJob.Labels, pod.Labels) - default: - // Handles pod tasks that marshal the pod spec to the k8s_pod task target. - if task.GetK8SPod() != nil && task.GetK8SPod().Metadata != nil { - mergeMapInto(task.GetK8SPod().Metadata.Annotations, pod.Annotations) - mergeMapInto(task.GetK8SPod().Metadata.Labels, pod.Labels) - } - } - - // validate pod and update resource requirements - primaryContainerName, err := s.getPrimaryContainerName(task, taskCtx) - if err != nil { - return err - } - - if err := validateAndFinalizePodSpec(ctx, taskCtx, primaryContainerName, &pod.Spec); err != nil { - return err - } - - pod.Annotations[PrimaryContainerKey] = primaryContainerName - return nil -} - -// This method handles templatizing primary container input args, env variables and adds a GPU toleration to the pod -// spec if necessary. -func validateAndFinalizePodSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, primaryContainerName string, podSpec *v1.PodSpec) error { - var hasPrimaryContainer bool - - resReqs := make([]v1.ResourceRequirements, 0, len(podSpec.Containers)) - for index, container := range podSpec.Containers { - var resourceMode = flytek8s.ResourceCustomizationModeEnsureExistingResourcesInRange - if container.Name == primaryContainerName { - hasPrimaryContainer = true - resourceMode = flytek8s.ResourceCustomizationModeMergeExistingResources - } - - templateParameters := template.Parameters{ - TaskExecMetadata: taskCtx.TaskExecutionMetadata(), - Inputs: taskCtx.InputReader(), - OutputPath: taskCtx.OutputWriter(), - Task: taskCtx.TaskReader(), - } - - err := flytek8s.AddFlyteCustomizationsToContainer(ctx, templateParameters, resourceMode, &podSpec.Containers[index]) - if err != nil { - return err - } - - resReqs = append(resReqs, container.Resources) - } - - if !hasPrimaryContainer { - return errors.Errorf(errors.BadTaskSpecification, "invalid Sidecar task, primary container [%s] not defined", primaryContainerName) - } - - flytek8s.UpdatePod(taskCtx.TaskExecutionMetadata(), resReqs, podSpec) - return nil -} diff --git a/go/tasks/plugins/k8s/pod/sidecar_test.go b/go/tasks/plugins/k8s/pod/sidecar_test.go index e301e5bf2..ecb0e36f8 100644 --- a/go/tasks/plugins/k8s/pod/sidecar_test.go +++ b/go/tasks/plugins/k8s/pod/sidecar_test.go @@ -9,25 +9,27 @@ import ( "path" "testing" - "sigs.k8s.io/controller-runtime/pkg/client" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" errors2 "github.com/flyteorg/flyteplugins/go/tasks/errors" + pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + pluginsCoreMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" + pluginsIOMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" + + structpb "github.com/golang/protobuf/ptypes/struct" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + + v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - structpb "github.com/golang/protobuf/ptypes/struct" - "github.com/stretchr/testify/assert" - v1 "k8s.io/api/core/v1" - - pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - pluginsCoreMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" - pluginsIOMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" + "sigs.k8s.io/controller-runtime/pkg/client" ) const ResourceNvidiaGPU = "nvidia.com/gpu" @@ -202,7 +204,7 @@ func TestBuildSidecarResource_TaskType2(t *testing.T) { Type: SidecarTaskType, TaskTypeVersion: 2, Config: map[string]string{ - PrimaryContainerKey: "primary container", + flytek8s.PrimaryContainerKey: "primary container", }, Target: &core.TaskTemplate_K8SPod{ K8SPod: &core.K8SPod{ @@ -245,8 +247,8 @@ func TestBuildSidecarResource_TaskType2(t *testing.T) { res, err := DefaultPodPlugin.BuildResource(context.TODO(), taskCtx) assert.Nil(t, err) assert.EqualValues(t, map[string]string{ - PrimaryContainerKey: "primary container", - "anno": "bar", + flytek8s.PrimaryContainerKey: "primary container", + "anno": "bar", }, res.GetAnnotations()) assert.EqualValues(t, map[string]string{ "label": "foo", @@ -286,7 +288,7 @@ func TestBuildSidecarResource_TaskType2_Invalid_Spec(t *testing.T) { Type: SidecarTaskType, TaskTypeVersion: 2, Config: map[string]string{ - PrimaryContainerKey: "primary container", + flytek8s.PrimaryContainerKey: "primary container", }, Target: &core.TaskTemplate_K8SPod{ K8SPod: &core.K8SPod{ @@ -325,7 +327,7 @@ func TestBuildSidecarResource_TaskType1(t *testing.T) { Custom: structObj, TaskTypeVersion: 1, Config: map[string]string{ - PrimaryContainerKey: "primary container", + flytek8s.PrimaryContainerKey: "primary container", }, } @@ -354,7 +356,7 @@ func TestBuildSidecarResource_TaskType1(t *testing.T) { res, err := DefaultPodPlugin.BuildResource(context.TODO(), taskCtx) assert.Nil(t, err) assert.EqualValues(t, map[string]string{ - PrimaryContainerKey: "primary container", + flytek8s.PrimaryContainerKey: "primary container", }, res.GetAnnotations()) assert.EqualValues(t, map[string]string{}, res.GetLabels()) @@ -471,8 +473,8 @@ func TestBuildSidecarResource(t *testing.T) { res, err := DefaultPodPlugin.BuildResource(context.TODO(), taskCtx) assert.Nil(t, err) assert.EqualValues(t, map[string]string{ - PrimaryContainerKey: "a container", - "a1": "a1", + flytek8s.PrimaryContainerKey: "a container", + "a1": "a1", }, res.GetAnnotations()) assert.EqualValues(t, map[string]string{ @@ -488,14 +490,8 @@ func TestBuildSidecarResource(t *testing.T) { assert.Equal(t, "service-account", res.(*v1.Pod).Spec.ServiceAccountName) - assert.Len(t, res.(*v1.Pod).Spec.Tolerations, 1) - for _, tol := range res.(*v1.Pod).Spec.Tolerations { - if tol.Key == "my toleration key" { - assert.Equal(t, tol.Value, "my toleration value") - } else { - t.Fatalf("unexpected toleration [%+v]", tol) - } - } + checkTolerations(t, res, tolGPU) + // Assert resource requirements are correctly set expectedCPURequest := resource.MustParse("2048m") assert.Equal(t, expectedCPURequest.Value(), res.(*v1.Pod).Spec.Containers[0].Resources.Requests.Cpu().Value()) @@ -577,7 +573,7 @@ func TestGetTaskSidecarStatus(t *testing.T) { }, } res.SetAnnotations(map[string]string{ - PrimaryContainerKey: "PrimaryContainer", + flytek8s.PrimaryContainerKey: "PrimaryContainer", }) taskCtx := getDummySidecarTaskContext(task, sidecarResourceRequirements) phaseInfo, err := DefaultPodPlugin.GetTaskPhase(context.TODO(), taskCtx, res) @@ -604,7 +600,7 @@ func TestDemystifiedSidecarStatus_PrimaryFailed(t *testing.T) { }, } res.SetAnnotations(map[string]string{ - PrimaryContainerKey: "Primary", + flytek8s.PrimaryContainerKey: "Primary", }) taskCtx := getDummySidecarTaskContext(&core.TaskTemplate{}, sidecarResourceRequirements) phaseInfo, err := DefaultPodPlugin.GetTaskPhase(context.TODO(), taskCtx, res) @@ -629,7 +625,7 @@ func TestDemystifiedSidecarStatus_PrimarySucceeded(t *testing.T) { }, } res.SetAnnotations(map[string]string{ - PrimaryContainerKey: "Primary", + flytek8s.PrimaryContainerKey: "Primary", }) taskCtx := getDummySidecarTaskContext(&core.TaskTemplate{}, sidecarResourceRequirements) phaseInfo, err := DefaultPodPlugin.GetTaskPhase(context.TODO(), taskCtx, res) @@ -654,7 +650,7 @@ func TestDemystifiedSidecarStatus_PrimaryRunning(t *testing.T) { }, } res.SetAnnotations(map[string]string{ - PrimaryContainerKey: "Primary", + flytek8s.PrimaryContainerKey: "Primary", }) taskCtx := getDummySidecarTaskContext(&core.TaskTemplate{}, sidecarResourceRequirements) phaseInfo, err := DefaultPodPlugin.GetTaskPhase(context.TODO(), taskCtx, res) @@ -674,7 +670,7 @@ func TestDemystifiedSidecarStatus_PrimaryMissing(t *testing.T) { }, } res.SetAnnotations(map[string]string{ - PrimaryContainerKey: "Primary", + flytek8s.PrimaryContainerKey: "Primary", }) taskCtx := getDummySidecarTaskContext(&core.TaskTemplate{}, sidecarResourceRequirements) phaseInfo, err := DefaultPodPlugin.GetTaskPhase(context.TODO(), taskCtx, res) diff --git a/go/tasks/plugins/k8s/ray/ray.go b/go/tasks/plugins/k8s/ray/ray.go index 02b3f060b..d7e274250 100644 --- a/go/tasks/plugins/k8s/ray/ray.go +++ b/go/tasks/plugins/k8s/ray/ray.go @@ -7,9 +7,6 @@ import ( "strings" "time" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" - "github.com/flyteorg/flytestdlib/logger" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" "github.com/flyteorg/flyteplugins/go/tasks/errors" @@ -58,25 +55,10 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC return nil, errors.Errorf(errors.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) } - if taskTemplate.GetContainer() == nil { - logger.Errorf(ctx, "Default Pod creation logic works for default container in the task template only.") - return nil, fmt.Errorf("container not specified in task template") - } - - templateParameters := template.Parameters{ - Task: taskCtx.TaskReader(), - Inputs: taskCtx.InputReader(), - OutputPath: taskCtx.OutputWriter(), - TaskExecMetadata: taskCtx.TaskExecutionMetadata(), - } - container, err := flytek8s.ToK8sContainer(ctx, taskTemplate.GetContainer(), taskTemplate.Interface, templateParameters) + container, err := flytek8s.ToK8sContainer(ctx, taskCtx) if err != nil { return nil, errors.Errorf(errors.BadTaskSpecification, "Unable to create container spec: [%v]", err.Error()) } - err = flytek8s.AddFlyteCustomizationsToContainer(ctx, templateParameters, flytek8s.ResourceCustomizationModeAssignResources, container) - if err != nil { - return nil, errors.Errorf(errors.BadTaskSpecification, "Unable to update container resource and command: [%v]", err.Error()) - } headReplicas := int32(1) headNodeRayStartParams := make(map[string]string)