Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sdk): support collecting outputs from conditional branches using dsl.OneOf #10067

Merged
merged 3 commits into from
Oct 18, 2023

Conversation

connor-mccarthy
Copy link
Member

@connor-mccarthy connor-mccarthy commented Oct 6, 2023

Description of your changes:
Supports collecting outputs from conditional branches in a pipeline using dsl.OneOf. For example:

from kfp import dsl

@dsl.component
def flip_coin() -> str:
    import random
    return 'heads' if random.randint(0, 1) == 0 else 'tails'

@dsl.component
def print_and_return(text: str) -> str:
    print(text)
    return text

@dsl.pipeline
def flip_coin_pipeline() -> str:
    flip_coin_task = flip_coin()
    with dsl.If(flip_coin_task.output == 'heads'):
        print_task_1 = print_and_return(text='Got heads!')
    with dsl.Else():
        print_task_2 = print_and_return(text='Got tails!')

    # use the output from the branch that gets executed
    oneof = dsl.OneOf(print_task_1.output, print_task_2.output)

    # consume it
    print_and_return(text=oneof)

    # return it
    return oneof

There are three future work items to consider following this PR:

  1. Migrate the dsl.Collected implementation to the dsl.OneOf-style implementation. The abstraction for the dsl.OneOf is much cleaner and permits more implementation alignment and code reuse between dsl.OneOf and dsl.Collected. I left TODOs for this, opting not to implement in this CL to simplify the already large diff.
  2. After 1, we can consider (but may choose not to) support composing dsl.Collected into dsl.OneOf. This composability increases the complexity of the compiler code + authorable pipelines considerably (high implementation + maintenance cost). We should weigh this against the suspected benefit. Leaving this note here to document the choice and the likely dependency on 1.
  3. We can consider supporting dsl.OneOf and dsl.Collected in f-strings. dsl.Collected support in f-strings follows naturally from 1 above. dsl.OneOf support in f-strings requires a new way of injecting a channel into a user string in the compiler code such that all the member channels of a dsl.OneOf are not lost. The approach which I suspect will be most successful (included in a code comment) is to maintain a pipeline-level map of unique key to pipeline channel, then injecting/extracting that key from the user string, then looking up in the map. Similar level of effort consideration to 2 above.

Checklist:

@connor-mccarthy
Copy link
Member Author

/assign @chensun

sdk/python/kfp/compiler/compiler_test.py Outdated Show resolved Hide resolved
sdk/python/kfp/compiler/compiler_test.py Outdated Show resolved Hide resolved
sdk/python/kfp/compiler/compiler_test.py Show resolved Hide resolved
Copy link
Member

@chensun chensun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/lgtm

Only a few nitpicks and questions, nothing blocking.
Great work!

if not isinstance(channel, expected_type):
raise TypeError(
f'Task outputs passed to dsl.{OneOf.__name__} must be the same type. Got two channels with different types: {readable_name_map[expected_type]} at index 0 and {readable_name_map[type(channel)]} at index {i}.'
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this function reads a bit complex, while the essential check is all channels are of the same type. Is the complexity to allow specific error message? Not sure if user would understand what index # refers to.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I simplified the logic a bit and the error message. Please feel free to re-open if there are more changes you suggest.

) -> None:
# avoid circular imports
from kfp.dsl import tasks_group
if not isinstance(parent_group.groups[-1], tasks_group.Else):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean users have to have an ending dsl.Else?
So the following is illegal?

dsl.If(a == True):
     f = foo()
dsl.Elif(a == False):
     b = bar()
oneof = dsl.OneOf(f.output, b.output)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly.

For posterity, I'll leave a few notes to document my thinking:

  • This is aligned with the current treatment of value_from_oneof field in PipelineSpec on Vertex Pipelines, which is a backend that supports such collection.
  • If we want different behavior... We can, of course, prescribe new behavior for any backend and drop this validation in the KFP SDK. Since the current implementation is forward compatible with looser validation, my preference is to proceed with the strong validation and listen for user demands to loosen. This helps us get this feature in the hands of users faster with effectively no added cost for users or maintainers.
  • If we like this behavior, but want to permit more flexibility... It's technically possible that a user might want to use only dsl.If and dsl.Elif and ensure the conditions are mutually exclusive themselves (no dsl.Else). Since this comes with the cost of shifting errors right (to runtime, instead of compilation time as in the current implementation) and because this seems like it wouldn't be the typical case, my preference is again to start conservatively with the current approach.

# the combination of OneOf and an f-string is not common, so prefer
# deferring implementation
raise NotImplementedError(
f'dsl.{OneOf.__name__} is not yet supported in f-strings.')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Technically f-string is just one special syntax for string concatenation, while this limitation would apply to all formats, for example, "{} world".format(str1), "Hello" + str1, etc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch. Updated.

# isinstance(<some channel>, PipelineParameterChannel/PipelineArtifactChannel)
# checks continue to behave in desirable ways
class OneOfParameter(PipelineParameterChannel, OneOfMixin):
"""OneOf that results in an parameter channel for all downstream tasks."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: a parameter channel

@@ -150,6 +153,16 @@ def __init__(self) -> None:
is_root=False,
)

def _get_oneof_id(self) -> int:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this seems to be an public method, so maybe no underscore prefix?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, makes sense.

@google-oss-prow google-oss-prow bot added the lgtm label Oct 17, 2023
@google-oss-prow google-oss-prow bot removed the lgtm label Oct 18, 2023
Copy link
Member Author

@connor-mccarthy connor-mccarthy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review, @chensun. Addressed your comments.

# the combination of OneOf and an f-string is not common, so prefer
# deferring implementation
raise NotImplementedError(
f'dsl.{OneOf.__name__} is not yet supported in f-strings.')
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch. Updated.

) -> None:
# avoid circular imports
from kfp.dsl import tasks_group
if not isinstance(parent_group.groups[-1], tasks_group.Else):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly.

For posterity, I'll leave a few notes to document my thinking:

  • This is aligned with the current treatment of value_from_oneof field in PipelineSpec on Vertex Pipelines, which is a backend that supports such collection.
  • If we want different behavior... We can, of course, prescribe new behavior for any backend and drop this validation in the KFP SDK. Since the current implementation is forward compatible with looser validation, my preference is to proceed with the strong validation and listen for user demands to loosen. This helps us get this feature in the hands of users faster with effectively no added cost for users or maintainers.
  • If we like this behavior, but want to permit more flexibility... It's technically possible that a user might want to use only dsl.If and dsl.Elif and ensure the conditions are mutually exclusive themselves (no dsl.Else). Since this comes with the cost of shifting errors right (to runtime, instead of compilation time as in the current implementation) and because this seems like it wouldn't be the typical case, my preference is again to start conservatively with the current approach.

if not isinstance(channel, expected_type):
raise TypeError(
f'Task outputs passed to dsl.{OneOf.__name__} must be the same type. Got two channels with different types: {readable_name_map[expected_type]} at index 0 and {readable_name_map[type(channel)]} at index {i}.'
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I simplified the logic a bit and the error message. Please feel free to re-open if there are more changes you suggest.

@@ -150,6 +153,16 @@ def __init__(self) -> None:
is_root=False,
)

def _get_oneof_id(self) -> int:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, makes sense.

Copy link
Member

@chensun chensun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/lgtm
/approve

Thanks!

@google-oss-prow google-oss-prow bot added the lgtm label Oct 18, 2023
@google-oss-prow
Copy link

[APPROVALNOTIFIER] This PR is APPROVED

This pull-request has been approved by: chensun

The full list of commands accepted by this bot can be found here.

The pull request process is described here

Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment

@google-oss-prow google-oss-prow bot merged commit 2d3171c into kubeflow:master Oct 18, 2023
1 check passed
connor-mccarthy added a commit to connor-mccarthy/pipelines that referenced this pull request Oct 19, 2023
connor-mccarthy added a commit to connor-mccarthy/pipelines that referenced this pull request Oct 19, 2023
connor-mccarthy added a commit to connor-mccarthy/pipelines that referenced this pull request Oct 19, 2023
stijntratsaertit pushed a commit to stijntratsaertit/kfp that referenced this pull request Feb 16, 2024
… `dsl.OneOf` (kubeflow#10067)

* support dsl.OneOf

* address review feedback

* address review feedback
stijntratsaertit pushed a commit to stijntratsaertit/kfp that referenced this pull request Feb 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants