Skip to content

Commit

Permalink
deprecate dsl.Condition
Browse files Browse the repository at this point in the history
  • Loading branch information
connor-mccarthy committed Aug 31, 2023
1 parent 047f141 commit 572915d
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 17 deletions.
29 changes: 14 additions & 15 deletions sdk/python/kfp/dsl/tasks_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import copy
import enum
from typing import List, Optional, Union
import warnings

from kfp.dsl import for_loop
from kfp.dsl import pipeline_channel
Expand Down Expand Up @@ -161,7 +162,7 @@ def __init__(
self.condition: List[pipeline_channel.BinaryOperation] = condition


class Condition(_ConditionBase):
class If(_ConditionBase):
"""A class for creating a conditional control flow "if" block within a
pipeline.
Expand All @@ -173,7 +174,7 @@ class Condition(_ConditionBase):
::
task1 = my_component1(...)
with dsl.Condition(task1.output=='pizza', 'pizza-condition'):
with dsl.If(task1.output=='pizza', 'pizza-condition'):
task2 = my_component2(...)
"""

Expand All @@ -195,22 +196,20 @@ def __init__(
self._negated_upstream_conditions = [copied_condition]


class If(Condition):
"""A class for creating a conditional control flow "if" block within a
pipeline. Identical to dsl.Condition.
Args:
condition: A comparative expression that evaluates to True or False. At least one of the operands must be an output from an upstream task or a pipeline parameter.
name: The name of the condition group.
Example:
::
class Condition(If):
"""Deprecated.
task1 = my_component1(...)
with dsl.If(task1.output=='pizza', 'pizza-condition'):
task2 = my_component2(...)
Use dsl.If instead.
"""

def __enter__(self):
super().__enter__()
warnings.warn(
'dsl.Condition is deprecated. Please use dsl.If instead.',
category=DeprecationWarning,
stacklevel=2)
return self


class Elif(_ConditionBase):
"""A class for creating a conditional control flow "else if" block within a
Expand Down
24 changes: 22 additions & 2 deletions sdk/python/kfp/dsl/tasks_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import parameterized
import unittest

from kfp import dsl
from kfp.dsl import for_loop
from kfp.dsl import pipeline_context
from kfp.dsl import tasks_group


class ParallelForTest(parameterized.TestCase):
class ParallelForTest(unittest.TestCase):

def test_basic(self):
loop_items = ['pizza', 'hotdog', 'pasta']
Expand Down Expand Up @@ -58,3 +60,21 @@ def test_parallelfor_invalid_parallelism(self):
'ParallelFor parallelism must be >= 0.'):
with pipeline_context.Pipeline('pipeline') as p:
tasks_group.ParallelFor(items=loop_items, parallelism=-1)


class TestConditionDeprecated(unittest.TestCase):

def test(self):

@dsl.component
def foo() -> str:
return 'foo'

@dsl.pipeline
def my_pipeline(string: str):
with self.assertWarnsRegex(
DeprecationWarning,
'dsl\.Condition is deprecated\. Please use dsl\.If instead\.'
):
with dsl.Condition(string == 'text'):
foo()

0 comments on commit 572915d

Please sign in to comment.