Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Caching of offloaded objects #762

Merged
merged 41 commits into from
Mar 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
06e960f
Remove flyteidl from install_requires
eapolinario Dec 2, 2021
c9047a8
Expose hash in Literal
eapolinario Dec 2, 2021
4d21957
Set hash in TypeEngine
eapolinario Dec 2, 2021
f9273db
Modify cache key calculation to take hash into account
eapolinario Dec 2, 2021
53dff4d
Opt-in PandasDataFrameTransformer
eapolinario Dec 2, 2021
5ae247c
Add unit tests
eapolinario Dec 2, 2021
9305746
Iterate using a flyteidl branch
eapolinario Dec 3, 2021
faff038
Merge remote-tracking branch 'origin' into offloaded-objects-caching
eapolinario Dec 3, 2021
ecd8b93
Regenerate requirements files
eapolinario Dec 3, 2021
44b72f6
Regenerate requirements files
eapolinario Dec 3, 2021
f46dc74
Merge remote-tracking branch 'origin' into offloaded-objects-caching
eapolinario Jan 19, 2022
22c90d3
Move _hash_overridable to StructureDatasetTransformerEngine
eapolinario Jan 21, 2022
fe7e8f7
Merge remote-tracking branch 'origin' into offloaded-objects-caching
eapolinario Jan 25, 2022
6ac4b44
Move HashMethod to flytekit.core.hash
eapolinario Jan 26, 2022
1702961
Merge remote-tracking branch 'origin' into offloaded-objects-caching
eapolinario Feb 17, 2022
4552e78
Fix `unit_test` make target
eapolinario Feb 17, 2022
c9fc044
Merge remote-tracking branch 'origin' into offloaded-objects-caching
eapolinario Feb 17, 2022
24f67a0
Split `unit_test` make target in two lines
eapolinario Feb 17, 2022
60ecf3a
Add assert to structured dataset compatibility test
eapolinario Feb 17, 2022
3aeedeb
Remove TODO
eapolinario Feb 17, 2022
c2dbb54
Regenerate plugins requirements files pointing to the right version o…
eapolinario Feb 17, 2022
0e58199
Set hash as a property of the literal
eapolinario Feb 17, 2022
5054861
Install plugins requirements in CI.
eapolinario Feb 17, 2022
9f2d06f
Add hash.setter
eapolinario Feb 18, 2022
e039836
Install flyteidl directly
eapolinario Feb 18, 2022
2da76f3
Revert "Regenerate plugins requirements files pointing to the right v…
eapolinario Feb 18, 2022
adaa448
wip - Add support for univariate lists
eapolinario Feb 18, 2022
4b5f608
Add support for lists of annotated objects
eapolinario Feb 18, 2022
ecffa04
Revamp generation of cache key (to cover case of literals collections…
eapolinario Feb 23, 2022
d4b0b49
Leave TODO for warning
eapolinario Feb 23, 2022
4d54c59
Revert "Add support for lists of annotated objects"
eapolinario Feb 23, 2022
82bbc1f
Revert "wip - Add support for univariate lists"
eapolinario Feb 23, 2022
a21631b
Remove docstring
eapolinario Mar 1, 2022
a432b1e
Merge remote-tracking branch 'origin' into offloaded-objects-caching
eapolinario Mar 1, 2022
3519149
Add flyteidl>=0.23.0
eapolinario Mar 2, 2022
09e736d
Remove mentions to branch flyteidl@add-hash-to-literal
eapolinario Mar 2, 2022
9ec103c
Bump flyteidl in plugins requirements
eapolinario Mar 2, 2022
fa4e3b2
Regenerate plugins requirements again
eapolinario Mar 2, 2022
373b3bb
Restore papermill/requirements.txt
eapolinario Mar 2, 2022
403830d
Point flytekitplugins-spark to the offloaded-objects-caching branch i…
eapolinario Mar 2, 2022
9754a0f
Set flyteidl>=0.23.0 in papermill dev-requirements
eapolinario Mar 2, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ test: lint unit_test

.PHONY: unit_test
unit_test:
FLYTE_SDK_USE_STRUCTURED_DATASET=TRUE pytest tests/flytekit/unit tests/flytekit_compatibility
FLYTE_SDK_USE_STRUCTURED_DATASET=FALSE pytest tests/flytekit_compatibility && \
FLYTE_SDK_USE_STRUCTURED_DATASET=TRUE pytest tests/flytekit/unit

requirements-spark2.txt: export CUSTOM_COMPILE_COMMAND := make requirements-spark2.txt
requirements-spark2.txt: requirements-spark2.in install-piptools
Expand Down
33 changes: 24 additions & 9 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# This file is autogenerated by pip-compile with python 3.8
# This file is autogenerated by pip-compile with python 3.9
# To update, run:
#
# make dev-requirements.txt
Expand Down Expand Up @@ -32,6 +32,7 @@ certifi==2021.10.8
# requests
cffi==1.15.0
# via
# -c requirements.txt
# bcrypt
# cryptography
# pynacl
Expand Down Expand Up @@ -71,7 +72,10 @@ croniter==1.3.4
# -c requirements.txt
# flytekit
cryptography==36.0.1
# via paramiko
# via
# -c requirements.txt
# paramiko
# secretstorage
dataclasses-json==0.5.6
# via
# -c requirements.txt
Expand Down Expand Up @@ -112,7 +116,7 @@ docstring-parser==0.13
# flytekit
filelock==3.6.0
# via virtualenv
flyteidl==0.22.3
flyteidl==0.23.0
# via
# -c requirements.txt
# flytekit
Expand All @@ -133,9 +137,9 @@ google-cloud-core==2.2.2
# via google-cloud-bigquery
google-crc32c==1.3.0
# via google-resumable-media
google-resumable-media==2.2.1
google-resumable-media==2.3.0
# via google-cloud-bigquery
googleapis-common-protos==1.54.0
googleapis-common-protos==1.55.0
# via
# -c requirements.txt
# flyteidl
Expand All @@ -156,12 +160,17 @@ idna==3.3
# via
# -c requirements.txt
# requests
importlib-metadata==4.11.1
importlib-metadata==4.11.2
# via
# -c requirements.txt
# keyring
iniconfig==1.1.1
# via pytest
jeepney==0.7.1
# via
# -c requirements.txt
# keyring
# secretstorage
jinja2==3.0.3
# via
# -c requirements.txt
Expand Down Expand Up @@ -275,7 +284,9 @@ pyasn1==0.4.8
pyasn1-modules==0.2.8
# via google-auth
pycparser==2.21
# via cffi
# via
# -c requirements.txt
# cffi
pynacl==1.5.0
# via paramiko
pyparsing==3.0.7
Expand Down Expand Up @@ -307,7 +318,7 @@ python-json-logger==2.0.2
# via
# -c requirements.txt
# flytekit
python-slugify==6.1.0
python-slugify==6.1.1
# via
# -c requirements.txt
# cookiecutter
Expand Down Expand Up @@ -349,6 +360,10 @@ retry==0.9.2
# flytekit
rsa==4.8
# via google-auth
secretstorage==3.3.1
# via
# -c requirements.txt
# keyring
six==1.16.0
# via
# -c requirements.txt
Expand Down Expand Up @@ -399,7 +414,7 @@ urllib3==1.26.8
# flytekit
# requests
# responses
virtualenv==20.13.1
virtualenv==20.13.2
# via pre-commit
websocket-client==0.59.0
# via
Expand Down
25 changes: 15 additions & 10 deletions doc-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# This file is autogenerated by pip-compile with python 3.8
# This file is autogenerated by pip-compile with python 3.9
# To update, run:
#
# make doc-requirements.txt
Expand All @@ -10,7 +10,7 @@ alabaster==0.7.12
# via sphinx
arrow==1.2.2
# via jinja2-time
astroid==2.9.3
astroid==2.10.0
# via sphinx-autoapi
babel==2.9.1
# via sphinx
Expand Down Expand Up @@ -42,7 +42,9 @@ cookiecutter==1.7.3
croniter==1.3.4
# via flytekit
cryptography==36.0.1
# via -r doc-requirements.in
# via
# -r doc-requirements.in
# secretstorage
css-html-js-minify==2.5.5
# via sphinx-material
dataclasses-json==0.5.6
Expand All @@ -61,7 +63,7 @@ docutils==0.17.1
# via
# sphinx
# sphinx-panels
flyteidl==0.22.3
flyteidl==0.23.0
# via flytekit
furo @ git+https://github.com/flyteorg/furo@main
# via -r doc-requirements.in
Expand All @@ -75,10 +77,14 @@ idna==3.3
# via requests
imagesize==1.3.0
# via sphinx
importlib-metadata==4.11.1
importlib-metadata==4.11.2
# via
# keyring
# sphinx
jeepney==0.7.1
# via
# keyring
# secretstorage
jinja2==3.0.3
# via
# cookiecutter
Expand Down Expand Up @@ -125,10 +131,7 @@ protobuf==3.19.4
# googleapis-common-protos
# protoc-gen-swagger
protoc-gen-swagger==0.1.0
# via
# flyteidl
# flytekit

# via flyteidl
py==1.11.0
# via retry
pyarrow==6.0.1
Expand All @@ -149,7 +152,7 @@ python-dateutil==2.8.2
# pandas
python-json-logger==2.0.2
# via flytekit
python-slugify[unidecode]==6.1.0
python-slugify[unidecode]==6.1.1
# via
# cookiecutter
# sphinx-material
Expand All @@ -174,6 +177,8 @@ responses==0.18.0
# via flytekit
retry==0.9.2
# via flytekit
secretstorage==3.3.1
# via keyring
six==1.16.0
# via
# cookiecutter
Expand Down
1 change: 1 addition & 0 deletions flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@
from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins
from flytekit.core.dynamic_workflow_task import dynamic
from flytekit.core.hash import HashMethod
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.map_task import map_task
from flytekit.core.notification import Email, PagerDuty, Slack
Expand Down
20 changes: 20 additions & 0 deletions flytekit/core/hash.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,23 @@
from typing import Callable, Generic, TypeVar

T = TypeVar("T")


class HashOnReferenceMixin(object):
def __hash__(self):
return hash(id(self))


class HashMethod(Generic[T]):
"""
Flyte-specific object used to wrap the hash function for a specific type
"""

def __init__(self, function: Callable[[T], str]):
self._function = function

def calculate(self, obj: T) -> str:
"""
Calculate hash for `obj`.
"""
return self._function(obj)
30 changes: 28 additions & 2 deletions flytekit/core/local_cache.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,42 @@
import base64
from typing import Optional

import cloudpickle
from diskcache import Cache

from flytekit.models.literals import LiteralMap
from flytekit.models.literals import Literal, LiteralCollection, LiteralMap

# Location on the filesystem where serialized objects will be stored
# TODO: read from config
CACHE_LOCATION = "~/.flyte/local-cache"


def _recursive_hash_placement(literal: Literal) -> Literal:
if literal.collection is not None:
literals = [_recursive_hash_placement(literal) for literal in literal.collection.literals]
return Literal(collection=LiteralCollection(literals=literals))
elif literal.map is not None:
literal_map = {}
for key, literal in literal.map.literals.items():
literal_map[key] = _recursive_hash_placement(literal)
return Literal(map=LiteralMap(literal_map))

# Base case
if literal.hash is not None:
return Literal(hash=literal.hash)
else:
return literal


def _calculate_cache_key(task_name: str, cache_version: str, input_literal_map: LiteralMap) -> str:
return f"{task_name}-{cache_version}-{input_literal_map}"
# Traverse the literals and replace the literal with a new literal that only contains the hash
literal_map_overridden = {}
for key, literal in input_literal_map.literals.items():
literal_map_overridden[key] = _recursive_hash_placement(literal)

# Pickle the literal map and use base64 encoding to generate a representation of it
b64_encoded = base64.b64encode(cloudpickle.dumps(LiteralMap(literal_map_overridden)))
return f"{task_name}-{cache_version}-{b64_encoded}"


class LocalTaskCache(object):
Expand Down
35 changes: 29 additions & 6 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@
from abc import ABC, abstractmethod
from typing import NamedTuple, Optional, Type, cast

try:
from typing import Annotated, get_args, get_origin
except ImportError:
from typing_extensions import Annotated, get_origin, get_args

from dataclasses_json import DataClassJsonMixin, dataclass_json
from google.protobuf import json_format as _json_format
from google.protobuf import reflection as _proto_reflection
Expand All @@ -24,9 +19,11 @@
from google.protobuf.struct_pb2 import Struct
from marshmallow_enum import EnumField, LoadDumpOptions
from marshmallow_jsonschema import JSONSchema
from typing_extensions import Annotated, get_args, get_origin

from flytekit.core.annotation import FlyteAnnotation
from flytekit.core.context_manager import FlyteContext
from flytekit.core.hash import HashMethod
from flytekit.core.type_helpers import load_type_from_tag
from flytekit.exceptions import user as user_exceptions
from flytekit.loggers import logger
Expand Down Expand Up @@ -56,10 +53,13 @@ class TypeTransformer(typing.Generic[T]):
Base transformer type that should be implemented for every python native type that can be handled by flytekit
"""

def __init__(self, name: str, t: Type[T], enable_type_assertions: bool = True):
def __init__(self, name: str, t: Type[T], enable_type_assertions: bool = True, hash_overridable: bool = False):
self._t = t
self._name = name
self._type_assertions_enabled = enable_type_assertions
# `hash_overridable` indicates that the literals produced by this type transformer can set their hashes if needed.
# See (link to documentation where this feature is explained).
self._hash_overridable = hash_overridable

@property
def name(self):
Expand All @@ -79,6 +79,10 @@ def type_assertions_enabled(self) -> bool:
"""
return self._type_assertions_enabled

@property
def hash_overridable(self) -> bool:
return self._hash_overridable

def assert_type(self, t: Type[T], v: T):
if not hasattr(t, "__origin__") and not isinstance(v, t):
raise TypeError(f"Type of Val '{v}' is not an instance of {t}")
Expand Down Expand Up @@ -640,7 +644,25 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type
transformer = cls.get_transformer(python_type)
if transformer.type_assertions_enabled:
transformer.assert_type(python_type, python_val)

# In case the value is an annotated type we inspect the annotations and look for hash-related annotations.
hash = None
if transformer.hash_overridable and get_origin(python_type) is Annotated:
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
# We are now dealing with one of two cases:
# 1. The annotated type is a `HashMethod`, which indicates that we should we should produce the hash using
# the method indicated in the annotation.
# 2. The annotated type is being used for a different purpose other than calculating hash values, in which case
# we should just continue.
for annotation in get_args(python_type)[1:]:
if not isinstance(annotation, HashMethod):
continue
hash = annotation.calculate(python_val)
break

lv = transformer.to_literal(ctx, python_val, python_type, expected)

if hash is not None:
lv.hash = hash
return lv

@classmethod
Expand Down Expand Up @@ -852,6 +874,7 @@ def to_literal(
for k, v in python_val.items():
if type(k) != str:
raise ValueError("Flyte MapType expects all keys to be strings")
# TODO: log a warning for Annotated objects that contain HashMethod
k_type, v_type = self.get_dict_types(python_type)
lit_map[k] = TypeEngine.to_literal(ctx, v, v_type, expected.map_value_type)
return Literal(map=LiteralMap(literals=lit_map))
Expand Down
Loading