Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add caching to local execution #592

Merged
merged 27 commits into from
Aug 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6b58da3
Add initial structure of local cache using joblib
Aug 12, 2021
e8012e8
Add bogus unit tests
eapolinario Aug 12, 2021
0529f7c
Add clear-cache make target
eapolinario Aug 12, 2021
1cc9bdb
Run `make requirements`
eapolinario Aug 12, 2021
fa951a5
Remove a few TODO
eapolinario Aug 12, 2021
5b137a8
Linting
eapolinario Aug 12, 2021
ce947a2
More linting
eapolinario Aug 12, 2021
ec9fdfe
Define LocalCache and replace uses of it in base_task.py
eapolinario Aug 12, 2021
8dfdd1e
LocalCache type hints
eapolinario Aug 12, 2021
790284e
Comment use of LocalCache in base_task.py
eapolinario Aug 12, 2021
ffc4f52
Remove use of Optional from declaration of _memory
eapolinario Aug 12, 2021
9a486e1
Move comment closer to invocation of 'dispatch_execute_func'
eapolinario Aug 12, 2021
44b4858
Use global counter to validate cache hits
eapolinario Aug 13, 2021
35d4a2b
Use ~/.flyte/local-cache as the default location
eapolinario Aug 13, 2021
01ca49c
Force initialization in LocalCache.clear()
eapolinario Aug 13, 2021
ffb66a0
Add pyflyte local-cache command
eapolinario Aug 13, 2021
823dffc
Remove clear-cache make target
eapolinario Aug 13, 2021
9deff96
Linting.
eapolinario Aug 13, 2021
f0c8aed
More linting
eapolinario Aug 13, 2021
f9fd960
Add more tests
eapolinario Aug 16, 2021
a00ead1
Add constant for cache verbosity
eapolinario Aug 17, 2021
130b432
Add task name to cache key definition
eapolinario Aug 17, 2021
5c991e3
Add tests containing complex inputs and outputs
eapolinario Aug 17, 2021
e3170b0
Linting
eapolinario Aug 17, 2021
5b83e76
Comment the test used to confirm the use of task names in cache keys
eapolinario Aug 17, 2021
045cc3b
More linting
eapolinario Aug 17, 2021
9cbc70a
Fix linting
eapolinario Aug 17, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ jinja2==3.0.1
# via
# -c requirements.txt
# pytest-flyte
joblib==1.0.1
# via
# -c requirements.txt
# flytekit
jsonschema==3.2.0
# via
# -c requirements.txt
Expand Down Expand Up @@ -274,7 +278,7 @@ requests==2.26.0
# docker-compose
# flytekit
# responses
responses==0.13.3
responses==0.13.4
# via
# -c requirements.txt
# flytekit
Expand Down Expand Up @@ -345,7 +349,7 @@ websocket-client==0.59.0
# via
# docker
# docker-compose
wheel==0.36.2
wheel==0.37.0
# via
# -c requirements.txt
# flytekit
Expand Down
14 changes: 8 additions & 6 deletions doc-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ black==21.7b0
# via papermill
bleach==4.0.0
# via nbconvert
boto3==1.18.15
boto3==1.18.19
# via sagemaker-training
botocore==1.21.15
botocore==1.21.19
# via
# boto3
# s3transfer
Expand Down Expand Up @@ -96,7 +96,7 @@ git+git://github.com/flyteorg/furo@main
# via -r doc-requirements.in
gevent==21.8.0
# via sagemaker-training
greenlet==1.1.0
greenlet==1.1.1
# via gevent
grpcio==1.39.0
# via
Expand Down Expand Up @@ -131,6 +131,8 @@ jmespath==0.10.0
# via
# boto3
# botocore
joblib==1.0.1
# via flytekit
jsonschema==3.2.0
# via nbformat
jupyter-client==6.1.12
Expand Down Expand Up @@ -284,7 +286,7 @@ requests==2.26.0
# papermill
# responses
# sphinx
responses==0.13.3
responses==0.13.4
# via flytekit
retry==0.9.2
# via flytekit
Expand Down Expand Up @@ -343,7 +345,7 @@ sphinx-gallery==0.9.0
# via -r doc-requirements.in
sphinx-material==0.0.34
# via -r doc-requirements.in
sphinx-prompt==1.4.0
sphinx-prompt==1.5.0
# via -r doc-requirements.in
sphinxcontrib-applehelp==1.0.2
# via sphinx
Expand Down Expand Up @@ -409,7 +411,7 @@ webencodings==0.5.1
# via bleach
werkzeug==2.0.1
# via sagemaker-training
wheel==0.36.2
wheel==0.37.0
# via flytekit
wrapt==1.12.1
# via
Expand Down
22 changes: 22 additions & 0 deletions flytekit/clis/sdk_in_container/local_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import click

from flytekit.core.local_cache import LocalCache


@click.group("local-cache")
def local_cache():
"""
Interact with the local cache.
"""
pass


@click.command("clear")
def clear_local_cache():
"""
This command will remove all stored objects from local cache.
"""
LocalCache.clear()


local_cache.add_command(clear_local_cache)
2 changes: 2 additions & 0 deletions flytekit/clis/sdk_in_container/pyflyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from flytekit.clis.sdk_in_container.constants import CTX_PACKAGES
from flytekit.clis.sdk_in_container.fast_register import fast_register
from flytekit.clis.sdk_in_container.launch_plan import launch_plans
from flytekit.clis.sdk_in_container.local_cache import local_cache
from flytekit.clis.sdk_in_container.package import package
from flytekit.clis.sdk_in_container.register import register
from flytekit.clis.sdk_in_container.serialize import serialize
Expand Down Expand Up @@ -110,6 +111,7 @@ def update_configuration_file(config_file_path):
main.add_command(serialize)
main.add_command(launch_plans)
main.add_command(package)
main.add_command(local_cache)

if __name__ == "__main__":
main()
21 changes: 20 additions & 1 deletion flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)
from flytekit.core.docstring import Docstring
from flytekit.core.interface import Interface, transform_interface_to_typed_interface
from flytekit.core.local_cache import LocalCache
from flytekit.core.promise import (
Promise,
VoidPromise,
Expand Down Expand Up @@ -217,6 +218,14 @@ def get_input_types(self) -> Dict[str, type]:
"""
return None

def _dispatch_execute(
self, ctx: FlyteContext, task_name: str, input_literal_map: _literal_models.LiteralMap, cache_version: str
) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]:
"""
Thin wrapper around the actual call to 'dispatch_execute'.
"""
return self.dispatch_execute(ctx, input_literal_map)

def _local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]:
"""
This code is used only in the case when we want to dispatch_execute with outputs from a previous node
Expand All @@ -236,7 +245,17 @@ def _local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], P
)
input_literal_map = _literal_models.LiteralMap(literals=kwargs)

outputs_literal_map = self.dispatch_execute(ctx, input_literal_map)
# if metadata.cache is set, check memoized version
if self._metadata.cache:
# The cache key is composed of '(task name, input_literal_map, cache_version)', i.e. all other parameters
# passed to the call to 'dispatch_execute' are ignored
dispatch_execute_func = LocalCache.cache(self._dispatch_execute, ignore=["self", "ctx"])
else:
dispatch_execute_func = self._dispatch_execute
# The local cache uses the function signature (and an ignore list) to calculate the cache key. In other
# words, we need the cache version to be present in the function signature so that we can respect the current
# cache semantics where changing the cache version of a cached Task creates a separate entry in the cache.
outputs_literal_map = dispatch_execute_func(ctx, self.name, input_literal_map, self._metadata.cache_version)
outputs_literals = outputs_literal_map.literals

# TODO maybe this is the part that should be done for local execution, we pass the outputs to some special
Expand Down
31 changes: 31 additions & 0 deletions flytekit/core/local_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Callable, List, Optional

from joblib import Memory

# Location in the file system where serialized objects will be stored
# TODO: read from config
CACHE_LOCATION = "~/.flyte/local-cache"
# TODO: read from config
CACHE_VERBOSITY = 5


class LocalCache(object):
_memory: Memory
_initialized: bool = False

@staticmethod
def initialize():
LocalCache._memory = Memory(CACHE_LOCATION, verbose=CACHE_VERBOSITY)
LocalCache._initialized = True

@staticmethod
def cache(func: Callable, ignore: Optional[List[str]] = None):
if not LocalCache._initialized:
LocalCache.initialize()
return LocalCache._memory.cache(func, ignore=ignore)

@staticmethod
def clear():
if not LocalCache._initialized:
LocalCache.initialize()
LocalCache._memory.clear()
12 changes: 7 additions & 5 deletions requirements-spark2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ black==21.7b0
# via papermill
bleach==4.0.0
# via nbconvert
boto3==1.18.15
boto3==1.18.19
# via sagemaker-training
botocore==1.21.15
botocore==1.21.19
# via
# boto3
# s3transfer
Expand Down Expand Up @@ -78,7 +78,7 @@ flyteidl==0.19.19
# via flytekit
gevent==21.8.0
# via sagemaker-training
greenlet==1.1.0
greenlet==1.1.1
# via gevent
grpcio==1.39.0
# via flytekit
Expand Down Expand Up @@ -106,6 +106,8 @@ jmespath==0.10.0
# via
# boto3
# botocore
joblib==1.0.1
# via flytekit
jsonschema==3.2.0
# via nbformat
jupyter-client==6.1.12
Expand Down Expand Up @@ -245,7 +247,7 @@ requests==2.26.0
# flytekit
# papermill
# responses
responses==0.13.3
responses==0.13.4
# via flytekit
retry==0.9.2
# via flytekit
Expand Down Expand Up @@ -322,7 +324,7 @@ webencodings==0.5.1
# via bleach
werkzeug==2.0.1
# via sagemaker-training
wheel==0.36.2
wheel==0.37.0
# via flytekit
wrapt==1.12.1
# via
Expand Down
12 changes: 7 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ black==21.7b0
# via papermill
bleach==4.0.0
# via nbconvert
boto3==1.18.15
boto3==1.18.19
# via sagemaker-training
botocore==1.21.15
botocore==1.21.19
# via
# boto3
# s3transfer
Expand Down Expand Up @@ -78,7 +78,7 @@ flyteidl==0.19.19
# via flytekit
gevent==21.8.0
# via sagemaker-training
greenlet==1.1.0
greenlet==1.1.1
# via gevent
grpcio==1.39.0
# via flytekit
Expand Down Expand Up @@ -106,6 +106,8 @@ jmespath==0.10.0
# via
# boto3
# botocore
joblib==1.0.1
# via flytekit
jsonschema==3.2.0
# via nbformat
jupyter-client==6.1.12
Expand Down Expand Up @@ -245,7 +247,7 @@ requests==2.26.0
# flytekit
# papermill
# responses
responses==0.13.3
responses==0.13.4
# via flytekit
retry==0.9.2
# via flytekit
Expand Down Expand Up @@ -322,7 +324,7 @@ webencodings==0.5.1
# via bleach
werkzeug==2.0.1
# via sagemaker-training
wheel==0.36.2
wheel==0.37.0
# via flytekit
wrapt==1.12.1
# via
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
"docker-image-py>=0.1.10",
"singledispatchmethod; python_version < '3.8.0'",
"docstring-parser>=0.9.0",
"joblib>=1.0.0",
],
extras_require=extras_require,
scripts=[
Expand Down
Loading