diff --git a/flytekit/core/local_cache.py b/flytekit/core/local_cache.py index 4afab73955..c1f42a36a4 100644 --- a/flytekit/core/local_cache.py +++ b/flytekit/core/local_cache.py @@ -1,8 +1,8 @@ -import base64 from typing import Optional -import cloudpickle +import joblib from diskcache import Cache +from google.protobuf.struct_pb2 import Struct from flytekit.models.literals import Literal, LiteralCollection, LiteralMap @@ -28,15 +28,26 @@ def _recursive_hash_placement(literal: Literal) -> Literal: return literal +class ProtoJoblibHasher(joblib.hashing.NumpyHasher): + def save(self, obj): + if isinstance(obj, Struct): + obj = dict( + rewrite_rule="google.protobuf.struct_pb2.Struct", + cls=obj.__class__, + obj=dict(sorted(obj.fields.items())), + ) + return obj + + def _calculate_cache_key(task_name: str, cache_version: str, input_literal_map: LiteralMap) -> str: # Traverse the literals and replace the literal with a new literal that only contains the hash literal_map_overridden = {} for key, literal in input_literal_map.literals.items(): literal_map_overridden[key] = _recursive_hash_placement(literal) - # Pickle the literal map and use base64 encoding to generate a representation of it - b64_encoded = base64.b64encode(cloudpickle.dumps(LiteralMap(literal_map_overridden))) - return f"{task_name}-{cache_version}-{b64_encoded}" + # Generate a hash key of inputs with joblib + hashed_inputs = ProtoJoblibHasher().hash(literal_map_overridden) + return f"{task_name}-{cache_version}-{hashed_inputs}" class LocalTaskCache(object): diff --git a/setup.py b/setup.py index 85af8691b6..56e2dd5624 100644 --- a/setup.py +++ b/setup.py @@ -52,6 +52,7 @@ "grpcio-status>=1.43,!=1.45.0", "importlib-metadata", "pyopenssl", + "joblib", "protobuf>=3.6.1,<4", "python-json-logger>=2.0.0", "pytimeparse>=1.1.8,<2.0.0", diff --git a/tests/flytekit/unit/core/test_local_cache.py b/tests/flytekit/unit/core/test_local_cache.py index fe09fac830..5a6d2d3853 100644 --- a/tests/flytekit/unit/core/test_local_cache.py +++ b/tests/flytekit/unit/core/test_local_cache.py @@ -1,7 +1,7 @@ import datetime import typing from dataclasses import dataclass -from typing import List +from typing import Dict, List import pandas from dataclasses_json import dataclass_json @@ -10,12 +10,15 @@ from flytekit.core.base_sql_task import SQLTask from flytekit.core.base_task import kwtypes +from flytekit.core.context_manager import FlyteContextManager from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.hash import HashMethod -from flytekit.core.local_cache import LocalTaskCache +from flytekit.core.local_cache import LocalTaskCache, _calculate_cache_key from flytekit.core.task import TaskMetadata, task from flytekit.core.testing import task_mock +from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow +from flytekit.models.literals import LiteralMap from flytekit.types.schema import FlyteSchema # Global counter used to validate number of calls to cache @@ -385,3 +388,26 @@ def my_workflow(): # Confirm that we see a cache hit in the case of annotated dataframes. my_workflow() assert n_cached_task_calls == 1 + + +def test_cache_key_repetition(): + pt = Dict + lt = TypeEngine.to_literal_type(pt) + ctx = FlyteContextManager.current_context() + kwargs = { + "a": 0.41083513079747874, + "b": 0.7773927872515183, + "c": 17, + } + keys = set() + for i in range(0, 100): + lit = TypeEngine.to_literal(ctx, kwargs, Dict, lt) + lm = LiteralMap( + literals={ + "d": lit, + } + ) + key = _calculate_cache_key("t1", "007", lm) + keys.add(key) + + assert len(keys) == 1