From 3f7ab0ddc667ec056f08f001b784a55b3daafdae Mon Sep 17 00:00:00 2001 From: ddalvi Date: Mon, 26 Aug 2024 08:41:10 -0400 Subject: [PATCH] Add an environment variable and CLI option to enable or disable default caching Signed-off-by: ddalvi --- sdk/python/kfp/cli/compile_.py | 20 ++++++++++++++++++++ sdk/python/kfp/dsl/pipeline_task.py | 19 +++++++++++-------- sdk/python/kfp/dsl/structures.py | 9 ++++++--- 3 files changed, 37 insertions(+), 11 deletions(-) diff --git a/sdk/python/kfp/cli/compile_.py b/sdk/python/kfp/cli/compile_.py index 2bd3bab18c23..4f945a9bbe26 100644 --- a/sdk/python/kfp/cli/compile_.py +++ b/sdk/python/kfp/cli/compile_.py @@ -133,14 +133,34 @@ def parse_parameters(parameters: Optional[str]) -> Dict: is_flag=True, default=False, help='Whether to disable type checking.') +@click.option( + '--execution-caching-enabled-by-default', + type=click.Choice(['enabled', 'disabled'], case_sensitive=False), + default=None, + help='Enable task-level caching. Enabled by default, set it to disabled to disable caching.' +) def compile_( py: str, output: str, function_name: Optional[str] = None, pipeline_parameters: Optional[str] = None, disable_type_check: bool = False, + execution_caching_enabled_by_default: Optional[bool] = None, ) -> None: """Compiles a pipeline or component written in a .py file.""" + + env_enable_caching = os.getenv('KFP_EXECUTION_CACHING_ENABLED_BY_DEFAULT', + 'enabled').lower() == 'enabled' + if execution_caching_enabled_by_default is None: + execution_caching_enabled_by_default = env_enable_caching + else: + execution_caching_enabled_by_default = execution_caching_enabled_by_default.lower( + ) == 'enabled' + if execution_caching_enabled_by_default: + os.environ['KFP_EXECUTION_CACHING_ENABLED_BY_DEFAULT'] = 'enabled' + else: + os.environ['KFP_EXECUTION_CACHING_ENABLED_BY_DEFAULT'] = 'disabled' + pipeline_func = collect_pipeline_or_component_func( python_file=py, function_name=function_name) parsed_parameters = parse_parameters(parameters=pipeline_parameters) diff --git a/sdk/python/kfp/dsl/pipeline_task.py b/sdk/python/kfp/dsl/pipeline_task.py index 2e82d23378aa..730e0a2c3810 100644 --- a/sdk/python/kfp/dsl/pipeline_task.py +++ b/sdk/python/kfp/dsl/pipeline_task.py @@ -18,6 +18,7 @@ import functools import inspect import itertools +import os import re from typing import Any, Dict, List, Mapping, Optional, Union import warnings @@ -130,7 +131,8 @@ def __init__( inputs=dict(args.items()), dependent_tasks=[], component_ref=component_spec.name, - enable_caching=True) + enable_caching=os.getenv('KFP_EXECUTION_CACHING_ENABLED_BY_DEFAULT', + 'enabled').lower() == 'enabled') self._run_after: List[str] = [] self.importer_spec = None @@ -161,13 +163,14 @@ def validate_placeholder_types( self.pipeline_spec = self.component_spec.implementation.graph self._outputs = { - output_name: pipeline_channel.create_pipeline_channel( - name=output_name, - channel_type=output_spec.type, - task_name=self._task_spec.name, - is_artifact_list=output_spec.is_artifact_list, - ) for output_name, output_spec in ( - component_spec.outputs or {}).items() + output_name: + pipeline_channel.create_pipeline_channel( + name=output_name, + channel_type=output_spec.type, + task_name=self._task_spec.name, + is_artifact_list=output_spec.is_artifact_list, + ) for output_name, output_spec in ( + component_spec.outputs or {}).items() } self._inputs = args diff --git a/sdk/python/kfp/dsl/structures.py b/sdk/python/kfp/dsl/structures.py index 440f9a3940af..740a04808bb4 100644 --- a/sdk/python/kfp/dsl/structures.py +++ b/sdk/python/kfp/dsl/structures.py @@ -17,6 +17,7 @@ import collections import dataclasses import itertools +import os import re from typing import Any, Dict, List, Mapping, Optional, Tuple, Union import uuid @@ -420,7 +421,8 @@ class TaskSpec: trigger_strategy: Optional[str] = None iterator_items: Optional[Any] = None iterator_item_input: Optional[str] = None - enable_caching: bool = True + enable_caching: bool = os.getenv('KFP_EXECUTION_CACHING_ENABLED_BY_DEFAULT', + 'enabled').lower() == 'enabled' display_name: Optional[str] = None retry_policy: Optional[RetryPolicy] = None @@ -637,8 +639,9 @@ def from_v1_component_spec( ] env = { key: - placeholders.maybe_convert_v1_yaml_placeholder_to_v2_placeholder( - command, component_dict=component_dict) + placeholders + .maybe_convert_v1_yaml_placeholder_to_v2_placeholder( + command, component_dict=component_dict) for key, command in container.get('env', {}).items() } container_spec = ContainerSpecImplementation.from_container_dict({