Skip to content

Commit

Permalink
Add caching to local execution (#592)
Browse files Browse the repository at this point in the history
* Add initial structure of local cache using joblib

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add bogus unit tests

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add clear-cache make target

Signed-off-by: Eduardo Apolinario <[email protected]>

* Run `make requirements`

Signed-off-by: Eduardo Apolinario <[email protected]>

* Remove a few TODO

Signed-off-by: Eduardo Apolinario <[email protected]>

* Linting

Signed-off-by: Eduardo Apolinario <[email protected]>

* More linting

Signed-off-by: Eduardo Apolinario <[email protected]>

* Define LocalCache and replace uses of it in base_task.py

Signed-off-by: Eduardo Apolinario <[email protected]>

* LocalCache type hints

Signed-off-by: Eduardo Apolinario <[email protected]>

* Comment use of LocalCache in base_task.py

Signed-off-by: Eduardo Apolinario <[email protected]>

* Remove use of Optional from declaration of _memory

Signed-off-by: Eduardo Apolinario <[email protected]>

* Move comment closer to invocation of 'dispatch_execute_func'

Signed-off-by: Eduardo Apolinario <[email protected]>

* Use global counter to validate cache hits

Signed-off-by: Eduardo Apolinario <[email protected]>

* Use ~/.flyte/local-cache as the default location

Signed-off-by: Eduardo Apolinario <[email protected]>

* Force initialization in LocalCache.clear()

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add pyflyte local-cache command

Signed-off-by: Eduardo Apolinario <[email protected]>

* Remove clear-cache make target

Signed-off-by: Eduardo Apolinario <[email protected]>

* Linting.

Signed-off-by: Eduardo Apolinario <[email protected]>

* More linting

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add more tests

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add constant for cache verbosity

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add task name to cache key definition

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add tests containing complex inputs and outputs

Signed-off-by: Eduardo Apolinario <[email protected]>

* Linting

Signed-off-by: Eduardo Apolinario <[email protected]>

* Comment the test used to confirm the use of task names in cache keys

Signed-off-by: Eduardo Apolinario <[email protected]>

* More linting

Signed-off-by: Eduardo Apolinario <[email protected]>

* Fix linting

Signed-off-by: Eduardo Apolinario <[email protected]>

Co-authored-by: eduardo apolinario <[email protected]>
  • Loading branch information
eapolinario and eduardo apolinario authored Aug 17, 2021
1 parent a137472 commit 86d3368
Show file tree
Hide file tree
Showing 10 changed files with 356 additions and 19 deletions.
8 changes: 6 additions & 2 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ jinja2==3.0.1
# via
# -c requirements.txt
# pytest-flyte
joblib==1.0.1
# via
# -c requirements.txt
# flytekit
jsonschema==3.2.0
# via
# -c requirements.txt
Expand Down Expand Up @@ -274,7 +278,7 @@ requests==2.26.0
# docker-compose
# flytekit
# responses
responses==0.13.3
responses==0.13.4
# via
# -c requirements.txt
# flytekit
Expand Down Expand Up @@ -345,7 +349,7 @@ websocket-client==0.59.0
# via
# docker
# docker-compose
wheel==0.36.2
wheel==0.37.0
# via
# -c requirements.txt
# flytekit
Expand Down
14 changes: 8 additions & 6 deletions doc-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ black==21.7b0
# via papermill
bleach==4.0.0
# via nbconvert
boto3==1.18.15
boto3==1.18.19
# via sagemaker-training
botocore==1.21.15
botocore==1.21.19
# via
# boto3
# s3transfer
Expand Down Expand Up @@ -96,7 +96,7 @@ git+git://github.com/flyteorg/furo@main
# via -r doc-requirements.in
gevent==21.8.0
# via sagemaker-training
greenlet==1.1.0
greenlet==1.1.1
# via gevent
grpcio==1.39.0
# via
Expand Down Expand Up @@ -131,6 +131,8 @@ jmespath==0.10.0
# via
# boto3
# botocore
joblib==1.0.1
# via flytekit
jsonschema==3.2.0
# via nbformat
jupyter-client==6.1.12
Expand Down Expand Up @@ -284,7 +286,7 @@ requests==2.26.0
# papermill
# responses
# sphinx
responses==0.13.3
responses==0.13.4
# via flytekit
retry==0.9.2
# via flytekit
Expand Down Expand Up @@ -343,7 +345,7 @@ sphinx-gallery==0.9.0
# via -r doc-requirements.in
sphinx-material==0.0.34
# via -r doc-requirements.in
sphinx-prompt==1.4.0
sphinx-prompt==1.5.0
# via -r doc-requirements.in
sphinxcontrib-applehelp==1.0.2
# via sphinx
Expand Down Expand Up @@ -409,7 +411,7 @@ webencodings==0.5.1
# via bleach
werkzeug==2.0.1
# via sagemaker-training
wheel==0.36.2
wheel==0.37.0
# via flytekit
wrapt==1.12.1
# via
Expand Down
22 changes: 22 additions & 0 deletions flytekit/clis/sdk_in_container/local_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import click

from flytekit.core.local_cache import LocalCache


@click.group("local-cache")
def local_cache():
"""
Interact with the local cache.
"""
pass


@click.command("clear")
def clear_local_cache():
"""
This command will remove all stored objects from local cache.
"""
LocalCache.clear()


local_cache.add_command(clear_local_cache)
2 changes: 2 additions & 0 deletions flytekit/clis/sdk_in_container/pyflyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from flytekit.clis.sdk_in_container.constants import CTX_PACKAGES
from flytekit.clis.sdk_in_container.fast_register import fast_register
from flytekit.clis.sdk_in_container.launch_plan import launch_plans
from flytekit.clis.sdk_in_container.local_cache import local_cache
from flytekit.clis.sdk_in_container.package import package
from flytekit.clis.sdk_in_container.register import register
from flytekit.clis.sdk_in_container.serialize import serialize
Expand Down Expand Up @@ -110,6 +111,7 @@ def update_configuration_file(config_file_path):
main.add_command(serialize)
main.add_command(launch_plans)
main.add_command(package)
main.add_command(local_cache)

if __name__ == "__main__":
main()
21 changes: 20 additions & 1 deletion flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
SerializationSettings,
)
from flytekit.core.interface import Interface, transform_interface_to_typed_interface
from flytekit.core.local_cache import LocalCache
from flytekit.core.promise import (
Promise,
VoidPromise,
Expand Down Expand Up @@ -216,6 +217,14 @@ def get_input_types(self) -> Dict[str, type]:
"""
return None

def _dispatch_execute(
self, ctx: FlyteContext, task_name: str, input_literal_map: _literal_models.LiteralMap, cache_version: str
) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]:
"""
Thin wrapper around the actual call to 'dispatch_execute'.
"""
return self.dispatch_execute(ctx, input_literal_map)

def _local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]:
"""
This code is used only in the case when we want to dispatch_execute with outputs from a previous node
Expand All @@ -235,7 +244,17 @@ def _local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], P
)
input_literal_map = _literal_models.LiteralMap(literals=kwargs)

outputs_literal_map = self.dispatch_execute(ctx, input_literal_map)
# if metadata.cache is set, check memoized version
if self._metadata.cache:
# The cache key is composed of '(task name, input_literal_map, cache_version)', i.e. all other parameters
# passed to the call to 'dispatch_execute' are ignored
dispatch_execute_func = LocalCache.cache(self._dispatch_execute, ignore=["self", "ctx"])
else:
dispatch_execute_func = self._dispatch_execute
# The local cache uses the function signature (and an ignore list) to calculate the cache key. In other
# words, we need the cache version to be present in the function signature so that we can respect the current
# cache semantics where changing the cache version of a cached Task creates a separate entry in the cache.
outputs_literal_map = dispatch_execute_func(ctx, self.name, input_literal_map, self._metadata.cache_version)
outputs_literals = outputs_literal_map.literals

# TODO maybe this is the part that should be done for local execution, we pass the outputs to some special
Expand Down
31 changes: 31 additions & 0 deletions flytekit/core/local_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Callable, List, Optional

from joblib import Memory

# Location in the file system where serialized objects will be stored
# TODO: read from config
CACHE_LOCATION = "~/.flyte/local-cache"
# TODO: read from config
CACHE_VERBOSITY = 5


class LocalCache(object):
_memory: Memory
_initialized: bool = False

@staticmethod
def initialize():
LocalCache._memory = Memory(CACHE_LOCATION, verbose=CACHE_VERBOSITY)
LocalCache._initialized = True

@staticmethod
def cache(func: Callable, ignore: Optional[List[str]] = None):
if not LocalCache._initialized:
LocalCache.initialize()
return LocalCache._memory.cache(func, ignore=ignore)

@staticmethod
def clear():
if not LocalCache._initialized:
LocalCache.initialize()
LocalCache._memory.clear()
12 changes: 7 additions & 5 deletions requirements-spark2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ black==21.7b0
# via papermill
bleach==4.0.0
# via nbconvert
boto3==1.18.15
boto3==1.18.19
# via sagemaker-training
botocore==1.21.15
botocore==1.21.19
# via
# boto3
# s3transfer
Expand Down Expand Up @@ -78,7 +78,7 @@ flyteidl==0.19.19
# via flytekit
gevent==21.8.0
# via sagemaker-training
greenlet==1.1.0
greenlet==1.1.1
# via gevent
grpcio==1.39.0
# via flytekit
Expand Down Expand Up @@ -106,6 +106,8 @@ jmespath==0.10.0
# via
# boto3
# botocore
joblib==1.0.1
# via flytekit
jsonschema==3.2.0
# via nbformat
jupyter-client==6.1.12
Expand Down Expand Up @@ -245,7 +247,7 @@ requests==2.26.0
# flytekit
# papermill
# responses
responses==0.13.3
responses==0.13.4
# via flytekit
retry==0.9.2
# via flytekit
Expand Down Expand Up @@ -322,7 +324,7 @@ webencodings==0.5.1
# via bleach
werkzeug==2.0.1
# via sagemaker-training
wheel==0.36.2
wheel==0.37.0
# via flytekit
wrapt==1.12.1
# via
Expand Down
12 changes: 7 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ black==21.7b0
# via papermill
bleach==4.0.0
# via nbconvert
boto3==1.18.15
boto3==1.18.19
# via sagemaker-training
botocore==1.21.15
botocore==1.21.19
# via
# boto3
# s3transfer
Expand Down Expand Up @@ -78,7 +78,7 @@ flyteidl==0.19.19
# via flytekit
gevent==21.8.0
# via sagemaker-training
greenlet==1.1.0
greenlet==1.1.1
# via gevent
grpcio==1.39.0
# via flytekit
Expand Down Expand Up @@ -106,6 +106,8 @@ jmespath==0.10.0
# via
# boto3
# botocore
joblib==1.0.1
# via flytekit
jsonschema==3.2.0
# via nbformat
jupyter-client==6.1.12
Expand Down Expand Up @@ -245,7 +247,7 @@ requests==2.26.0
# flytekit
# papermill
# responses
responses==0.13.3
responses==0.13.4
# via flytekit
retry==0.9.2
# via flytekit
Expand Down Expand Up @@ -322,7 +324,7 @@ webencodings==0.5.1
# via bleach
werkzeug==2.0.1
# via sagemaker-training
wheel==0.36.2
wheel==0.37.0
# via flytekit
wrapt==1.12.1
# via
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
"docker-image-py>=0.1.10",
"singledispatchmethod; python_version < '3.8.0'",
"docstring-parser>=0.9.0",
"joblib>=1.0.0",
],
extras_require=extras_require,
scripts=[
Expand Down
Loading

0 comments on commit 86d3368

Please sign in to comment.