Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sdk): support collecting outputs from conditional branches using dsl.OneOf #10067

Merged
merged 3 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Features

## Breaking changes
* Support collecting outputs from conditional branches using `dsl.OneOf` [\#10067](https://github.com/kubeflow/pipelines/pull/10067)

## Deprecations

Expand Down
799 changes: 799 additions & 0 deletions sdk/python/kfp/compiler/compiler_test.py

Large diffs are not rendered by default.

247 changes: 157 additions & 90 deletions sdk/python/kfp/compiler/compiler_utils.py

Large diffs are not rendered by default.

122 changes: 116 additions & 6 deletions sdk/python/kfp/compiler/pipeline_spec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,12 +417,13 @@ def _build_component_spec_from_component_spec_structure(
return component_spec


def _connect_dag_outputs(
def connect_single_dag_output(
component_spec: pipeline_spec_pb2.ComponentSpec,
output_name: str,
output_channel: pipeline_channel.PipelineChannel,
) -> None:
"""Connects dag output to a subtask output.
"""Connects a DAG output to a subtask output when the subtask output
contains only one channel (i.e., not OneOfMixin).

Args:
component_spec: The component spec to modify its dag outputs.
Expand Down Expand Up @@ -451,14 +452,71 @@ def _connect_dag_outputs(
output_name].value_from_parameter.output_parameter_key = output_channel.name


def connect_oneof_dag_output(
component_spec: pipeline_spec_pb2.ComponentSpec,
output_name: str,
oneof_output: pipeline_channel.OneOfMixin,
) -> None:
"""Connects a output to the OneOf output returned by the DAG's internal
condition-branches group.

Args:
component_spec: The component spec to modify its DAG outputs.
output_name: The name of the DAG output.
oneof_output: The OneOfMixin object returned by the pipeline (OneOf in user code).
"""
if isinstance(oneof_output, pipeline_channel.OneOfArtifact):
if output_name not in component_spec.output_definitions.artifacts:
raise ValueError(
f'Pipeline or component output not defined: {output_name}. You may be missing a type annotation.'
)
for channel in oneof_output.channels:
component_spec.dag.outputs.artifacts[
output_name].artifact_selectors.append(
pipeline_spec_pb2.DagOutputsSpec.ArtifactSelectorSpec(
producer_subtask=channel.task_name,
output_artifact_key=channel.name,
))
if isinstance(oneof_output, pipeline_channel.OneOfParameter):
if output_name not in component_spec.output_definitions.parameters:
raise ValueError(
f'Pipeline or component output not defined: {output_name}. You may be missing a type annotation.'
)
for channel in oneof_output.channels:
component_spec.dag.outputs.parameters[
output_name].value_from_oneof.parameter_selectors.append(
pipeline_spec_pb2.DagOutputsSpec.ParameterSelectorSpec(
producer_subtask=channel.task_name,
output_parameter_key=channel.name,
))


def _build_dag_outputs(
component_spec: pipeline_spec_pb2.ComponentSpec,
dag_outputs: Dict[str, pipeline_channel.PipelineChannel],
) -> None:
"""Builds DAG output spec."""
"""Connects the DAG's outputs to a TaskGroup's ComponentSpec and validates
it is present in the component interface.

Args:
component_spec: The ComponentSpec.
dag_outputs: Dictionary of output key to output channel.
"""
for output_name, output_channel in dag_outputs.items():
_connect_dag_outputs(component_spec, output_name, output_channel)
# Valid dag outputs covers all outptus in component definition.
if not isinstance(output_channel, pipeline_channel.PipelineChannel):
raise ValueError(
f"Got unknown pipeline output '{output_name}' of type {output_channel}."
)
connect_single_dag_output(component_spec, output_name, output_channel)

validate_dag_outputs(component_spec)


def validate_dag_outputs(
component_spec: pipeline_spec_pb2.ComponentSpec) -> None:
"""Validates the DAG's ComponentSpec specifies the source task for all of
its ComponentSpec inputs (input_definitions) and outputs
(output_definitions)."""
for output_name in component_spec.output_definitions.artifacts:
if output_name not in component_spec.dag.outputs.artifacts:
raise ValueError(f'Missing pipeline output: {output_name}.')
Expand All @@ -467,6 +525,31 @@ def _build_dag_outputs(
raise ValueError(f'Missing pipeline output: {output_name}.')


def build_oneof_dag_outputs(
component_spec: pipeline_spec_pb2.ComponentSpec,
oneof_outputs: Dict[str, pipeline_channel.OneOfMixin],
) -> None:
"""Connects the DAG's OneOf outputs to a TaskGroup's ComponentSpec and
validates it is present in the component interface.

Args:
component_spec: The ComponentSpec.
oneof_outputs: Dictionary of output key to OneOf output channel.
"""
for output_name, oneof_output in oneof_outputs.items():
for channel in oneof_output.channels:
if not isinstance(channel, pipeline_channel.PipelineChannel):
raise ValueError(
f"Got unknown pipeline output '{output_name}' of type {type(channel)}."
)
connect_oneof_dag_output(
component_spec,
output_name,
oneof_output,
)
validate_dag_outputs(component_spec)


def build_importer_spec_for_task(
task: pipeline_task.PipelineTask
) -> pipeline_spec_pb2.PipelineDeploymentConfig.ImporterSpec:
Expand Down Expand Up @@ -1290,7 +1373,7 @@ def build_spec_by_group(
elif isinstance(subgroup, tasks_group.ConditionBranches):
subgroup_component_spec = build_component_spec_for_group(
input_pipeline_channels=subgroup_input_channels,
output_pipeline_channels={},
output_pipeline_channels=subgroup_output_channels,
)

subgroup_task_spec = build_task_spec_for_group(
Expand All @@ -1299,6 +1382,9 @@ def build_spec_by_group(
tasks_in_current_dag=tasks_in_current_dag,
is_parent_component_root=is_parent_component_root,
)
# oneof is the only type of output a ConditionBranches group can have
build_oneof_dag_outputs(subgroup_component_spec,
subgroup_output_channels)

else:
raise RuntimeError(
Expand Down Expand Up @@ -1694,6 +1780,28 @@ def _rename_component_refs(
old_name_to_new_name[old_component_name]].CopyFrom(component_spec)


def validate_pipeline_outputs_dict(
pipeline_outputs_dict: Dict[str, pipeline_channel.PipelineChannel]):
for channel in pipeline_outputs_dict.values():
if isinstance(channel, for_loop.Collected):
# this validation doesn't apply to Collected
continue

elif isinstance(channel, pipeline_channel.OneOfMixin):
if channel.condition_branches_group.parent_task_group.group_type != tasks_group.TasksGroupType.PIPELINE:
raise compiler_utils.InvalidTopologyException(
f'Pipeline outputs may only be returned from the top level of the pipeline function scope. Got pipeline output dsl.{pipeline_channel.OneOf.__name__} from within the control flow group dsl.{channel.condition_branches_group.parent_task_group.__class__.__name__}.'
)

elif isinstance(channel, pipeline_channel.PipelineChannel):
if channel.task.parent_task_group.group_type != tasks_group.TasksGroupType.PIPELINE:
raise compiler_utils.InvalidTopologyException(
f'Pipeline outputs may only be returned from the top level of the pipeline function scope. Got pipeline output from within the control flow group dsl.{channel.task.parent_task_group.__class__.__name__}.'
)
else:
raise ValueError(f'Got unknown pipeline output: {channel}.')


def create_pipeline_spec(
pipeline: pipeline_context.Pipeline,
component_spec: structures.ComponentSpec,
Expand Down Expand Up @@ -1729,6 +1837,8 @@ def create_pipeline_spec(
# an output from a task in a condition group, for example, which isn't
# caught until submission time using Vertex SDK client
pipeline_outputs_dict = convert_pipeline_outputs_to_dict(pipeline_outputs)
validate_pipeline_outputs_dict(pipeline_outputs_dict)

root_group = pipeline.groups[0]

all_groups = compiler_utils.get_all_groups(root_group)
Expand Down
3 changes: 3 additions & 0 deletions sdk/python/kfp/dsl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,10 @@ def my_pipeline():
if os.environ.get('_KFP_RUNTIME', 'false') != 'true':
from kfp.dsl.component_decorator import component
from kfp.dsl.container_component_decorator import container_component
# TODO: Collected should be moved to pipeline_channel.py, consistent with OneOf
from kfp.dsl.for_loop import Collected
from kfp.dsl.importer_node import importer
from kfp.dsl.pipeline_channel import OneOf
from kfp.dsl.pipeline_context import pipeline
from kfp.dsl.pipeline_task import PipelineTask
from kfp.dsl.placeholders import ConcatPlaceholder
Expand All @@ -252,6 +254,7 @@ def my_pipeline():
'If',
'Elif',
'Else',
'OneOf',
'ExitHandler',
'ParallelFor',
'Collected',
Expand Down
11 changes: 11 additions & 0 deletions sdk/python/kfp/dsl/for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def _get_name_override(self, loop_arg_name: str, subvar_name: str) -> str:
return f'{loop_arg_name}{self.SUBVAR_NAME_DELIMITER}{subvar_name}'


# TODO: migrate Collected to OneOfMixin style implementation
class Collected(pipeline_channel.PipelineChannel):
"""For collecting into a list the output from a task in dsl.ParallelFor
loops.
Expand Down Expand Up @@ -313,3 +314,13 @@ def __init__(
channel_type=channel_type,
task_name=output.task_name,
)
self._validate_no_oneof_channel(self.output)

def _validate_no_oneof_channel(
self, channel: Union[pipeline_channel.PipelineParameterChannel,
pipeline_channel.PipelineArtifactChannel]
) -> None:
if isinstance(channel, pipeline_channel.OneOfMixin):
raise ValueError(
f'dsl.{pipeline_channel.OneOf.__name__} cannot be used inside of dsl.{Collected.__name__}.'
)
Loading