diff --git a/backend/src/v2/driver/driver.go b/backend/src/v2/driver/driver.go index 960ec6148e8..9d392d1175f 100644 --- a/backend/src/v2/driver/driver.go +++ b/backend/src/v2/driver/driver.go @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + package driver import ( @@ -448,19 +449,28 @@ func initPodSpecPatch( accelerator := container.GetResources().GetAccelerator() if accelerator != nil { if accelerator.GetType() != "" && accelerator.GetCount() > 0 { - q, err := k8sres.ParseQuantity(fmt.Sprintf("%v", accelerator.GetCount())) - if err != nil { - return nil, fmt.Errorf("failed to init podSpecPatch: %w", err) + acceleratorType, err1 := resolvePodSpecRuntimeParameter(accelerator.GetType(), executorInput) + if err1 != nil { + return nil, err1 } - res.Limits[k8score.ResourceName(accelerator.GetType())] = q + q, err1 := k8sres.ParseQuantity(fmt.Sprintf("%v", accelerator.GetCount())) + if err1 != nil { + return nil, fmt.Errorf("failed to init podSpecPatch: %w", err1) + } + res.Limits[k8score.ResourceName(acceleratorType)] = q } } + + containerImage, err := resolvePodSpecRuntimeParameter(container.Image, executorInput) + if err != nil { + return nil, err + } podSpec := &k8score.PodSpec{ Containers: []k8score.Container{{ Name: "main", // argo task user container is always called "main" Command: launcherCmd, Args: userCmdArgs, - Image: container.Image, + Image: containerImage, Resources: res, Env: userEnvVar, }}, diff --git a/backend/src/v2/driver/util.go b/backend/src/v2/driver/util.go new file mode 100644 index 00000000000..d124d463cee --- /dev/null +++ b/backend/src/v2/driver/util.go @@ -0,0 +1,78 @@ +// Copyright 2021-2024 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package driver + +import ( + "fmt" + "github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec" + "regexp" +) + +// InputPipelineChannelPattern define a regex pattern to match the content within single quotes +// example input channel looks like "{{$.inputs.parameters['pipelinechannel--val']}}" +const InputPipelineChannelPattern = `\$.inputs.parameters\['(.+?)'\]` + +func isInputParameterChannel(inputChannel string) bool { + re := regexp.MustCompile(InputPipelineChannelPattern) + match := re.FindStringSubmatch(inputChannel) + if len(match) == 2 { + return true + } else { + // if len(match) > 2, then this is still incorrect because + // inputChannel should contain only one parameter channel input + return false + } +} + +// extractInputParameterFromChannel takes an inputChannel that adheres to +// InputPipelineChannelPattern and extracts the channel parameter name. +// For example given an input channel of the form "{{$.inputs.parameters['pipelinechannel--val']}}" +// the channel parameter name "pipelinechannel--val" is returned. +func extractInputParameterFromChannel(inputChannel string) (string, error) { + re := regexp.MustCompile(InputPipelineChannelPattern) + match := re.FindStringSubmatch(inputChannel) + if len(match) > 1 { + extractedValue := match[1] + return extractedValue, nil + } else { + return "", fmt.Errorf("failed to extract input parameter from channel: %s", inputChannel) + } +} + +// resolvePodSpecRuntimeParameter resolves runtime value that is intended to be +// utilized within the Pod Spec. parameterValue takes the form of: +// "{{$.inputs.parameters['pipelinechannel--someParameterName']}}" +// +// parameterValue is a runtime parameter value that has been resolved and included within +// the executor input. Since the pod spec patch cannot dynamically update the underlying +// container template's inputs in an Argo Workflow, this is a workaround for resolving +// such parameters. +// +// If parameter value is not a parameter channel, then a constant value is assumed and +// returned as is. +func resolvePodSpecRuntimeParameter(parameterValue string, executorInput *pipelinespec.ExecutorInput) (string, error) { + if isInputParameterChannel(parameterValue) { + inputImage, err1 := extractInputParameterFromChannel(parameterValue) + if err1 != nil { + return "", err1 + } + if val, ok := executorInput.Inputs.ParameterValues[inputImage]; ok { + return val.GetStringValue(), nil + } else { + return "", fmt.Errorf("executorInput did not contain container Image input parameter") + } + } + return parameterValue, nil +} diff --git a/sdk/python/kfp/compiler/compiler_test.py b/sdk/python/kfp/compiler/compiler_test.py index 2433f09bc6d..b09b1b52c5d 100644 --- a/sdk/python/kfp/compiler/compiler_test.py +++ b/sdk/python/kfp/compiler/compiler_test.py @@ -909,6 +909,70 @@ def my_pipeline() -> NamedTuple('Outputs', [ ]): task = print_and_return(text='Hello') + def test_pipeline_with_parameterized_container_image(self): + with tempfile.TemporaryDirectory() as tmpdir: + + @dsl.component(base_image='docker.io/python:3.9.17') + def empty_component(): + pass + + @dsl.pipeline() + def simple_pipeline(img: str): + task = empty_component() + # overwrite base_image="docker.io/python:3.9.17" + task.set_container_image(img) + + output_yaml = os.path.join(tmpdir, 'result.yaml') + compiler.Compiler().compile( + pipeline_func=simple_pipeline, + package_path=output_yaml, + pipeline_parameters={'img': 'someimage'}) + self.assertTrue(os.path.exists(output_yaml)) + + with open(output_yaml, 'r') as f: + pipeline_spec = yaml.safe_load(f) + container = pipeline_spec['deploymentSpec']['executors'][ + 'exec-empty-component']['container'] + self.assertEqual( + container['image'], + "{{$.inputs.parameters['pipelinechannel--img']}}") + # A parameter value should result in 2 input parameters + # One for storing pipeline channel template to be resolved during runtime. + # Two for holding the key to the resolved input. + input_parameters = pipeline_spec['root']['dag']['tasks'][ + 'empty-component']['inputs']['parameters'] + self.assertTrue('base_image' in input_parameters) + self.assertTrue('pipelinechannel--img' in input_parameters) + + def test_pipeline_with_constant_container_image(self): + with tempfile.TemporaryDirectory() as tmpdir: + + @dsl.component(base_image='docker.io/python:3.9.17') + def empty_component(): + pass + + @dsl.pipeline() + def simple_pipeline(): + task = empty_component() + # overwrite base_image="docker.io/python:3.9.17" + task.set_container_image('constant-value') + + output_yaml = os.path.join(tmpdir, 'result.yaml') + compiler.Compiler().compile( + pipeline_func=simple_pipeline, package_path=output_yaml) + + self.assertTrue(os.path.exists(output_yaml)) + + with open(output_yaml, 'r') as f: + pipeline_spec = yaml.safe_load(f) + container = pipeline_spec['deploymentSpec']['executors'][ + 'exec-empty-component']['container'] + self.assertEqual(container['image'], 'constant-value') + # A constant value should yield no parameters + dag_task = pipeline_spec['root']['dag']['tasks'][ + 'empty-component'] + self.assertTrue('inputs' not in dag_task) + class TestCompilePipelineCaching(unittest.TestCase): diff --git a/sdk/python/kfp/compiler/pipeline_spec_builder.py b/sdk/python/kfp/compiler/pipeline_spec_builder.py index afc014530fa..ffd1871bc2a 100644 --- a/sdk/python/kfp/compiler/pipeline_spec_builder.py +++ b/sdk/python/kfp/compiler/pipeline_spec_builder.py @@ -135,6 +135,11 @@ def build_task_spec_for_task( if val and pipeline_channel.extract_pipeline_channels_from_any(val): task.inputs[key] = val + if task.container_spec and task.container_spec.image: + val = task.container_spec.image + if val and pipeline_channel.extract_pipeline_channels_from_any(val): + task.inputs['base_image'] = val + for input_name, input_value in task.inputs.items(): # Since LoopParameterArgument and LoopArtifactArgument and LoopArgumentVariable are narrower # types than PipelineParameterChannel, start with them. @@ -634,7 +639,7 @@ def convert_to_placeholder(input_value: str) -> str: container_spec = ( pipeline_spec_pb2.PipelineDeploymentConfig.PipelineContainerSpec( - image=task.container_spec.image, + image=convert_to_placeholder(task.container_spec.image), command=task.container_spec.command, args=task.container_spec.args, env=[ diff --git a/sdk/python/kfp/dsl/pipeline_task.py b/sdk/python/kfp/dsl/pipeline_task.py index 822f5520788..2fccd7ad47b 100644 --- a/sdk/python/kfp/dsl/pipeline_task.py +++ b/sdk/python/kfp/dsl/pipeline_task.py @@ -631,6 +631,17 @@ def set_env_variable(self, name: str, value: str) -> 'PipelineTask': self.container_spec.env = {name: value} return self + @block_if_final() + def set_container_image( + self, + name: Union[str, + pipeline_channel.PipelineChannel]) -> 'PipelineTask': + self._ensure_container_spec_exists() + if isinstance(name, pipeline_channel.PipelineChannel): + name = str(name) + self.container_spec.image = name + return self + @block_if_final() def after(self, *tasks) -> 'PipelineTask': """Specifies an explicit dependency on other tasks by requiring this