diff --git a/sdk/python/kfp/compiler/compiler_test.py b/sdk/python/kfp/compiler/compiler_test.py index 9f4a9310877..e01e8d5e910 100644 --- a/sdk/python/kfp/compiler/compiler_test.py +++ b/sdk/python/kfp/compiler/compiler_test.py @@ -16,10 +16,9 @@ import os import re import tempfile -import unittest from typing import Any, Dict, List, Optional +import unittest -import yaml from absl.testing import parameterized from click import testing from google.protobuf import json_format @@ -30,6 +29,7 @@ from kfp.components.types import type_utils from kfp.dsl import PipelineTaskFinalStatus from kfp.pipeline_spec import pipeline_spec_pb2 +import yaml VALID_PRODUCER_COMPONENT_SAMPLE = components.load_component_from_text(""" name: producer @@ -977,5 +977,36 @@ def test_compile_pipelines(self, file: str): self._test_compile_py_to_yaml(file) +class TestSetRetryCompilation(unittest.TestCase): + + def test_set_retry(self): + + @dsl.component + def hello_world(text: str) -> str: + """Hello world component.""" + return text + + @dsl.pipeline(name='hello-world', description='A simple intro pipeline') + def pipeline_hello_world(text: str = 'hi there'): + """Hello world pipeline.""" + + hello_world(text=text).set_retry( + num_retries=3, + backoff_duration='30s', + backoff_factor=1.0, + backoff_max_duration='3h', + ) + + with tempfile.TemporaryDirectory() as tempdir: + package_path = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=pipeline_hello_world, package_path=package_path) + pipeline_spec = pipeline_spec_from_file(package_path) + + self.assertEqual( + pipeline_spec.root.dag.tasks['hello-world'].retry_policy + .max_retry_count, 3) + + if __name__ == '__main__': unittest.main()