diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 2d7938c67c..8290438418 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -85,6 +85,7 @@ class TaskMetadata(object): cache (bool): Indicates if caching should be enabled. See :std:ref:`Caching ` cache_serialize (bool): Indicates if identical (ie. same inputs) instances of this task should be executed in serial when caching is enabled. See :std:ref:`Caching ` cache_version (str): Version to be used for the cached value + cache_ignore_input_vars (Tuple[str, ...]): Input variables that should not be included when calculating hash for cache interruptible (Optional[bool]): Indicates that this task can be interrupted and/or scheduled on nodes with lower QoS guarantees that can include pre-emption. This can reduce the monetary cost executions incur at the cost of performance penalties due to potential interruptions @@ -100,6 +101,7 @@ class TaskMetadata(object): cache: bool = False cache_serialize: bool = False cache_version: str = "" + cache_ignore_input_vars: Tuple[str, ...] = () interruptible: Optional[bool] = None deprecated: str = "" retries: int = 0 @@ -116,6 +118,10 @@ def __post_init__(self): raise ValueError("Caching is enabled ``cache=True`` but ``cache_version`` is not set.") if self.cache_serialize and not self.cache: raise ValueError("Cache serialize is enabled ``cache_serialize=True`` but ``cache`` is not enabled.") + if self.cache_ignore_input_vars and not self.cache: + raise ValueError( + f"Cache ignore input vars are specified ``cache_ignore_input_vars={self.cache_ignore_input_vars}`` but ``cache`` is not enabled." + ) @property def retry_strategy(self) -> _literal_models.RetryStrategy: @@ -139,6 +145,7 @@ def to_taskmetadata_model(self) -> _task_model.TaskMetadata: deprecated_error_message=self.deprecated, cache_serializable=self.cache_serialize, pod_template_name=self.pod_template_name, + cache_ignore_input_vars=self.cache_ignore_input_vars, ) @@ -268,18 +275,26 @@ def local_execute( # TODO: how to get a nice `native_inputs` here? logger.info( f"Checking cache for task named {self.name}, cache version {self.metadata.cache_version} " - f"and inputs: {input_literal_map}" + f", inputs: {input_literal_map}, and ignore input vars: {self.metadata.cache_ignore_input_vars}" + ) + outputs_literal_map = LocalTaskCache.get( + self.name, self.metadata.cache_version, input_literal_map, self.metadata.cache_ignore_input_vars ) - outputs_literal_map = LocalTaskCache.get(self.name, self.metadata.cache_version, input_literal_map) # The cache returns None iff the key does not exist in the cache if outputs_literal_map is None: logger.info("Cache miss, task will be executed now") outputs_literal_map = self.sandbox_execute(ctx, input_literal_map) # TODO: need `native_inputs` - LocalTaskCache.set(self.name, self.metadata.cache_version, input_literal_map, outputs_literal_map) + LocalTaskCache.set( + self.name, + self.metadata.cache_version, + input_literal_map, + self.metadata.cache_ignore_input_vars, + outputs_literal_map, + ) logger.info( f"Cache set for task named {self.name}, cache version {self.metadata.cache_version} " - f"and inputs: {input_literal_map}" + f", inputs: {input_literal_map}, and ignore input vars: {self.metadata.cache_ignore_input_vars}" ) else: logger.info("Cache hit") diff --git a/flytekit/core/local_cache.py b/flytekit/core/local_cache.py index 1e8363396e..a52e6708c7 100644 --- a/flytekit/core/local_cache.py +++ b/flytekit/core/local_cache.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple from diskcache import Cache @@ -28,12 +28,15 @@ def _recursive_hash_placement(literal: Literal) -> Literal: return literal -def _calculate_cache_key(task_name: str, cache_version: str, input_literal_map: LiteralMap) -> str: +def _calculate_cache_key( + task_name: str, cache_version: str, input_literal_map: LiteralMap, cache_ignore_input_vars: Tuple[str, ...] = () +) -> str: # Traverse the literals and replace the literal with a new literal that only contains the hash literal_map_overridden = {} for key, literal in input_literal_map.literals.items(): + if key in cache_ignore_input_vars: + continue literal_map_overridden[key] = _recursive_hash_placement(literal) - # Generate a stable representation of the underlying protobuf by passing `deterministic=True` to the # protobuf library. hashed_inputs = LiteralMap(literal_map_overridden).to_flyte_idl().SerializeToString(deterministic=True) @@ -61,13 +64,25 @@ def clear(): LocalTaskCache._cache.clear() @staticmethod - def get(task_name: str, cache_version: str, input_literal_map: LiteralMap) -> Optional[LiteralMap]: + def get( + task_name: str, cache_version: str, input_literal_map: LiteralMap, cache_ignore_input_vars: Tuple[str, ...] + ) -> Optional[LiteralMap]: if not LocalTaskCache._initialized: LocalTaskCache.initialize() - return LocalTaskCache._cache.get(_calculate_cache_key(task_name, cache_version, input_literal_map)) + return LocalTaskCache._cache.get( + _calculate_cache_key(task_name, cache_version, input_literal_map, cache_ignore_input_vars) + ) @staticmethod - def set(task_name: str, cache_version: str, input_literal_map: LiteralMap, value: LiteralMap) -> None: + def set( + task_name: str, + cache_version: str, + input_literal_map: LiteralMap, + cache_ignore_input_vars: Tuple[str, ...], + value: LiteralMap, + ) -> None: if not LocalTaskCache._initialized: LocalTaskCache.initialize() - LocalTaskCache._cache.add(_calculate_cache_key(task_name, cache_version, input_literal_map), value) + LocalTaskCache._cache.add( + _calculate_cache_key(task_name, cache_version, input_literal_map, cache_ignore_input_vars), value + ) diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 547abd41fa..39389bfdea 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -1,6 +1,6 @@ import datetime as _datetime from functools import update_wrapper -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, overload +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, overload from flytekit.core.base_task import TaskMetadata, TaskResolverMixin from flytekit.core.interface import transform_function_to_interface @@ -87,6 +87,7 @@ def task( cache: bool = ..., cache_serialize: bool = ..., cache_version: str = ..., + cache_ignore_input_vars: Tuple[str, ...] = ..., retries: int = ..., interruptible: Optional[bool] = ..., deprecated: str = ..., @@ -115,6 +116,7 @@ def task( cache: bool = ..., cache_serialize: bool = ..., cache_version: str = ..., + cache_ignore_input_vars: Tuple[str, ...] = ..., retries: int = ..., interruptible: Optional[bool] = ..., deprecated: str = ..., @@ -142,6 +144,7 @@ def task( cache: bool = False, cache_serialize: bool = False, cache_version: str = "", + cache_ignore_input_vars: Tuple[str, ...] = (), retries: int = 0, interruptible: Optional[bool] = None, deprecated: str = "", @@ -200,6 +203,7 @@ def my_task(x: int, y: typing.Dict[str, str]) -> str: :param cache_version: Cache version to use. Changes to the task signature will automatically trigger a cache miss, but you can always manually update this field as well to force a cache miss. You should also manually bump this version if the function body/business logic has changed, but the signature hasn't. + :param cache_ignore_input_vars: Input variables that should not be included when calculating hash for cache. :param retries: Number of times to retry this task during a workflow execution. :param interruptible: [Optional] Boolean that indicates that this task can be interrupted and/or scheduled on nodes with lower QoS guarantees. This will directly reduce the `$`/`execution cost` associated, @@ -260,6 +264,7 @@ def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]: cache=cache, cache_serialize=cache_serialize, cache_version=cache_version, + cache_ignore_input_vars=cache_ignore_input_vars, retries=retries, interruptible=interruptible, deprecated=deprecated, diff --git a/flytekit/models/task.py b/flytekit/models/task.py index 48a8abfde1..8ad74ede23 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -178,6 +178,7 @@ def __init__( deprecated_error_message, cache_serializable, pod_template_name, + cache_ignore_input_vars, ): """ Information needed at runtime to determine behavior such as whether or not outputs are discoverable, timeouts, @@ -198,6 +199,7 @@ def __init__( :param bool cache_serializable: Whether or not caching operations are executed in serial. This means only a single instance over identical inputs is executed, other concurrent executions wait for the cached results. :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. + :param cache_ignore_input_vars: Input variables that should not be included when calculating hash for cache. """ self._discoverable = discoverable self._runtime = runtime @@ -208,6 +210,7 @@ def __init__( self._deprecated_error_message = deprecated_error_message self._cache_serializable = cache_serializable self._pod_template_name = pod_template_name + self._cache_ignore_input_vars = cache_ignore_input_vars @property def discoverable(self): @@ -285,6 +288,14 @@ def pod_template_name(self): """ return self._pod_template_name + @property + def cache_ignore_input_vars(self): + """ + Input variables that should not be included when calculating hash for cache. + :rtype: tuple[Text] + """ + return self._cache_ignore_input_vars + def to_flyte_idl(self): """ :rtype: flyteidl.admin.task_pb2.TaskMetadata @@ -298,6 +309,7 @@ def to_flyte_idl(self): deprecated_error_message=self.deprecated_error_message, cache_serializable=self.cache_serializable, pod_template_name=self.pod_template_name, + cache_ignore_input_vars=self.cache_ignore_input_vars, ) if self.timeout: tm.timeout.FromTimedelta(self.timeout) @@ -319,6 +331,7 @@ def from_flyte_idl(cls, pb2_object): deprecated_error_message=pb2_object.deprecated_error_message, cache_serializable=pb2_object.cache_serializable, pod_template_name=pb2_object.pod_template_name, + cache_ignore_input_vars=pb2_object.cache_ignore_input_vars, ) diff --git a/tests/flytekit/common/parameterizers.py b/tests/flytekit/common/parameterizers.py index 96c30b69b4..8426716ec7 100644 --- a/tests/flytekit/common/parameterizers.py +++ b/tests/flytekit/common/parameterizers.py @@ -125,8 +125,9 @@ deprecated, cache_serializable, pod_template_name, + cache_ignore_input_vars, ) - for discoverable, runtime_metadata, timeout, retry_strategy, interruptible, discovery_version, deprecated, cache_serializable, pod_template_name in product( + for discoverable, runtime_metadata, timeout, retry_strategy, interruptible, discovery_version, deprecated, cache_serializable, pod_template_name, cache_ignore_input_vars in product( [True, False], LIST_OF_RUNTIME_METADATA, [timedelta(days=i) for i in range(3)], @@ -136,6 +137,7 @@ ["deprecated"], [True, False], ["A", "B"], + [()], ) ] diff --git a/tests/flytekit/unit/core/test_local_cache.py b/tests/flytekit/unit/core/test_local_cache.py index 1569a258f4..7d6ead51ba 100644 --- a/tests/flytekit/unit/core/test_local_cache.py +++ b/tests/flytekit/unit/core/test_local_cache.py @@ -1,4 +1,5 @@ import datetime +import re import typing from dataclasses import dataclass from typing import Dict, List @@ -496,3 +497,30 @@ def test_literal_hash_placement(): assert litmap.hash == _recursive_hash_placement(litmap).hash assert litcoll.hash == _recursive_hash_placement(litcoll).hash + + +def test_cache_ignore_input_vars(): + @task(cache=True, cache_version="v1", cache_ignore_input_vars=["a"]) + def add(a: int, b: int) -> int: + return a + b + + @workflow + def add_wf(a: int, b: int) -> int: + return add(a=a, b=b) + + assert add_wf(a=10, b=5) == 15 + assert add_wf(a=20, b=5) == 15 # since a is ignored, this line will hit cache of a=10, b=5 + assert add_wf(a=20, b=8) == 28 + + +def test_set_cache_ignore_input_vars_without_set_cache(): + with pytest.raises( + ValueError, + match=re.escape( + "Cache ignore input vars are specified ``cache_ignore_input_vars=['a']`` but ``cache`` is not enabled." + ), + ): + + @task(cache_ignore_input_vars=["a"]) + def add(a: int, b: int) -> int: + return a + b diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index 1a41aaab40..22f78579e7 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -100,6 +100,7 @@ def get_task_template(task_type: str) -> TaskTemplate: "This is deprecated!", True, "A", + (), ) interfaces = interface_models.TypedInterface( diff --git a/tests/flytekit/unit/models/test_tasks.py b/tests/flytekit/unit/models/test_tasks.py index b4158c3852..b9685736b7 100644 --- a/tests/flytekit/unit/models/test_tasks.py +++ b/tests/flytekit/unit/models/test_tasks.py @@ -73,6 +73,7 @@ def test_task_metadata(): "This is deprecated!", True, "A", + (), ) assert obj.discoverable is True @@ -142,6 +143,7 @@ def test_task_spec(): "This is deprecated!", True, "A", + (), ) int_type = types.LiteralType(types.SimpleType.INTEGER) @@ -202,6 +204,7 @@ def test_task_template_k8s_pod_target(): "deprecated", False, "A", + (), ), interface_models.TypedInterface( # inputs diff --git a/tests/flytekit/unit/models/test_workflow_closure.py b/tests/flytekit/unit/models/test_workflow_closure.py index 64e5a57713..8181d0c256 100644 --- a/tests/flytekit/unit/models/test_workflow_closure.py +++ b/tests/flytekit/unit/models/test_workflow_closure.py @@ -42,6 +42,7 @@ def test_workflow_closure(): "This is deprecated!", True, "A", + (), ) cpu_resource = _task.Resources.ResourceEntry(_task.Resources.ResourceName.CPU, "1")