Skip to content

Commit

Permalink
feat(sdk): support collecting outputs from conditional branches using…
Browse files Browse the repository at this point in the history
… `dsl.OneOf` (kubeflow#10067)

* support dsl.OneOf

* address review feedback

* address review feedback
  • Loading branch information
connor-mccarthy authored and stijntratsaertit committed Feb 16, 2024
1 parent 124686b commit 4106993
Show file tree
Hide file tree
Showing 20 changed files with 2,742 additions and 408 deletions.
1 change: 1 addition & 0 deletions sdk/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Features

## Breaking changes
* Support collecting outputs from conditional branches using `dsl.OneOf` [\#10067](https://github.com/kubeflow/pipelines/pull/10067)

## Deprecations

Expand Down
799 changes: 799 additions & 0 deletions sdk/python/kfp/compiler/compiler_test.py

Large diffs are not rendered by default.

247 changes: 157 additions & 90 deletions sdk/python/kfp/compiler/compiler_utils.py

Large diffs are not rendered by default.

122 changes: 116 additions & 6 deletions sdk/python/kfp/compiler/pipeline_spec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,12 +417,13 @@ def _build_component_spec_from_component_spec_structure(
return component_spec


def _connect_dag_outputs(
def connect_single_dag_output(
component_spec: pipeline_spec_pb2.ComponentSpec,
output_name: str,
output_channel: pipeline_channel.PipelineChannel,
) -> None:
"""Connects dag output to a subtask output.
"""Connects a DAG output to a subtask output when the subtask output
contains only one channel (i.e., not OneOfMixin).
Args:
component_spec: The component spec to modify its dag outputs.
Expand Down Expand Up @@ -451,14 +452,71 @@ def _connect_dag_outputs(
output_name].value_from_parameter.output_parameter_key = output_channel.name


def connect_oneof_dag_output(
component_spec: pipeline_spec_pb2.ComponentSpec,
output_name: str,
oneof_output: pipeline_channel.OneOfMixin,
) -> None:
"""Connects a output to the OneOf output returned by the DAG's internal
condition-branches group.
Args:
component_spec: The component spec to modify its DAG outputs.
output_name: The name of the DAG output.
oneof_output: The OneOfMixin object returned by the pipeline (OneOf in user code).
"""
if isinstance(oneof_output, pipeline_channel.OneOfArtifact):
if output_name not in component_spec.output_definitions.artifacts:
raise ValueError(
f'Pipeline or component output not defined: {output_name}. You may be missing a type annotation.'
)
for channel in oneof_output.channels:
component_spec.dag.outputs.artifacts[
output_name].artifact_selectors.append(
pipeline_spec_pb2.DagOutputsSpec.ArtifactSelectorSpec(
producer_subtask=channel.task_name,
output_artifact_key=channel.name,
))
if isinstance(oneof_output, pipeline_channel.OneOfParameter):
if output_name not in component_spec.output_definitions.parameters:
raise ValueError(
f'Pipeline or component output not defined: {output_name}. You may be missing a type annotation.'
)
for channel in oneof_output.channels:
component_spec.dag.outputs.parameters[
output_name].value_from_oneof.parameter_selectors.append(
pipeline_spec_pb2.DagOutputsSpec.ParameterSelectorSpec(
producer_subtask=channel.task_name,
output_parameter_key=channel.name,
))


def _build_dag_outputs(
component_spec: pipeline_spec_pb2.ComponentSpec,
dag_outputs: Dict[str, pipeline_channel.PipelineChannel],
) -> None:
"""Builds DAG output spec."""
"""Connects the DAG's outputs to a TaskGroup's ComponentSpec and validates
it is present in the component interface.
Args:
component_spec: The ComponentSpec.
dag_outputs: Dictionary of output key to output channel.
"""
for output_name, output_channel in dag_outputs.items():
_connect_dag_outputs(component_spec, output_name, output_channel)
# Valid dag outputs covers all outptus in component definition.
if not isinstance(output_channel, pipeline_channel.PipelineChannel):
raise ValueError(
f"Got unknown pipeline output '{output_name}' of type {output_channel}."
)
connect_single_dag_output(component_spec, output_name, output_channel)

validate_dag_outputs(component_spec)


def validate_dag_outputs(
component_spec: pipeline_spec_pb2.ComponentSpec) -> None:
"""Validates the DAG's ComponentSpec specifies the source task for all of
its ComponentSpec inputs (input_definitions) and outputs
(output_definitions)."""
for output_name in component_spec.output_definitions.artifacts:
if output_name not in component_spec.dag.outputs.artifacts:
raise ValueError(f'Missing pipeline output: {output_name}.')
Expand All @@ -467,6 +525,31 @@ def _build_dag_outputs(
raise ValueError(f'Missing pipeline output: {output_name}.')


def build_oneof_dag_outputs(
component_spec: pipeline_spec_pb2.ComponentSpec,
oneof_outputs: Dict[str, pipeline_channel.OneOfMixin],
) -> None:
"""Connects the DAG's OneOf outputs to a TaskGroup's ComponentSpec and
validates it is present in the component interface.
Args:
component_spec: The ComponentSpec.
oneof_outputs: Dictionary of output key to OneOf output channel.
"""
for output_name, oneof_output in oneof_outputs.items():
for channel in oneof_output.channels:
if not isinstance(channel, pipeline_channel.PipelineChannel):
raise ValueError(
f"Got unknown pipeline output '{output_name}' of type {type(channel)}."
)
connect_oneof_dag_output(
component_spec,
output_name,
oneof_output,
)
validate_dag_outputs(component_spec)


def build_importer_spec_for_task(
task: pipeline_task.PipelineTask
) -> pipeline_spec_pb2.PipelineDeploymentConfig.ImporterSpec:
Expand Down Expand Up @@ -1290,7 +1373,7 @@ def build_spec_by_group(
elif isinstance(subgroup, tasks_group.ConditionBranches):
subgroup_component_spec = build_component_spec_for_group(
input_pipeline_channels=subgroup_input_channels,
output_pipeline_channels={},
output_pipeline_channels=subgroup_output_channels,
)

subgroup_task_spec = build_task_spec_for_group(
Expand All @@ -1299,6 +1382,9 @@ def build_spec_by_group(
tasks_in_current_dag=tasks_in_current_dag,
is_parent_component_root=is_parent_component_root,
)
# oneof is the only type of output a ConditionBranches group can have
build_oneof_dag_outputs(subgroup_component_spec,
subgroup_output_channels)

else:
raise RuntimeError(
Expand Down Expand Up @@ -1706,6 +1792,28 @@ def _rename_component_refs(
old_name_to_new_name[old_component_name]].CopyFrom(component_spec)


def validate_pipeline_outputs_dict(
pipeline_outputs_dict: Dict[str, pipeline_channel.PipelineChannel]):
for channel in pipeline_outputs_dict.values():
if isinstance(channel, for_loop.Collected):
# this validation doesn't apply to Collected
continue

elif isinstance(channel, pipeline_channel.OneOfMixin):
if channel.condition_branches_group.parent_task_group.group_type != tasks_group.TasksGroupType.PIPELINE:
raise compiler_utils.InvalidTopologyException(
f'Pipeline outputs may only be returned from the top level of the pipeline function scope. Got pipeline output dsl.{pipeline_channel.OneOf.__name__} from within the control flow group dsl.{channel.condition_branches_group.parent_task_group.__class__.__name__}.'
)

elif isinstance(channel, pipeline_channel.PipelineChannel):
if channel.task.parent_task_group.group_type != tasks_group.TasksGroupType.PIPELINE:
raise compiler_utils.InvalidTopologyException(
f'Pipeline outputs may only be returned from the top level of the pipeline function scope. Got pipeline output from within the control flow group dsl.{channel.task.parent_task_group.__class__.__name__}.'
)
else:
raise ValueError(f'Got unknown pipeline output: {channel}.')


def create_pipeline_spec(
pipeline: pipeline_context.Pipeline,
component_spec: structures.ComponentSpec,
Expand Down Expand Up @@ -1741,6 +1849,8 @@ def create_pipeline_spec(
# an output from a task in a condition group, for example, which isn't
# caught until submission time using Vertex SDK client
pipeline_outputs_dict = convert_pipeline_outputs_to_dict(pipeline_outputs)
validate_pipeline_outputs_dict(pipeline_outputs_dict)

root_group = pipeline.groups[0]

all_groups = compiler_utils.get_all_groups(root_group)
Expand Down
3 changes: 3 additions & 0 deletions sdk/python/kfp/dsl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,10 @@ def my_pipeline():
if os.environ.get('_KFP_RUNTIME', 'false') != 'true':
from kfp.dsl.component_decorator import component
from kfp.dsl.container_component_decorator import container_component
# TODO: Collected should be moved to pipeline_channel.py, consistent with OneOf
from kfp.dsl.for_loop import Collected
from kfp.dsl.importer_node import importer
from kfp.dsl.pipeline_channel import OneOf
from kfp.dsl.pipeline_context import pipeline
from kfp.dsl.pipeline_task import PipelineTask
from kfp.dsl.placeholders import ConcatPlaceholder
Expand All @@ -252,6 +254,7 @@ def my_pipeline():
'If',
'Elif',
'Else',
'OneOf',
'ExitHandler',
'ParallelFor',
'Collected',
Expand Down
11 changes: 11 additions & 0 deletions sdk/python/kfp/dsl/for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def _get_name_override(self, loop_arg_name: str, subvar_name: str) -> str:
return f'{loop_arg_name}{self.SUBVAR_NAME_DELIMITER}{subvar_name}'


# TODO: migrate Collected to OneOfMixin style implementation
class Collected(pipeline_channel.PipelineChannel):
"""For collecting into a list the output from a task in dsl.ParallelFor
loops.
Expand Down Expand Up @@ -313,3 +314,13 @@ def __init__(
channel_type=channel_type,
task_name=output.task_name,
)
self._validate_no_oneof_channel(self.output)

def _validate_no_oneof_channel(
self, channel: Union[pipeline_channel.PipelineParameterChannel,
pipeline_channel.PipelineArtifactChannel]
) -> None:
if isinstance(channel, pipeline_channel.OneOfMixin):
raise ValueError(
f'dsl.{pipeline_channel.OneOf.__name__} cannot be used inside of dsl.{Collected.__name__}.'
)
Loading

0 comments on commit 4106993

Please sign in to comment.