diff --git a/docs/deployment/configuration/general.rst b/docs/deployment/configuration/general.rst index 0b97f8b8ff..5db7786c9a 100644 --- a/docs/deployment/configuration/general.rst +++ b/docs/deployment/configuration/general.rst @@ -128,6 +128,9 @@ as the base container configuration for all primary containers. If both containe names exist in the default PodTemplate, Flyte first applies the default configuration, followed by the primary configuration. +Note: Init containers can be configured with similar granularity using "default-init" +and "primary-init" init container names. + The ``containers`` field is required in each k8s PodSpec. If no default configuration is desired, specifying a container with a name other than "default" or "primary" (for example, "noop") is considered best practice. Since Flyte only diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot.go index 8e89e58d3d..eaee5bce6c 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot.go @@ -201,14 +201,15 @@ func AddCoPilotToContainer(ctx context.Context, cfg config.FlyteCoPilotConfig, c return nil } -func AddCoPilotToPod(ctx context.Context, cfg config.FlyteCoPilotConfig, coPilotPod *v1.PodSpec, iFace *core.TypedInterface, taskExecMetadata core2.TaskExecutionMetadata, inputPaths io.InputFilePaths, outputPaths io.OutputFilePaths, pilot *core.DataLoadingConfig) error { +func AddCoPilotToPod(ctx context.Context, cfg config.FlyteCoPilotConfig, coPilotPod *v1.PodSpec, iFace *core.TypedInterface, taskExecMetadata core2.TaskExecutionMetadata, inputPaths io.InputFilePaths, outputPaths io.OutputFilePaths, pilot *core.DataLoadingConfig) (string, error) { if pilot == nil || !pilot.Enabled { - return nil + return "", nil } logger.Infof(ctx, "CoPilot Enabled for task [%s]", taskExecMetadata.GetTaskExecutionID().GetID().TaskId.Name) shareProcessNamespaceEnabled := true coPilotPod.ShareProcessNamespace = &shareProcessNamespaceEnabled + primaryInitContainerName := "" if iFace != nil { if iFace.Inputs != nil && len(iFace.Inputs.Variables) > 0 { inPath := cfg.DefaultInputDataPath @@ -231,13 +232,14 @@ func AddCoPilotToPod(ctx context.Context, cfg config.FlyteCoPilotConfig, coPilot // Lets add the Inputs init container args, err := DownloadCommandArgs(inputPaths.GetInputPath(), outputPaths.GetOutputPrefixPath(), inPath, format, iFace.Inputs) if err != nil { - return err + return primaryInitContainerName, err } downloader, err := FlyteCoPilotContainer(flyteInitContainerName, cfg, args, inputsVolumeMount) if err != nil { - return err + return primaryInitContainerName, err } coPilotPod.InitContainers = append(coPilotPod.InitContainers, downloader) + primaryInitContainerName = downloader.Name } if iFace.Outputs != nil && len(iFace.Outputs.Variables) > 0 { @@ -260,15 +262,15 @@ func AddCoPilotToPod(ctx context.Context, cfg config.FlyteCoPilotConfig, coPilot // Lets add the Inputs init container args, err := SidecarCommandArgs(outPath, outputPaths.GetOutputPrefixPath(), outputPaths.GetRawOutputPrefix(), cfg.StartTimeout.Duration, iFace) if err != nil { - return err + return primaryInitContainerName, err } sidecar, err := FlyteCoPilotContainer(flyteSidecarContainerName, cfg, args, outputsVolumeMount) if err != nil { - return err + return primaryInitContainerName, err } coPilotPod.Containers = append(coPilotPod.Containers, sidecar) } } - return nil + return primaryInitContainerName, nil } diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot_test.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot_test.go index 09a9fbf52b..aba18c85ac 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot_test.go @@ -533,7 +533,9 @@ func TestAddCoPilotToPod(t *testing.T) { InputPath: "in", OutputPath: "out", } - assert.NoError(t, AddCoPilotToPod(ctx, cfg, &pod, iface, taskMetadata, inputPaths, opath, pilot)) + primaryInitContainerName, err := AddCoPilotToPod(ctx, cfg, &pod, iface, taskMetadata, inputPaths, opath, pilot) + assert.NoError(t, err) + assert.Equal(t, "test-downloader", primaryInitContainerName) assertPodHasSNPS(t, &pod) assertPodHasCoPilot(t, cfg, pilot, iface, &pod) }) @@ -545,7 +547,9 @@ func TestAddCoPilotToPod(t *testing.T) { InputPath: "in", OutputPath: "out", } - assert.NoError(t, AddCoPilotToPod(ctx, cfg, &pod, nil, taskMetadata, inputPaths, opath, pilot)) + primaryInitContainerName, err := AddCoPilotToPod(ctx, cfg, &pod, nil, taskMetadata, inputPaths, opath, pilot) + assert.NoError(t, err) + assert.Empty(t, primaryInitContainerName) assertPodHasSNPS(t, &pod) assertPodHasCoPilot(t, cfg, pilot, nil, &pod) }) @@ -565,7 +569,9 @@ func TestAddCoPilotToPod(t *testing.T) { InputPath: "in", OutputPath: "out", } - assert.NoError(t, AddCoPilotToPod(ctx, cfg, &pod, iface, taskMetadata, inputPaths, opath, pilot)) + primaryInitContainerName, err := AddCoPilotToPod(ctx, cfg, &pod, iface, taskMetadata, inputPaths, opath, pilot) + assert.NoError(t, err) + assert.Equal(t, "test-downloader", primaryInitContainerName) assertPodHasSNPS(t, &pod) assertPodHasCoPilot(t, cfg, pilot, iface, &pod) }) @@ -584,7 +590,9 @@ func TestAddCoPilotToPod(t *testing.T) { InputPath: "in", OutputPath: "out", } - assert.NoError(t, AddCoPilotToPod(ctx, cfg, &pod, iface, taskMetadata, inputPaths, opath, pilot)) + primaryInitContainerName, err := AddCoPilotToPod(ctx, cfg, &pod, iface, taskMetadata, inputPaths, opath, pilot) + assert.NoError(t, err) + assert.Empty(t, primaryInitContainerName) assertPodHasSNPS(t, &pod) assertPodHasCoPilot(t, cfg, pilot, iface, &pod) }) @@ -603,11 +611,15 @@ func TestAddCoPilotToPod(t *testing.T) { InputPath: "in", OutputPath: "out", } - assert.NoError(t, AddCoPilotToPod(ctx, cfg, &pod, iface, taskMetadata, inputPaths, opath, pilot)) + primaryInitContainerName, err := AddCoPilotToPod(ctx, cfg, &pod, iface, taskMetadata, inputPaths, opath, pilot) + assert.NoError(t, err) + assert.Empty(t, primaryInitContainerName) assert.Len(t, pod.Volumes, 0) }) t.Run("nil", func(t *testing.T) { - assert.NoError(t, AddCoPilotToPod(ctx, cfg, nil, nil, taskMetadata, inputPaths, opath, nil)) + primaryInitContainerName, err := AddCoPilotToPod(ctx, cfg, nil, nil, taskMetadata, inputPaths, opath, nil) + assert.NoError(t, err) + assert.Empty(t, primaryInitContainerName) }) } diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go index e8252090df..229f963968 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -28,7 +28,9 @@ const PrimaryContainerNotFound = "PrimaryContainerNotFound" const SIGKILL = 137 const defaultContainerTemplateName = "default" +const defaultInitContainerTemplateName = "default-init" const primaryContainerTemplateName = "primary" +const primaryInitContainerTemplateName = "primary-init" const PrimaryContainerKey = "primary_container_name" // AddRequiredNodeSelectorRequirements adds the provided v1.NodeSelectorRequirement @@ -387,14 +389,17 @@ func ApplyFlytePodConfiguration(ctx context.Context, tCtx pluginsCore.TaskExecut dataLoadingConfig = pod.GetDataConfig() } + primaryInitContainerName := "" + if dataLoadingConfig != nil { if err := AddCoPilotToContainer(ctx, config.GetK8sPluginConfig().CoPilot, primaryContainer, taskTemplate.Interface, dataLoadingConfig); err != nil { return nil, nil, err } - if err := AddCoPilotToPod(ctx, config.GetK8sPluginConfig().CoPilot, podSpec, taskTemplate.GetInterface(), - tCtx.TaskExecutionMetadata(), tCtx.InputReader(), tCtx.OutputWriter(), dataLoadingConfig); err != nil { + primaryInitContainerName, err = AddCoPilotToPod(ctx, config.GetK8sPluginConfig().CoPilot, podSpec, taskTemplate.GetInterface(), + tCtx.TaskExecutionMetadata(), tCtx.InputReader(), tCtx.OutputWriter(), dataLoadingConfig) + if err != nil { return nil, nil, err } } @@ -406,7 +411,7 @@ func ApplyFlytePodConfiguration(ctx context.Context, tCtx pluginsCore.TaskExecut } // merge PodSpec and ObjectMeta with configuration pod template (if exists) - podSpec, objectMeta, err = MergeWithBasePodTemplate(ctx, tCtx, podSpec, objectMeta, primaryContainerName) + podSpec, objectMeta, err = MergeWithBasePodTemplate(ctx, tCtx, podSpec, objectMeta, primaryContainerName, primaryInitContainerName) if err != nil { return nil, nil, err } @@ -495,7 +500,7 @@ func getBasePodTemplate(ctx context.Context, tCtx pluginsCore.TaskExecutionConte // MergeWithBasePodTemplate attempts to merge the provided PodSpec and ObjectMeta with the configuration PodTemplate for // this task. func MergeWithBasePodTemplate(ctx context.Context, tCtx pluginsCore.TaskExecutionContext, - podSpec *v1.PodSpec, objectMeta *metav1.ObjectMeta, primaryContainerName string) (*v1.PodSpec, *metav1.ObjectMeta, error) { + podSpec *v1.PodSpec, objectMeta *metav1.ObjectMeta, primaryContainerName string, primaryInitContainerName string) (*v1.PodSpec, *metav1.ObjectMeta, error) { // attempt to retrieve base PodTemplate podTemplate, err := getBasePodTemplate(ctx, tCtx, DefaultPodTemplateStore) @@ -507,7 +512,7 @@ func MergeWithBasePodTemplate(ctx context.Context, tCtx pluginsCore.TaskExecutio } // merge podSpec with podTemplate - mergedPodSpec, err := mergePodSpecs(&podTemplate.Template.Spec, podSpec, primaryContainerName) + mergedPodSpec, err := mergePodSpecs(&podTemplate.Template.Spec, podSpec, primaryContainerName, primaryInitContainerName) if err != nil { return nil, nil, err } @@ -524,7 +529,7 @@ func MergeWithBasePodTemplate(ctx context.Context, tCtx pluginsCore.TaskExecutio // mergePodSpecs merges the two provided PodSpecs. This process uses the first as the base configuration, where values // set by the first PodSpec are overwritten by the second in the return value. Additionally, this function applies // container-level configuration from the basePodSpec. -func mergePodSpecs(basePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContainerName string) (*v1.PodSpec, error) { +func mergePodSpecs(basePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContainerName string, primaryInitContainerName string) (*v1.PodSpec, error) { if basePodSpec == nil || podSpec == nil { return nil, errors.New("neither the basePodSpec or the podSpec can be nil") } @@ -539,6 +544,16 @@ func mergePodSpecs(basePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContaine } } + // extract defaultInitContainerTemplate and primaryInitContainerTemplate + var defaultInitContainerTemplate, primaryInitContainerTemplate *v1.Container + for i := 0; i < len(basePodSpec.InitContainers); i++ { + if basePodSpec.InitContainers[i].Name == defaultInitContainerTemplateName { + defaultInitContainerTemplate = &basePodSpec.InitContainers[i] + } else if basePodSpec.InitContainers[i].Name == primaryInitContainerTemplateName { + primaryInitContainerTemplate = &basePodSpec.InitContainers[i] + } + } + // merge PodTemplate PodSpec with podSpec var mergedPodSpec *v1.PodSpec = basePodSpec.DeepCopy() if err := mergo.Merge(mergedPodSpec, podSpec, mergo.WithOverride, mergo.WithAppendSlice); err != nil { @@ -580,6 +595,43 @@ func mergePodSpecs(basePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContaine } mergedPodSpec.Containers = mergedContainers + + // merge PodTemplate init containers + var mergedInitContainers []v1.Container + for _, initContainer := range podSpec.InitContainers { + // if applicable start with defaultContainerTemplate + var mergedInitContainer *v1.Container + if defaultInitContainerTemplate != nil { + mergedInitContainer = defaultInitContainerTemplate.DeepCopy() + } + + // if applicable merge with primaryInitContainerTemplate + if initContainer.Name == primaryInitContainerName && primaryInitContainerTemplate != nil { + if mergedInitContainer == nil { + mergedInitContainer = primaryInitContainerTemplate.DeepCopy() + } else { + err := mergo.Merge(mergedInitContainer, primaryInitContainerTemplate, mergo.WithOverride, mergo.WithAppendSlice) + if err != nil { + return nil, err + } + } + } + + // if applicable merge with existing init initContainer + if mergedInitContainer == nil { + mergedInitContainers = append(mergedInitContainers, initContainer) + } else { + err := mergo.Merge(mergedInitContainer, initContainer, mergo.WithOverride, mergo.WithAppendSlice) + if err != nil { + return nil, err + } + + mergedInitContainers = append(mergedInitContainers, *mergedInitContainer) + } + } + + mergedPodSpec.InitContainers = mergedInitContainers + return mergedPodSpec, nil } diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go index 0c2e9ef5cc..9797b5e05b 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go @@ -1934,6 +1934,14 @@ func TestMergeWithBasePodTemplate(t *testing.T) { Name: "bar", }, }, + InitContainers: []v1.Container{ + v1.Container{ + Name: "foo-init", + }, + v1.Container{ + Name: "foo-bar", + }, + }, } objectMeta := metav1.ObjectMeta{ @@ -1954,7 +1962,7 @@ func TestMergeWithBasePodTemplate(t *testing.T) { tCtx.OnTaskExecutionMetadata().Return(dummyTaskExecutionMetadata(&v1.ResourceRequirements{}, nil, "")) tCtx.OnTaskReader().Return(taskReader) - resultPodSpec, resultObjectMeta, err := MergeWithBasePodTemplate(context.TODO(), tCtx, &podSpec, &objectMeta, "foo") + resultPodSpec, resultObjectMeta, err := MergeWithBasePodTemplate(context.TODO(), tCtx, &podSpec, &objectMeta, "foo", "foo-init") assert.Nil(t, err) assert.True(t, reflect.DeepEqual(podSpec, *resultPodSpec)) assert.True(t, reflect.DeepEqual(objectMeta, *resultObjectMeta)) @@ -1966,6 +1974,11 @@ func TestMergeWithBasePodTemplate(t *testing.T) { TerminationMessagePath: "/dev/primary-termination-log", } + primaryInitContainerTemplate := v1.Container{ + Name: primaryInitContainerTemplateName, + TerminationMessagePath: "/dev/primary-init-termination-log", + } + podTemplate := v1.PodTemplate{ ObjectMeta: metav1.ObjectMeta{ Name: "fooTemplate", @@ -1982,6 +1995,9 @@ func TestMergeWithBasePodTemplate(t *testing.T) { Containers: []v1.Container{ primaryContainerTemplate, }, + InitContainers: []v1.Container{ + primaryInitContainerTemplate, + }, }, }, } @@ -2008,13 +2024,16 @@ func TestMergeWithBasePodTemplate(t *testing.T) { tCtx.OnTaskExecutionMetadata().Return(dummyTaskExecutionMetadata(&v1.ResourceRequirements{}, nil, "")) tCtx.OnTaskReader().Return(taskReader) - resultPodSpec, resultObjectMeta, err := MergeWithBasePodTemplate(context.TODO(), tCtx, &podSpec, &objectMeta, "foo") + resultPodSpec, resultObjectMeta, err := MergeWithBasePodTemplate(context.TODO(), tCtx, &podSpec, &objectMeta, "foo", "foo-init") assert.Nil(t, err) // test that template podSpec is merged primaryContainer := resultPodSpec.Containers[0] assert.Equal(t, podSpec.Containers[0].Name, primaryContainer.Name) assert.Equal(t, primaryContainerTemplate.TerminationMessagePath, primaryContainer.TerminationMessagePath) + primaryInitContainer := resultPodSpec.InitContainers[0] + assert.Equal(t, podSpec.InitContainers[0].Name, primaryInitContainer.Name) + assert.Equal(t, primaryInitContainerTemplate.TerminationMessagePath, primaryInitContainer.TerminationMessagePath) // test that template object metadata is copied assert.Contains(t, resultObjectMeta.Labels, "fooKey") @@ -2027,13 +2046,13 @@ func TestMergeWithBasePodTemplate(t *testing.T) { func TestMergePodSpecs(t *testing.T) { var priority int32 = 1 - podSpec1, _ := mergePodSpecs(nil, nil, "foo") + podSpec1, _ := mergePodSpecs(nil, nil, "foo", "foo-init") assert.Nil(t, podSpec1) - podSpec2, _ := mergePodSpecs(&v1.PodSpec{}, nil, "foo") + podSpec2, _ := mergePodSpecs(&v1.PodSpec{}, nil, "foo", "foo-init") assert.Nil(t, podSpec2) - podSpec3, _ := mergePodSpecs(nil, &v1.PodSpec{}, "foo") + podSpec3, _ := mergePodSpecs(nil, &v1.PodSpec{}, "foo", "foo-init") assert.Nil(t, podSpec3) podSpec := v1.PodSpec{ @@ -2051,6 +2070,20 @@ func TestMergePodSpecs(t *testing.T) { Name: "bar", }, }, + InitContainers: []v1.Container{ + v1.Container{ + Name: "primary-init", + VolumeMounts: []v1.VolumeMount{ + { + Name: "nccl", + MountPath: "abc", + }, + }, + }, + v1.Container{ + Name: "bar-init", + }, + }, NodeSelector: map[string]string{ "baz": "bar", }, @@ -2076,11 +2109,25 @@ func TestMergePodSpecs(t *testing.T) { TerminationMessagePath: "/dev/primary-termination-log", } + defaultInitContainerTemplate := v1.Container{ + Name: defaultInitContainerTemplateName, + TerminationMessagePath: "/dev/default-init-termination-log", + } + + primaryInitContainerTemplate := v1.Container{ + Name: primaryInitContainerTemplateName, + TerminationMessagePath: "/dev/primary-init-termination-log", + } + podTemplateSpec := v1.PodSpec{ Containers: []v1.Container{ defaultContainerTemplate, primaryContainerTemplate, }, + InitContainers: []v1.Container{ + defaultInitContainerTemplate, + primaryInitContainerTemplate, + }, HostNetwork: true, NodeSelector: map[string]string{ "foo": "bar", @@ -2093,7 +2140,7 @@ func TestMergePodSpecs(t *testing.T) { }, } - mergedPodSpec, err := mergePodSpecs(&podTemplateSpec, &podSpec, "primary") + mergedPodSpec, err := mergePodSpecs(&podTemplateSpec, &podSpec, "primary", "primary-init") assert.Nil(t, err) // validate a PodTemplate-only field @@ -2117,6 +2164,17 @@ func TestMergePodSpecs(t *testing.T) { defaultContainer := mergedPodSpec.Containers[1] assert.Equal(t, podSpec.Containers[1].Name, defaultContainer.Name) assert.Equal(t, defaultContainerTemplate.TerminationMessagePath, defaultContainer.TerminationMessagePath) + + // validate primary init container + primaryInitContainer := mergedPodSpec.InitContainers[0] + assert.Equal(t, podSpec.InitContainers[0].Name, primaryInitContainer.Name) + assert.Equal(t, primaryInitContainerTemplate.TerminationMessagePath, primaryInitContainer.TerminationMessagePath) + assert.Equal(t, 1, len(primaryInitContainer.VolumeMounts)) + + // validate default init container + defaultInitContainer := mergedPodSpec.InitContainers[1] + assert.Equal(t, podSpec.InitContainers[1].Name, defaultInitContainer.Name) + assert.Equal(t, defaultInitContainerTemplate.TerminationMessagePath, defaultInitContainer.TerminationMessagePath) } func TestAddFlyteCustomizationsToContainer_SetConsoleUrl(t *testing.T) {