Skip to content

Commit

Permalink
address review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
connor-mccarthy committed Sep 8, 2023
1 parent cdd8e11 commit e750768
Show file tree
Hide file tree
Showing 8 changed files with 595 additions and 176 deletions.
6 changes: 3 additions & 3 deletions sdk/python/kfp/compiler/compiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ def get_channels_from_condition(
operations: List[pipeline_channel.BinaryOperation],
collected_channels: list,
) -> None:
"""Append to collected_channels each pipeline channels used in each operand
of each operation in operations."""
"""Appends to collected_channels each pipeline channels used in each
operand of each operation in operations."""
for operation in operations:
for operand in [operation.left_operand, operation.right_operand]:
if isinstance(operand, pipeline_channel.PipelineChannel):
Expand Down Expand Up @@ -153,7 +153,7 @@ def _get_condition_channels_for_tasks_helper(
if isinstance(group, tasks_group._ConditionBase):
new_current_conditions_channels = list(current_conditions_channels)
get_channels_from_condition(
group.condition,
group.conditions,
new_current_conditions_channels,
)

Expand Down
4 changes: 2 additions & 2 deletions sdk/python/kfp/compiler/pipeline_spec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ def _update_task_spec_for_condition_group(
group: The condition group to update task spec for.
pipeline_task_spec: The pipeline task spec to update in place.
"""
condition = _binary_operations_to_cel_conjunctive(group.condition)
condition = _binary_operations_to_cel_conjunctive(group.conditions)
pipeline_task_spec.trigger_policy.CopyFrom(
pipeline_spec_pb2.PipelineTaskSpec.TriggerPolicy(condition=condition))

Expand Down Expand Up @@ -1260,7 +1260,7 @@ def build_spec_by_group(
condition_subgroup_channels = list(subgroup_input_channels)

compiler_utils.get_channels_from_condition(
subgroup.condition, condition_subgroup_channels)
subgroup.conditions, condition_subgroup_channels)

subgroup_component_spec = build_component_spec_for_group(
input_pipeline_channels=condition_subgroup_channels,
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/kfp/dsl/pipeline_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Definition for Pipeline."""

import functools
from typing import Callable, Optional, Union
from typing import Callable, Optional

from kfp.dsl import component_factory
from kfp.dsl import pipeline_task
Expand Down Expand Up @@ -189,7 +189,7 @@ def pop_tasks_group(self):
"""Removes the current TasksGroup from the stack."""
del self.groups[-1]

def get_last_tasks_group(self) -> Union['tasks_group.TasksGroup', None]:
def get_last_tasks_group(self) -> Optional['tasks_group.TasksGroup']:
"""Gets the last TasksGroup added to the pipeline at the current level
of the pipeline definition."""
groups = self.groups[-1].groups
Expand Down
12 changes: 6 additions & 6 deletions sdk/python/kfp/dsl/tasks_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,15 @@ class _ConditionBase(TasksGroup):

def __init__(
self,
condition: List[pipeline_channel.BinaryOperation],
conditions: List[pipeline_channel.BinaryOperation],
name: Optional[str] = None,
) -> None:
super().__init__(
group_type=TasksGroupType.CONDITION,
name=name,
is_root=False,
)
self.condition: List[pipeline_channel.BinaryOperation] = condition
self.conditions: List[pipeline_channel.BinaryOperation] = conditions


class If(_ConditionBase):
Expand All @@ -184,7 +184,7 @@ def __init__(
name: Optional[str] = None,
) -> None:
super().__init__(
condition=[condition],
conditions=[condition],
name=name,
)
if isinstance(condition, bool):
Expand Down Expand Up @@ -241,7 +241,7 @@ def __init__(
if not isinstance(prev_cond, (Condition, If, Elif)):
# prefer pushing toward dsl.If rather than dsl.Condition for syntactic consistency with the if-elif-else keywords in Python
raise InvalidControlFlowException(
'dsl.Else can only be used following an upstream dsl.If or dsl.Elif.'
'dsl.Elif can only be used following an upstream dsl.If or dsl.Elif.'
)

if isinstance(condition, bool):
Expand All @@ -259,7 +259,7 @@ def __init__(
conditions.append(condition)

super().__init__(
condition=conditions,
conditions=conditions,
name=name,
)

Expand Down Expand Up @@ -305,7 +305,7 @@ def __init__(
)

super().__init__(
condition=prev_cond._negated_upstream_conditions,
conditions=prev_cond._negated_upstream_conditions,
name=name,
)

Expand Down
4 changes: 2 additions & 2 deletions sdk/python/test_data/pipelines/if_elif_else.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


@dsl.component
def flip_coin() -> str:
def flip_three_sided_die() -> str:
import random
val = random.randint(0, 2)

Expand All @@ -36,7 +36,7 @@ def print_and_return(text: str) -> str:

@dsl.pipeline
def flip_coin_pipeline():
flip_coin_task = flip_coin()
flip_coin_task = flip_three_sided_die()
with dsl.If(flip_coin_task.output == 'heads'):
print_and_return(text='Got heads!')
with dsl.Elif(flip_coin_task.output == 'tails'):
Expand Down
64 changes: 45 additions & 19 deletions sdk/python/test_data/pipelines/if_elif_else_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List

from kfp import compiler
from kfp import dsl


@dsl.component
def int_zero_through_three() -> int:
def int_0_to_9999() -> int:
import random
return random.randint(0, 3)
return random.randint(0, 9999)


@dsl.component
def flip_coin() -> str:
import random
return 'heads' if random.randint(0, 1) == 0 else 'tails'
def is_even_or_odd(num: int) -> str:
return 'odd' if num % 2 else 'even'


@dsl.component
Expand All @@ -33,28 +34,53 @@ def print_and_return(text: str) -> str:
return text


@dsl.component
def print_strings(strings: List[str]):
print(strings)


@dsl.pipeline
def flip_coin_pipeline(confirm: bool):
int_task = int_zero_through_three()
flip_coin_task = flip_coin()
def lucky_number_pipeline(add_drumroll: bool = True,
repeat_if_lucky_number: bool = True,
trials: List[int] = [1, 2, 3]):
with dsl.ParallelFor(trials) as trial:
int_task = int_0_to_9999()
with dsl.If(add_drumroll == True):
with dsl.If(trial == 3):
print_and_return(text='Adding drumroll on last trial!')

with dsl.If(flip_coin_task.output == 'heads'):
with dsl.If(int_task.output == 0):
print_and_return(text='Got zero!')
with dsl.If(int_task.output < 5000):

with dsl.Elif(int_task.output == 1):
task = print_and_return(text='Got one!')
with dsl.If(confirm == True):
print_and_return(text='Confirmed: definitely got one.')
even_or_odd_task = is_even_or_odd(num=int_task.output)

with dsl.Elif(int_task.output == 2):
print_and_return(text='Got two!')
with dsl.If(even_or_odd_task.output == 'even'):
print_and_return(text='Got a low even number!')
with dsl.Else():
print_and_return(text='Got a low odd number!')

with dsl.Elif(int_task.output > 5000):

even_or_odd_task = is_even_or_odd(num=int_task.output)

with dsl.If(even_or_odd_task.output == 'even'):
print_and_return(text='Got a high even number!')
with dsl.Else():
print_and_return(text='Got a high odd number!')

with dsl.Else():
print_and_return(text='Got three!')
print_and_return(
text='Announcing: Got the lucky number 5000! A one in 10,000 chance.'
)
with dsl.If(repeat_if_lucky_number == True):
with dsl.ParallelFor([1, 2]) as _:
print_and_return(
text='Announcing again: Got the lucky number 5000! A one in 10,000 chance.'
)

print_strings(strings=dsl.Collected(even_or_odd_task.output))


if __name__ == '__main__':
compiler.Compiler().compile(
pipeline_func=flip_coin_pipeline,
pipeline_func=lucky_number_pipeline,
package_path=__file__.replace('.py', '.yaml'))
Loading

0 comments on commit e750768

Please sign in to comment.