Skip to content

Commit

Permalink
support dsl.OneOf
Browse files Browse the repository at this point in the history
  • Loading branch information
connor-mccarthy committed Oct 7, 2023
1 parent 2b0dcf3 commit 35df06d
Show file tree
Hide file tree
Showing 20 changed files with 2,295 additions and 408 deletions.
1 change: 1 addition & 0 deletions sdk/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Features

## Breaking changes
* Support collecting outputs from conditional branches using `dsl.OneOf` [\#10067](https://github.com/kubeflow/pipelines/pull/10067)

## Deprecations

Expand Down
366 changes: 366 additions & 0 deletions sdk/python/kfp/compiler/compiler_test.py

Large diffs are not rendered by default.

243 changes: 153 additions & 90 deletions sdk/python/kfp/compiler/compiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
"""Utility methods for compiler implementation that is IR-agnostic."""

import collections
from copy import deepcopy
import copy
from typing import DefaultDict, Dict, List, Mapping, Set, Tuple, Union

from kfp import dsl
from kfp.dsl import for_loop
from kfp.dsl import pipeline_channel
from kfp.dsl import pipeline_context
Expand Down Expand Up @@ -258,10 +259,9 @@ def get_inputs_for_all_groups(
if isinstance(channel_to_add, pipeline_channel.PipelineChannel):
channels_to_add.append(channel_to_add)

if channel.task_name:
if channel.task:
# The PipelineChannel is produced by a task.

upstream_task = pipeline.tasks[channel.task_name]
upstream_task = channel.task
upstream_groups, downstream_groups = (
_get_uncommon_ancestors(
task_name_to_parent_groups=task_name_to_parent_groups,
Expand Down Expand Up @@ -462,46 +462,114 @@ def get_outputs_for_all_groups(
}

outputs = collections.defaultdict(dict)

processed_oneofs: Set[pipeline_channel.OneOfMixin] = set()
# handle dsl.Collected consumed by tasks
for task in pipeline.tasks.values():
for channel in task.channel_inputs:
if not isinstance(channel, for_loop.Collected):
continue
producer_task = pipeline.tasks[channel.task_name]
consumer_task = task

upstream_groups, downstream_groups = (
_get_uncommon_ancestors(
task_name_to_parent_groups=task_name_to_parent_groups,
group_name_to_parent_groups=group_name_to_parent_groups,
task1=producer_task,
task2=consumer_task,
))
validate_parallel_for_fan_in_consumption_legal(
consumer_task_name=consumer_task.name,
upstream_groups=upstream_groups,
group_name_to_group=group_name_to_group,
)
if isinstance(channel, dsl.Collected):
producer_task = pipeline.tasks[channel.task_name]
consumer_task = task

upstream_groups, downstream_groups = (
_get_uncommon_ancestors(
task_name_to_parent_groups=task_name_to_parent_groups,
group_name_to_parent_groups=group_name_to_parent_groups,
task1=producer_task,
task2=consumer_task,
))
validate_parallel_for_fan_in_consumption_legal(
consumer_task_name=consumer_task.name,
upstream_groups=upstream_groups,
group_name_to_group=group_name_to_group,
)

# producer_task's immediate parent group and the name by which
# to surface the channel
surfaced_output_name = additional_input_name_for_pipeline_channel(
channel)

# the highest-level task group that "consumes" the
# collected output
parent_consumer = downstream_groups[0]
producer_task_name = upstream_groups.pop()

# process from the upstream groups from the inside out
for upstream_name in reversed(upstream_groups):
outputs[upstream_name][
surfaced_output_name] = make_new_channel_for_collected_outputs(
channel_name=channel.name,
starting_channel=channel.output,
task_name=producer_task_name,
)

# on each iteration, mutate the channel being consumed so
# that it references the last parent group surfacer
channel.name = surfaced_output_name
channel.task_name = upstream_name

# for the next iteration, set the consumer to the current
# surfacer (parent group)
producer_task_name = upstream_name

# producer_task's immediate parent group and the name by which
# to surface the channel
parent_of_current_surfacer = group_name_to_parent_groups[
upstream_name][-2]
if parent_consumer in group_name_to_children[
parent_of_current_surfacer]:
break

elif isinstance(channel, pipeline_channel.OneOfMixin):
for inner_channel in channel.channels:
producer_task = pipeline.tasks[inner_channel.task_name]
consumer_task = task
upstream_groups, downstream_groups = (
_get_uncommon_ancestors(
task_name_to_parent_groups=task_name_to_parent_groups,
group_name_to_parent_groups=group_name_to_parent_groups,
task1=producer_task,
task2=consumer_task,
))
surfaced_output_name = additional_input_name_for_pipeline_channel(
inner_channel)

# 1. get the oneof
# 2. find the task group that surfaced it
# 3. find the inner tasks reponsible

for upstream_name in reversed(upstream_groups):
# skip the first task processed, since we don't need to add new outputs for the innermost task
if upstream_name == inner_channel.task.name:
continue
# # once we've hit the outermost condition-branches group, we're done
if upstream_name == channel.condition_branches_group.name:
outputs[upstream_name][channel.name] = channel
break

# copy so we can update the inner channel for the next iteration
# use copy not deepcopy, since deepcopy will needlessly copy the entire pipeline
# this uses more memory than needed and some objects are uncopiable
outputs[upstream_name][
surfaced_output_name] = copy.copy(inner_channel)

inner_channel.name = surfaced_output_name
inner_channel.task_name = upstream_name

processed_oneofs.add(channel)

# handle dsl.Collected returned from pipeline
# TODO: consider migrating dsl.Collected returns to pattern used by dsl.OneOf, where the OneOf constructor returns a parameter/artifact channel, which fits in more cleanly into the existing compiler abtractions
for output_key, channel in pipeline_outputs_dict.items():
if isinstance(channel, for_loop.Collected):
surfaced_output_name = additional_input_name_for_pipeline_channel(
channel)

# the highest-level task group that "consumes" the
# collected output
parent_consumer = downstream_groups[0]
upstream_groups = task_name_to_parent_groups[channel.task_name][1:]
producer_task_name = upstream_groups.pop()

# process from the upstream groups from the inside out
# process upstream groups from the inside out, until getting to the pipeline level
for upstream_name in reversed(upstream_groups):
outputs[upstream_name][
surfaced_output_name] = make_new_channel_for_collected_outputs(
channel_name=channel.name,
starting_channel=channel.output,
task_name=producer_task_name,
)
new_channel = make_new_channel_for_collected_outputs(
channel_name=channel.name,
starting_channel=channel.output,
task_name=producer_task_name,
)

# on each iteration, mutate the channel being consumed so
# that it references the last parent group surfacer
Expand All @@ -511,46 +579,46 @@ def get_outputs_for_all_groups(
# for the next iteration, set the consumer to the current
# surfacer (parent group)
producer_task_name = upstream_name

parent_of_current_surfacer = group_name_to_parent_groups[
upstream_name][-2]
if parent_consumer in group_name_to_children[
parent_of_current_surfacer]:
break

# handle dsl.Collected returned from pipeline
for output_key, channel in pipeline_outputs_dict.items():
if isinstance(channel, for_loop.Collected):
surfaced_output_name = additional_input_name_for_pipeline_channel(
channel)
outputs[upstream_name][surfaced_output_name] = new_channel

# after surfacing from all inner TasksGroup, change the PipelineChannel output to also return from the correct TasksGroup
pipeline_outputs_dict[
output_key] = make_new_channel_for_collected_outputs(
channel_name=surfaced_output_name,
starting_channel=channel.output,
task_name=upstream_name,
)
elif isinstance(channel, pipeline_channel.OneOfMixin):
# if the output has already been consumed by a task before it is returned, we don't need to reprocess it
if channel in processed_oneofs:
continue
for inner_channel in channel.channels:
producer_task = pipeline.tasks[inner_channel.task_name]
upstream_groups = task_name_to_parent_groups[
channel.task_name][1:]
producer_task_name = upstream_groups.pop()
# process upstream groups from the inside out, until getting to the pipeline level
for upstream_name in reversed(upstream_groups):
new_channel = make_new_channel_for_collected_outputs(
channel_name=channel.name,
starting_channel=channel.output,
task_name=producer_task_name,
)

# on each iteration, mutate the channel being consumed so
# that it references the last parent group surfacer
channel.name = surfaced_output_name
channel.task_name = upstream_name
inner_channel.task_name][1:]
surfaced_output_name = additional_input_name_for_pipeline_channel(
inner_channel)

# for the next iteration, set the consumer to the current
# surfacer (parent group)
producer_task_name = upstream_name
outputs[upstream_name][surfaced_output_name] = new_channel

# after surfacing from all inner TasksGroup, change the PipelineChannel output to also return from the correct TasksGroup
pipeline_outputs_dict[
output_key] = make_new_channel_for_collected_outputs(
channel_name=surfaced_output_name,
starting_channel=channel.output,
task_name=upstream_name,
)
# 1. get the oneof
# 2. find the task group that surfaced it
# 3. find the inner tasks reponsible
for upstream_name in reversed(upstream_groups):
# skip the first task processed, since we don't need to add new outputs for the innermost task
if upstream_name == inner_channel.task.name:
continue
# # once we've hit the outermost condition-branches group, we're done
if upstream_name == channel.condition_branches_group.name:
outputs[upstream_name][channel.name] = channel
break

# copy so we can update the inner channel for the next iteration
# use copy not deepcopy, since deepcopy will needlessly copy the entire pipeline
# this uses more memory than needed and some objects are uncopiable
outputs[upstream_name][surfaced_output_name] = copy.copy(
inner_channel)

inner_channel.name = surfaced_output_name
inner_channel.task_name = upstream_name
return outputs, pipeline_outputs_dict


Expand Down Expand Up @@ -633,22 +701,17 @@ def get_dependencies(
"""
dependencies = collections.defaultdict(set)
for task in pipeline.tasks.values():
upstream_task_names = set()
upstream_task_names: Set[Union[pipeline_task.PipelineTask,
tasks_group.TasksGroup]] = set()
task_condition_inputs = list(condition_channels[task.name])
for channel in task.channel_inputs + task_condition_inputs:
if channel.task_name:
upstream_task_names.add(channel.task_name)
upstream_task_names |= set(task.dependent_tasks)

for upstream_task_name in upstream_task_names:
# the dependent op could be either a BaseOp or an opsgroup
if upstream_task_name in pipeline.tasks:
upstream_task = pipeline.tasks[upstream_task_name]
elif upstream_task_name in group_name_to_group:
upstream_task = group_name_to_group[upstream_task_name]
else:
raise ValueError(
f'Compiler cannot find task: {upstream_task_name}.')
all_channels = task.channel_inputs + task_condition_inputs
upstream_task_names.update(
{channel.task for channel in all_channels if channel.task})
# dependent tasks is tasks on which .after was called and can only be the names of PipelineTasks, not TasksGroups
upstream_task_names.update(
{pipeline.tasks[after_task] for after_task in task.dependent_tasks})

for upstream_task in upstream_task_names:

upstream_groups, downstream_groups = _get_uncommon_ancestors(
task_name_to_parent_groups=task_name_to_parent_groups,
Expand All @@ -658,7 +721,7 @@ def get_dependencies(
)

# uncommon upstream ancestor check
uncommon_upstream_groups = deepcopy(upstream_groups)
uncommon_upstream_groups = copy.deepcopy(upstream_groups)
uncommon_upstream_groups.remove(
upstream_task.name
) # because a task's `upstream_groups` contains the task's name
Expand Down
Loading

0 comments on commit 35df06d

Please sign in to comment.