diff --git a/sdk/python/kfp/dsl/tasks_group.py b/sdk/python/kfp/dsl/tasks_group.py index 689a502b6cd..78f51cb4bf9 100644 --- a/sdk/python/kfp/dsl/tasks_group.py +++ b/sdk/python/kfp/dsl/tasks_group.py @@ -16,6 +16,7 @@ import copy import enum from typing import List, Optional, Union +import warnings from kfp.dsl import for_loop from kfp.dsl import pipeline_channel @@ -161,7 +162,7 @@ def __init__( self.condition: List[pipeline_channel.BinaryOperation] = condition -class Condition(_ConditionBase): +class If(_ConditionBase): """A class for creating a conditional control flow "if" block within a pipeline. @@ -173,7 +174,7 @@ class Condition(_ConditionBase): :: task1 = my_component1(...) - with dsl.Condition(task1.output=='pizza', 'pizza-condition'): + with dsl.If(task1.output=='pizza', 'pizza-condition'): task2 = my_component2(...) """ @@ -195,22 +196,20 @@ def __init__( self._negated_upstream_conditions = [copied_condition] -class If(Condition): - """A class for creating a conditional control flow "if" block within a - pipeline. Identical to dsl.Condition. - - Args: - condition: A comparative expression that evaluates to True or False. At least one of the operands must be an output from an upstream task or a pipeline parameter. - name: The name of the condition group. - - Example: - :: +class Condition(If): + """Deprecated. - task1 = my_component1(...) - with dsl.If(task1.output=='pizza', 'pizza-condition'): - task2 = my_component2(...) + Use dsl.If instead. """ + def __enter__(self): + super().__enter__() + warnings.warn( + 'dsl.Condition is deprecated. Please use dsl.If instead.', + category=DeprecationWarning, + stacklevel=2) + return self + class Elif(_ConditionBase): """A class for creating a conditional control flow "else if" block within a diff --git a/sdk/python/kfp/dsl/tasks_group_test.py b/sdk/python/kfp/dsl/tasks_group_test.py index 09ba5cdbc34..40c68ab3725 100644 --- a/sdk/python/kfp/dsl/tasks_group_test.py +++ b/sdk/python/kfp/dsl/tasks_group_test.py @@ -12,13 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from absl.testing import parameterized +import unittest + +from kfp import dsl from kfp.dsl import for_loop from kfp.dsl import pipeline_context from kfp.dsl import tasks_group -class ParallelForTest(parameterized.TestCase): +class ParallelForTest(unittest.TestCase): def test_basic(self): loop_items = ['pizza', 'hotdog', 'pasta'] @@ -58,3 +60,21 @@ def test_parallelfor_invalid_parallelism(self): 'ParallelFor parallelism must be >= 0.'): with pipeline_context.Pipeline('pipeline') as p: tasks_group.ParallelFor(items=loop_items, parallelism=-1) + + +class TestConditionDeprecated(unittest.TestCase): + + def test(self): + + @dsl.component + def foo() -> str: + return 'foo' + + @dsl.pipeline + def my_pipeline(string: str): + with self.assertWarnsRegex( + DeprecationWarning, + 'dsl\.Condition is deprecated\. Please use dsl\.If instead\.' + ): + with dsl.Condition(string == 'text'): + foo()