Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add caching to local execution #592

Merged
merged 27 commits into from
Aug 17, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6b58da3
Add initial structure of local cache using joblib
Aug 12, 2021
e8012e8
Add bogus unit tests
eapolinario Aug 12, 2021
0529f7c
Add clear-cache make target
eapolinario Aug 12, 2021
1cc9bdb
Run `make requirements`
eapolinario Aug 12, 2021
fa951a5
Remove a few TODO
eapolinario Aug 12, 2021
5b137a8
Linting
eapolinario Aug 12, 2021
ce947a2
More linting
eapolinario Aug 12, 2021
ec9fdfe
Define LocalCache and replace uses of it in base_task.py
eapolinario Aug 12, 2021
8dfdd1e
LocalCache type hints
eapolinario Aug 12, 2021
790284e
Comment use of LocalCache in base_task.py
eapolinario Aug 12, 2021
ffc4f52
Remove use of Optional from declaration of _memory
eapolinario Aug 12, 2021
9a486e1
Move comment closer to invocation of 'dispatch_execute_func'
eapolinario Aug 12, 2021
44b4858
Use global counter to validate cache hits
eapolinario Aug 13, 2021
35d4a2b
Use ~/.flyte/local-cache as the default location
eapolinario Aug 13, 2021
01ca49c
Force initialization in LocalCache.clear()
eapolinario Aug 13, 2021
ffb66a0
Add pyflyte local-cache command
eapolinario Aug 13, 2021
823dffc
Remove clear-cache make target
eapolinario Aug 13, 2021
9deff96
Linting.
eapolinario Aug 13, 2021
f0c8aed
More linting
eapolinario Aug 13, 2021
f9fd960
Add more tests
eapolinario Aug 16, 2021
a00ead1
Add constant for cache verbosity
eapolinario Aug 17, 2021
130b432
Add task name to cache key definition
eapolinario Aug 17, 2021
5c991e3
Add tests containing complex inputs and outputs
eapolinario Aug 17, 2021
e3170b0
Linting
eapolinario Aug 17, 2021
5b83e76
Comment the test used to confirm the use of task names in cache keys
eapolinario Aug 17, 2021
045cc3b
More linting
eapolinario Aug 17, 2021
9cbc70a
Fix linting
eapolinario Aug 17, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,8 @@ update_version:

grep "$(PLACEHOLDER)" "setup.py"
sed -i "s/$(PLACEHOLDER)/__version__ = \"${VERSION}\"/g" "setup.py"

# TODO
.PHONY: clear-cache
clear-cache:
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
python -m 'from flytekit.core.local_cache import LocalCache; LocalCache.clear()'
8 changes: 6 additions & 2 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ jinja2==3.0.1
# via
# -c requirements.txt
# pytest-flyte
joblib==1.0.1
# via
# -c requirements.txt
# flytekit
jsonschema==3.2.0
# via
# -c requirements.txt
Expand Down Expand Up @@ -274,7 +278,7 @@ requests==2.26.0
# docker-compose
# flytekit
# responses
responses==0.13.3
responses==0.13.4
# via
# -c requirements.txt
# flytekit
Expand Down Expand Up @@ -345,7 +349,7 @@ websocket-client==0.59.0
# via
# docker
# docker-compose
wheel==0.36.2
wheel==0.37.0
# via
# -c requirements.txt
# flytekit
Expand Down
14 changes: 8 additions & 6 deletions doc-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ black==21.7b0
# via papermill
bleach==4.0.0
# via nbconvert
boto3==1.18.15
boto3==1.18.19
# via sagemaker-training
botocore==1.21.15
botocore==1.21.19
# via
# boto3
# s3transfer
Expand Down Expand Up @@ -96,7 +96,7 @@ git+git://github.com/flyteorg/furo@main
# via -r doc-requirements.in
gevent==21.8.0
# via sagemaker-training
greenlet==1.1.0
greenlet==1.1.1
# via gevent
grpcio==1.39.0
# via
Expand Down Expand Up @@ -131,6 +131,8 @@ jmespath==0.10.0
# via
# boto3
# botocore
joblib==1.0.1
# via flytekit
jsonschema==3.2.0
# via nbformat
jupyter-client==6.1.12
Expand Down Expand Up @@ -284,7 +286,7 @@ requests==2.26.0
# papermill
# responses
# sphinx
responses==0.13.3
responses==0.13.4
# via flytekit
retry==0.9.2
# via flytekit
Expand Down Expand Up @@ -343,7 +345,7 @@ sphinx-gallery==0.9.0
# via -r doc-requirements.in
sphinx-material==0.0.34
# via -r doc-requirements.in
sphinx-prompt==1.4.0
sphinx-prompt==1.5.0
# via -r doc-requirements.in
sphinxcontrib-applehelp==1.0.2
# via sphinx
Expand Down Expand Up @@ -409,7 +411,7 @@ webencodings==0.5.1
# via bleach
werkzeug==2.0.1
# via sagemaker-training
wheel==0.36.2
wheel==0.37.0
# via flytekit
wrapt==1.12.1
# via
Expand Down
22 changes: 21 additions & 1 deletion flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union

from joblib import Memory

from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.tasks.sdk_runnable import ExecutionParameters
from flytekit.core.context_manager import (
Expand Down Expand Up @@ -53,6 +55,10 @@
from flytekit.models.interface import Variable
from flytekit.models.security import SecurityContext

# TODO: move the definition of `memory` to a separate file
CACHE_LOCATION = "/tmp/cache-location"
memory = Memory(CACHE_LOCATION, verbose=0)


def kwtypes(**kwargs) -> Dict[str, Type]:
"""
Expand Down Expand Up @@ -217,6 +223,14 @@ def get_input_types(self) -> Dict[str, type]:
"""
return None

def _local_dispatch_execute(
self, ctx: FlyteContext, input_literal_map: _literal_models.LiteralMap, cache_version: str
) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]:
"""
TODO: explain why we need this wrapper.
"""
return self.dispatch_execute(ctx, input_literal_map)

def _local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]:
"""
This code is used only in the case when we want to dispatch_execute with outputs from a previous node
Expand All @@ -236,7 +250,13 @@ def _local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], P
)
input_literal_map = _literal_models.LiteralMap(literals=kwargs)

outputs_literal_map = self.dispatch_execute(ctx, input_literal_map)
# TODO: improve comment
# if metadata.cache is set, check memoized version including cache_version
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
if self._metadata.cache:
dispatch_execute = memory.cache(self._local_dispatch_execute, ignore=["self", "ctx"])
else:
dispatch_execute = self._local_dispatch_execute
outputs_literal_map = dispatch_execute(ctx, input_literal_map, self._metadata.cache_version)
outputs_literals = outputs_literal_map.literals

# TODO maybe this is the part that should be done for local execution, we pass the outputs to some special
Expand Down
12 changes: 7 additions & 5 deletions requirements-spark2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ black==21.7b0
# via papermill
bleach==4.0.0
# via nbconvert
boto3==1.18.15
boto3==1.18.19
# via sagemaker-training
botocore==1.21.15
botocore==1.21.19
# via
# boto3
# s3transfer
Expand Down Expand Up @@ -78,7 +78,7 @@ flyteidl==0.19.19
# via flytekit
gevent==21.8.0
# via sagemaker-training
greenlet==1.1.0
greenlet==1.1.1
# via gevent
grpcio==1.39.0
# via flytekit
Expand Down Expand Up @@ -106,6 +106,8 @@ jmespath==0.10.0
# via
# boto3
# botocore
joblib==1.0.1
# via flytekit
jsonschema==3.2.0
# via nbformat
jupyter-client==6.1.12
Expand Down Expand Up @@ -245,7 +247,7 @@ requests==2.26.0
# flytekit
# papermill
# responses
responses==0.13.3
responses==0.13.4
# via flytekit
retry==0.9.2
# via flytekit
Expand Down Expand Up @@ -322,7 +324,7 @@ webencodings==0.5.1
# via bleach
werkzeug==2.0.1
# via sagemaker-training
wheel==0.36.2
wheel==0.37.0
# via flytekit
wrapt==1.12.1
# via
Expand Down
12 changes: 7 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ black==21.7b0
# via papermill
bleach==4.0.0
# via nbconvert
boto3==1.18.15
boto3==1.18.19
# via sagemaker-training
botocore==1.21.15
botocore==1.21.19
# via
# boto3
# s3transfer
Expand Down Expand Up @@ -78,7 +78,7 @@ flyteidl==0.19.19
# via flytekit
gevent==21.8.0
# via sagemaker-training
greenlet==1.1.0
greenlet==1.1.1
# via gevent
grpcio==1.39.0
# via flytekit
Expand Down Expand Up @@ -106,6 +106,8 @@ jmespath==0.10.0
# via
# boto3
# botocore
joblib==1.0.1
# via flytekit
jsonschema==3.2.0
# via nbformat
jupyter-client==6.1.12
Expand Down Expand Up @@ -245,7 +247,7 @@ requests==2.26.0
# flytekit
# papermill
# responses
responses==0.13.3
responses==0.13.4
# via flytekit
retry==0.9.2
# via flytekit
Expand Down Expand Up @@ -322,7 +324,7 @@ webencodings==0.5.1
# via bleach
werkzeug==2.0.1
# via sagemaker-training
wheel==0.36.2
wheel==0.37.0
# via flytekit
wrapt==1.12.1
# via
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
"docker-image-py>=0.1.10",
"singledispatchmethod; python_version < '3.8.0'",
"docstring-parser>=0.9.0",
"joblib>=1.0.0",
],
extras_require=extras_require,
scripts=[
Expand Down
48 changes: 48 additions & 0 deletions tests/flytekit/unit/core/test_local_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from flytekit.core.task import task
from flytekit.core.workflow import workflow


def test_single_task_workflow():
@task(cache=True, cache_version="v1")
def is_even(n: int) -> bool:
import time

time.sleep(2)
return n % 2 == 0

@task(cache=False)
def uncached_task(a: int, b: int) -> int:
return a + b

@workflow
def check_evenness(n: int) -> bool:
uncached_task(a=n, b=n)
return is_even(n=n)

assert check_evenness(n=1) is False
assert check_evenness(n=8) is True


def test_shared_tasks_in_two_separate_workflows():
@task(cache=True, cache_version="0.0.1")
def is_even(n: int) -> bool:
import time

time.sleep(2)
return n % 2 == 0

@workflow
def check_evenness_wf1(n: int) -> bool:
return is_even(n=n)

@workflow
def check_evenness_wf2(n: int) -> bool:
return is_even(n=n)

assert check_evenness_wf1(n=42) is True
assert check_evenness_wf1(n=99) is False

# The next two executions of the *_wf2 workflow are going to
# hit the cache for the calls to `is_even`
assert check_evenness_wf2(n=42) is True
assert check_evenness_wf2(n=99) is False