diff --git a/sdk/RELEASE.md b/sdk/RELEASE.md index 36f39cf6317..8ace4a0fa5c 100644 --- a/sdk/RELEASE.md +++ b/sdk/RELEASE.md @@ -11,6 +11,7 @@ ## Deprecations ## Bug Fixes and Other Changes +* Enable overriding caching options at submission time [\#7912](https://github.com/kubeflow/pipelines/pull/7912) ## Documentation Updates # Current Version (2.0.0-beta.1) @@ -29,7 +30,7 @@ * Include default registry context JSON in package distribution [\#7987](https://github.com/kubeflow/pipelines/pull/7987) ## Documentation Updates -# Current Version (2.0.0-beta.0) +# 2.0.0-beta.0 ## Major Features and Improvements diff --git a/sdk/python/kfp/client/client.py b/sdk/python/kfp/client/client.py index 364794ed991..c30eb7f8f46 100644 --- a/sdk/python/kfp/client/client.py +++ b/sdk/python/kfp/client/client.py @@ -649,7 +649,7 @@ def delete_experiment(self, experiment_id: str) -> dict: """ return self._experiment_api.delete_experiment(id=experiment_id) - def _extract_pipeline_yaml(self, package_file: str) -> str: + def _extract_pipeline_yaml(self, package_file: str) -> dict: def _choose_pipeline_file(file_list: List[str]) -> str: pipeline_files = [ @@ -689,9 +689,19 @@ def _choose_pipeline_file(file_list: List[str]) -> str: f'The package_file {package_file} should end with one of the ' 'following formats: [.tar.gz, .tgz, .zip, .yaml, .yml].') - def _override_caching_options(self, workflow: str, + def _override_caching_options(self, pipeline_obj: dict, enable_caching: bool) -> None: - raise NotImplementedError('enable_caching is not supported yet.') + """Overrides caching options. + + Args: + pipeline_obj (dict): Dict object parsed from the yaml file. + enable_caching (bool): Overrides options, one of 'True', 'False'. + """ + for _, task in pipeline_obj['root']['dag']['tasks'].items(): + if 'cachingOptions' in task: + task['cachingOptions']['enableCache'] = enable_caching + else: + task['cachingOptions'] = {'enableCache': enable_caching} def list_pipelines( self, diff --git a/sdk/python/kfp/client/client_test.py b/sdk/python/kfp/client/client_test.py index 6eb559acab4..0c72c7758d6 100644 --- a/sdk/python/kfp/client/client_test.py +++ b/sdk/python/kfp/client/client_test.py @@ -12,10 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import tempfile import unittest +import yaml from absl.testing import parameterized from kfp.client import client +from kfp.compiler import Compiler +from kfp.dsl import component +from kfp.dsl import pipeline class TestValidatePipelineName(parameterized.TestCase): @@ -41,5 +47,68 @@ def test_invalid(self, name: str): client.validate_pipeline_resource_name(name) +class TestOverrideCachingOptions(parameterized.TestCase): + + def test_override_caching_from_pipeline(self): + + @component + def hello_world(text: str) -> str: + """Hello world component.""" + return text + + @pipeline(name='hello-world', description='A simple intro pipeline') + def pipeline_hello_world(text: str = 'hi there'): + """Hello world pipeline.""" + + hello_world(text=text).set_caching_options(True) + + with tempfile.TemporaryDirectory() as tempdir: + temp_filepath = os.path.join(tempdir, 'hello_world_pipeline.yaml') + Compiler().compile( + pipeline_func=pipeline_hello_world, package_path=temp_filepath) + + with open(temp_filepath, 'r') as f: + pipeline_obj = yaml.safe_load(f) + test_client = client.Client(namespace='dummy_namespace') + test_client._override_caching_options(pipeline_obj, False) + for _, task in pipeline_obj['root']['dag']['tasks'].items(): + self.assertFalse(task['cachingOptions']['enableCache']) + + def test_override_caching_of_multiple_components(self): + + @component + def hello_word(text: str) -> str: + return text + + @component + def to_lower(text: str) -> str: + return text.lower() + + @pipeline( + name='sample two-step pipeline', + description='a minimal two-step pipeline') + def pipeline_with_two_component(text: str = 'hi there'): + + component_1 = hello_word(text=text).set_caching_options(True) + component_2 = to_lower( + text=component_1.output).set_caching_options(True) + + with tempfile.TemporaryDirectory() as tempdir: + temp_filepath = os.path.join(tempdir, 'hello_world_pipeline.yaml') + Compiler().compile( + pipeline_func=pipeline_with_two_component, + package_path=temp_filepath) + + with open(temp_filepath, 'r') as f: + pipeline_obj = yaml.safe_load(f) + test_client = client.Client(namespace='dummy_namespace') + test_client._override_caching_options(pipeline_obj, False) + self.assertFalse( + pipeline_obj['root']['dag']['tasks']['hello-word'] + ['cachingOptions']['enableCache']) + self.assertFalse(pipeline_obj['root']['dag']['tasks'] + ['to-lower']['cachingOptions']['enableCache']) + + if __name__ == '__main__': unittest.main()