Skip to content

Commit

Permalink
add joblib and use to compute cache key instead
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor committed Aug 5, 2022
1 parent 12b5097 commit 97b454b
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 7 deletions.
21 changes: 16 additions & 5 deletions flytekit/core/local_cache.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import base64
from typing import Optional

import cloudpickle
import joblib
from diskcache import Cache
from google.protobuf.struct_pb2 import Struct

from flytekit.models.literals import Literal, LiteralCollection, LiteralMap

Expand All @@ -28,15 +28,26 @@ def _recursive_hash_placement(literal: Literal) -> Literal:
return literal


class ProtoJoblibHasher(joblib.hashing.NumpyHasher):
def save(self, obj):
if isinstance(obj, Struct):
obj = dict(
rewrite_rule="google.protobuf.struct_pb2.Struct",
cls=obj.__class__,
obj=dict(sorted(obj.fields.items())),
)
return obj


def _calculate_cache_key(task_name: str, cache_version: str, input_literal_map: LiteralMap) -> str:
# Traverse the literals and replace the literal with a new literal that only contains the hash
literal_map_overridden = {}
for key, literal in input_literal_map.literals.items():
literal_map_overridden[key] = _recursive_hash_placement(literal)

# Pickle the literal map and use base64 encoding to generate a representation of it
b64_encoded = base64.b64encode(cloudpickle.dumps(LiteralMap(literal_map_overridden)))
return f"{task_name}-{cache_version}-{b64_encoded}"
# Generate a hash key of inputs with joblib
hashed_inputs = ProtoJoblibHasher().hash(literal_map_overridden)
return f"{task_name}-{cache_version}-{hashed_inputs}"


class LocalTaskCache(object):
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"grpcio-status>=1.43,!=1.45.0",
"importlib-metadata",
"pyopenssl",
"joblib",
"protobuf>=3.6.1,<4",
"python-json-logger>=2.0.0",
"pytimeparse>=1.1.8,<2.0.0",
Expand Down
30 changes: 28 additions & 2 deletions tests/flytekit/unit/core/test_local_cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
import typing
from dataclasses import dataclass
from typing import List
from typing import Dict, List

import pandas
from dataclasses_json import dataclass_json
Expand All @@ -10,12 +10,15 @@

from flytekit.core.base_sql_task import SQLTask
from flytekit.core.base_task import kwtypes
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.dynamic_workflow_task import dynamic
from flytekit.core.hash import HashMethod
from flytekit.core.local_cache import LocalTaskCache
from flytekit.core.local_cache import LocalTaskCache, _calculate_cache_key
from flytekit.core.task import TaskMetadata, task
from flytekit.core.testing import task_mock
from flytekit.core.type_engine import TypeEngine
from flytekit.core.workflow import workflow
from flytekit.models.literals import LiteralMap
from flytekit.types.schema import FlyteSchema

# Global counter used to validate number of calls to cache
Expand Down Expand Up @@ -385,3 +388,26 @@ def my_workflow():
# Confirm that we see a cache hit in the case of annotated dataframes.
my_workflow()
assert n_cached_task_calls == 1


def test_cache_key_repetition():
pt = Dict
lt = TypeEngine.to_literal_type(pt)
ctx = FlyteContextManager.current_context()
kwargs = {
"a": 0.41083513079747874,
"b": 0.7773927872515183,
"c": 17,
}
keys = set()
for i in range(0, 100):
lit = TypeEngine.to_literal(ctx, kwargs, Dict, lt)
lm = LiteralMap(
literals={
"d": lit,
}
)
key = _calculate_cache_key("t1", "007", lm)
keys.add(key)

assert len(keys) == 1

0 comments on commit 97b454b

Please sign in to comment.