Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sdk/client): implements overriding caching options at submission #7912

3 changes: 2 additions & 1 deletion sdk/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
## Deprecations

## Bug Fixes and Other Changes
* Enable overriding caching options at submission time [\#7912](https://github.com/kubeflow/pipelines/pull/7912)

## Documentation Updates
# Current Version (2.0.0-beta.1)
Expand All @@ -29,7 +30,7 @@
* Include default registry context JSON in package distribution [\#7987](https://github.com/kubeflow/pipelines/pull/7987)

## Documentation Updates
# Current Version (2.0.0-beta.0)
# 2.0.0-beta.0

## Major Features and Improvements

Expand Down
16 changes: 13 additions & 3 deletions sdk/python/kfp/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def delete_experiment(self, experiment_id: str) -> dict:
"""
return self._experiment_api.delete_experiment(id=experiment_id)

def _extract_pipeline_yaml(self, package_file: str) -> str:
def _extract_pipeline_yaml(self, package_file: str) -> dict:

def _choose_pipeline_file(file_list: List[str]) -> str:
pipeline_files = [
Expand Down Expand Up @@ -689,9 +689,19 @@ def _choose_pipeline_file(file_list: List[str]) -> str:
f'The package_file {package_file} should end with one of the '
'following formats: [.tar.gz, .tgz, .zip, .yaml, .yml].')

def _override_caching_options(self, workflow: str,
def _override_caching_options(self, pipeline_obj: dict,
enable_caching: bool) -> None:
raise NotImplementedError('enable_caching is not supported yet.')
"""Overrides caching options.

Args:
pipeline_obj (dict): Dict object parsed from the yaml file.
enable_caching (bool): Overrides options, one of 'True', 'False'.
"""
for _, task in pipeline_obj['root']['dag']['tasks'].items():
if 'cachingOptions' in task:
task['cachingOptions']['enableCache'] = enable_caching
else:
task['cachingOptions'] = {'enableCache': enable_caching}
zichuan-scott-xu marked this conversation as resolved.
Show resolved Hide resolved

def list_pipelines(
self,
Expand Down
69 changes: 69 additions & 0 deletions sdk/python/kfp/client/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import unittest
import yaml

from absl.testing import parameterized
from kfp.client import client
from kfp.compiler import Compiler
from kfp.dsl import component
from kfp.dsl import pipeline


class TestValidatePipelineName(parameterized.TestCase):
Expand All @@ -41,5 +47,68 @@ def test_invalid(self, name: str):
client.validate_pipeline_resource_name(name)


class TestOverrideCachingOptions(parameterized.TestCase):

def test_override_caching_from_pipeline(self):

@component
def hello_world(text: str) -> str:
"""Hello world component."""
return text

@pipeline(name='hello-world', description='A simple intro pipeline')
def pipeline_hello_world(text: str = 'hi there'):
"""Hello world pipeline."""

hello_world(text=text).set_caching_options(True)

with tempfile.TemporaryDirectory() as tempdir:
temp_filepath = os.path.join(tempdir, 'hello_world_pipeline.yaml')
Compiler().compile(
pipeline_func=pipeline_hello_world, package_path=temp_filepath)

with open(temp_filepath, 'r') as f:
pipeline_obj = yaml.safe_load(f)
test_client = client.Client(namespace='dummy_namespace')
test_client._override_caching_options(pipeline_obj, False)
for _, task in pipeline_obj['root']['dag']['tasks'].items():
self.assertFalse(task['cachingOptions']['enableCache'])

def test_override_caching_of_multiple_components(self):

@component
def hello_word(text: str) -> str:
return text

@component
def to_lower(text: str) -> str:
return text.lower()

@pipeline(
name='sample two-step pipeline',
description='a minimal two-step pipeline')
def pipeline_with_two_component(text: str = 'hi there'):

component_1 = hello_word(text=text).set_caching_options(True)
component_2 = to_lower(
text=component_1.output).set_caching_options(True)

with tempfile.TemporaryDirectory() as tempdir:
temp_filepath = os.path.join(tempdir, 'hello_world_pipeline.yaml')
Compiler().compile(
pipeline_func=pipeline_with_two_component,
package_path=temp_filepath)

with open(temp_filepath, 'r') as f:
pipeline_obj = yaml.safe_load(f)
test_client = client.Client(namespace='dummy_namespace')
test_client._override_caching_options(pipeline_obj, False)
self.assertFalse(
pipeline_obj['root']['dag']['tasks']['hello-word']
['cachingOptions']['enableCache'])
self.assertFalse(pipeline_obj['root']['dag']['tasks']
['to-lower']['cachingOptions']['enableCache'])


if __name__ == '__main__':
unittest.main()