Skip to content

Commit

Permalink
Add an environment variable and CLI option to enable or disable defau…
Browse files Browse the repository at this point in the history
…lt caching

Signed-off-by: ddalvi <[email protected]>
  • Loading branch information
DharmitD committed Aug 29, 2024
1 parent 36cf066 commit 3f7ab0d
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 11 deletions.
20 changes: 20 additions & 0 deletions sdk/python/kfp/cli/compile_.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,34 @@ def parse_parameters(parameters: Optional[str]) -> Dict:
is_flag=True,
default=False,
help='Whether to disable type checking.')
@click.option(
'--execution-caching-enabled-by-default',
type=click.Choice(['enabled', 'disabled'], case_sensitive=False),
default=None,
help='Enable task-level caching. Enabled by default, set it to disabled to disable caching.'
)
def compile_(
py: str,
output: str,
function_name: Optional[str] = None,
pipeline_parameters: Optional[str] = None,
disable_type_check: bool = False,
execution_caching_enabled_by_default: Optional[bool] = None,
) -> None:
"""Compiles a pipeline or component written in a .py file."""

env_enable_caching = os.getenv('KFP_EXECUTION_CACHING_ENABLED_BY_DEFAULT',
'enabled').lower() == 'enabled'
if execution_caching_enabled_by_default is None:
execution_caching_enabled_by_default = env_enable_caching
else:
execution_caching_enabled_by_default = execution_caching_enabled_by_default.lower(
) == 'enabled'
if execution_caching_enabled_by_default:
os.environ['KFP_EXECUTION_CACHING_ENABLED_BY_DEFAULT'] = 'enabled'
else:
os.environ['KFP_EXECUTION_CACHING_ENABLED_BY_DEFAULT'] = 'disabled'

pipeline_func = collect_pipeline_or_component_func(
python_file=py, function_name=function_name)
parsed_parameters = parse_parameters(parameters=pipeline_parameters)
Expand Down
19 changes: 11 additions & 8 deletions sdk/python/kfp/dsl/pipeline_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import functools
import inspect
import itertools
import os
import re
from typing import Any, Dict, List, Mapping, Optional, Union
import warnings
Expand Down Expand Up @@ -130,7 +131,8 @@ def __init__(
inputs=dict(args.items()),
dependent_tasks=[],
component_ref=component_spec.name,
enable_caching=True)
enable_caching=os.getenv('KFP_EXECUTION_CACHING_ENABLED_BY_DEFAULT',
'enabled').lower() == 'enabled')
self._run_after: List[str] = []

self.importer_spec = None
Expand Down Expand Up @@ -161,13 +163,14 @@ def validate_placeholder_types(
self.pipeline_spec = self.component_spec.implementation.graph

self._outputs = {
output_name: pipeline_channel.create_pipeline_channel(
name=output_name,
channel_type=output_spec.type,
task_name=self._task_spec.name,
is_artifact_list=output_spec.is_artifact_list,
) for output_name, output_spec in (
component_spec.outputs or {}).items()
output_name:
pipeline_channel.create_pipeline_channel(
name=output_name,
channel_type=output_spec.type,
task_name=self._task_spec.name,
is_artifact_list=output_spec.is_artifact_list,
) for output_name, output_spec in (
component_spec.outputs or {}).items()
}

self._inputs = args
Expand Down
9 changes: 6 additions & 3 deletions sdk/python/kfp/dsl/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import collections
import dataclasses
import itertools
import os
import re
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
import uuid
Expand Down Expand Up @@ -420,7 +421,8 @@ class TaskSpec:
trigger_strategy: Optional[str] = None
iterator_items: Optional[Any] = None
iterator_item_input: Optional[str] = None
enable_caching: bool = True
enable_caching: bool = os.getenv('KFP_EXECUTION_CACHING_ENABLED_BY_DEFAULT',
'enabled').lower() == 'enabled'
display_name: Optional[str] = None
retry_policy: Optional[RetryPolicy] = None

Expand Down Expand Up @@ -637,8 +639,9 @@ def from_v1_component_spec(
]
env = {
key:
placeholders.maybe_convert_v1_yaml_placeholder_to_v2_placeholder(
command, component_dict=component_dict)
placeholders
.maybe_convert_v1_yaml_placeholder_to_v2_placeholder(
command, component_dict=component_dict)
for key, command in container.get('env', {}).items()
}
container_spec = ContainerSpecImplementation.from_container_dict({
Expand Down

0 comments on commit 3f7ab0d

Please sign in to comment.