diff --git a/dev-requirements.txt b/dev-requirements.txt index ce280c8698..cf89ebdbb2 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -128,6 +128,10 @@ jinja2==3.0.1 # via # -c requirements.txt # pytest-flyte +joblib==1.0.1 + # via + # -c requirements.txt + # flytekit jsonschema==3.2.0 # via # -c requirements.txt @@ -274,7 +278,7 @@ requests==2.26.0 # docker-compose # flytekit # responses -responses==0.13.3 +responses==0.13.4 # via # -c requirements.txt # flytekit @@ -345,7 +349,7 @@ websocket-client==0.59.0 # via # docker # docker-compose -wheel==0.36.2 +wheel==0.37.0 # via # -c requirements.txt # flytekit diff --git a/doc-requirements.txt b/doc-requirements.txt index 17d10fc24d..bc8ec36354 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -39,9 +39,9 @@ black==21.7b0 # via papermill bleach==4.0.0 # via nbconvert -boto3==1.18.15 +boto3==1.18.19 # via sagemaker-training -botocore==1.21.15 +botocore==1.21.19 # via # boto3 # s3transfer @@ -96,7 +96,7 @@ git+git://github.com/flyteorg/furo@main # via -r doc-requirements.in gevent==21.8.0 # via sagemaker-training -greenlet==1.1.0 +greenlet==1.1.1 # via gevent grpcio==1.39.0 # via @@ -131,6 +131,8 @@ jmespath==0.10.0 # via # boto3 # botocore +joblib==1.0.1 + # via flytekit jsonschema==3.2.0 # via nbformat jupyter-client==6.1.12 @@ -284,7 +286,7 @@ requests==2.26.0 # papermill # responses # sphinx -responses==0.13.3 +responses==0.13.4 # via flytekit retry==0.9.2 # via flytekit @@ -343,7 +345,7 @@ sphinx-gallery==0.9.0 # via -r doc-requirements.in sphinx-material==0.0.34 # via -r doc-requirements.in -sphinx-prompt==1.4.0 +sphinx-prompt==1.5.0 # via -r doc-requirements.in sphinxcontrib-applehelp==1.0.2 # via sphinx @@ -409,7 +411,7 @@ webencodings==0.5.1 # via bleach werkzeug==2.0.1 # via sagemaker-training -wheel==0.36.2 +wheel==0.37.0 # via flytekit wrapt==1.12.1 # via diff --git a/flytekit/clis/sdk_in_container/local_cache.py b/flytekit/clis/sdk_in_container/local_cache.py new file mode 100644 index 0000000000..80c29ad262 --- /dev/null +++ b/flytekit/clis/sdk_in_container/local_cache.py @@ -0,0 +1,22 @@ +import click + +from flytekit.core.local_cache import LocalCache + + +@click.group("local-cache") +def local_cache(): + """ + Interact with the local cache. + """ + pass + + +@click.command("clear") +def clear_local_cache(): + """ + This command will remove all stored objects from local cache. + """ + LocalCache.clear() + + +local_cache.add_command(clear_local_cache) diff --git a/flytekit/clis/sdk_in_container/pyflyte.py b/flytekit/clis/sdk_in_container/pyflyte.py index 29f3d201a0..676b1bc675 100644 --- a/flytekit/clis/sdk_in_container/pyflyte.py +++ b/flytekit/clis/sdk_in_container/pyflyte.py @@ -7,6 +7,7 @@ from flytekit.clis.sdk_in_container.constants import CTX_PACKAGES from flytekit.clis.sdk_in_container.fast_register import fast_register from flytekit.clis.sdk_in_container.launch_plan import launch_plans +from flytekit.clis.sdk_in_container.local_cache import local_cache from flytekit.clis.sdk_in_container.package import package from flytekit.clis.sdk_in_container.register import register from flytekit.clis.sdk_in_container.serialize import serialize @@ -110,6 +111,7 @@ def update_configuration_file(config_file_path): main.add_command(serialize) main.add_command(launch_plans) main.add_command(package) +main.add_command(local_cache) if __name__ == "__main__": main() diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index fa1726fbf2..432746f10a 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -34,6 +34,7 @@ SerializationSettings, ) from flytekit.core.interface import Interface, transform_interface_to_typed_interface +from flytekit.core.local_cache import LocalCache from flytekit.core.promise import ( Promise, VoidPromise, @@ -216,6 +217,14 @@ def get_input_types(self) -> Dict[str, type]: """ return None + def _dispatch_execute( + self, ctx: FlyteContext, task_name: str, input_literal_map: _literal_models.LiteralMap, cache_version: str + ) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]: + """ + Thin wrapper around the actual call to 'dispatch_execute'. + """ + return self.dispatch_execute(ctx, input_literal_map) + def _local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]: """ This code is used only in the case when we want to dispatch_execute with outputs from a previous node @@ -235,7 +244,17 @@ def _local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], P ) input_literal_map = _literal_models.LiteralMap(literals=kwargs) - outputs_literal_map = self.dispatch_execute(ctx, input_literal_map) + # if metadata.cache is set, check memoized version + if self._metadata.cache: + # The cache key is composed of '(task name, input_literal_map, cache_version)', i.e. all other parameters + # passed to the call to 'dispatch_execute' are ignored + dispatch_execute_func = LocalCache.cache(self._dispatch_execute, ignore=["self", "ctx"]) + else: + dispatch_execute_func = self._dispatch_execute + # The local cache uses the function signature (and an ignore list) to calculate the cache key. In other + # words, we need the cache version to be present in the function signature so that we can respect the current + # cache semantics where changing the cache version of a cached Task creates a separate entry in the cache. + outputs_literal_map = dispatch_execute_func(ctx, self.name, input_literal_map, self._metadata.cache_version) outputs_literals = outputs_literal_map.literals # TODO maybe this is the part that should be done for local execution, we pass the outputs to some special diff --git a/flytekit/core/local_cache.py b/flytekit/core/local_cache.py new file mode 100644 index 0000000000..e893a9edf2 --- /dev/null +++ b/flytekit/core/local_cache.py @@ -0,0 +1,31 @@ +from typing import Callable, List, Optional + +from joblib import Memory + +# Location in the file system where serialized objects will be stored +# TODO: read from config +CACHE_LOCATION = "~/.flyte/local-cache" +# TODO: read from config +CACHE_VERBOSITY = 5 + + +class LocalCache(object): + _memory: Memory + _initialized: bool = False + + @staticmethod + def initialize(): + LocalCache._memory = Memory(CACHE_LOCATION, verbose=CACHE_VERBOSITY) + LocalCache._initialized = True + + @staticmethod + def cache(func: Callable, ignore: Optional[List[str]] = None): + if not LocalCache._initialized: + LocalCache.initialize() + return LocalCache._memory.cache(func, ignore=ignore) + + @staticmethod + def clear(): + if not LocalCache._initialized: + LocalCache.initialize() + LocalCache._memory.clear() diff --git a/requirements-spark2.txt b/requirements-spark2.txt index db2ef56b9d..d96dcbe0f4 100644 --- a/requirements-spark2.txt +++ b/requirements-spark2.txt @@ -29,9 +29,9 @@ black==21.7b0 # via papermill bleach==4.0.0 # via nbconvert -boto3==1.18.15 +boto3==1.18.19 # via sagemaker-training -botocore==1.21.15 +botocore==1.21.19 # via # boto3 # s3transfer @@ -78,7 +78,7 @@ flyteidl==0.19.19 # via flytekit gevent==21.8.0 # via sagemaker-training -greenlet==1.1.0 +greenlet==1.1.1 # via gevent grpcio==1.39.0 # via flytekit @@ -106,6 +106,8 @@ jmespath==0.10.0 # via # boto3 # botocore +joblib==1.0.1 + # via flytekit jsonschema==3.2.0 # via nbformat jupyter-client==6.1.12 @@ -245,7 +247,7 @@ requests==2.26.0 # flytekit # papermill # responses -responses==0.13.3 +responses==0.13.4 # via flytekit retry==0.9.2 # via flytekit @@ -322,7 +324,7 @@ webencodings==0.5.1 # via bleach werkzeug==2.0.1 # via sagemaker-training -wheel==0.36.2 +wheel==0.37.0 # via flytekit wrapt==1.12.1 # via diff --git a/requirements.txt b/requirements.txt index 0658567035..5fd41a1fc4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,9 +29,9 @@ black==21.7b0 # via papermill bleach==4.0.0 # via nbconvert -boto3==1.18.15 +boto3==1.18.19 # via sagemaker-training -botocore==1.21.15 +botocore==1.21.19 # via # boto3 # s3transfer @@ -78,7 +78,7 @@ flyteidl==0.19.19 # via flytekit gevent==21.8.0 # via sagemaker-training -greenlet==1.1.0 +greenlet==1.1.1 # via gevent grpcio==1.39.0 # via flytekit @@ -106,6 +106,8 @@ jmespath==0.10.0 # via # boto3 # botocore +joblib==1.0.1 + # via flytekit jsonschema==3.2.0 # via nbformat jupyter-client==6.1.12 @@ -245,7 +247,7 @@ requests==2.26.0 # flytekit # papermill # responses -responses==0.13.3 +responses==0.13.4 # via flytekit retry==0.9.2 # via flytekit @@ -322,7 +324,7 @@ webencodings==0.5.1 # via bleach werkzeug==2.0.1 # via sagemaker-training -wheel==0.36.2 +wheel==0.37.0 # via flytekit wrapt==1.12.1 # via diff --git a/setup.py b/setup.py index fb2f2565d0..ec384cbdc8 100644 --- a/setup.py +++ b/setup.py @@ -93,6 +93,7 @@ "docker-image-py>=0.1.10", "singledispatchmethod; python_version < '3.8.0'", "docstring-parser>=0.9.0", + "joblib>=1.0.0", ], extras_require=extras_require, scripts=[ diff --git a/tests/flytekit/unit/core/test_local_cache.py b/tests/flytekit/unit/core/test_local_cache.py new file mode 100644 index 0000000000..23e9c52de2 --- /dev/null +++ b/tests/flytekit/unit/core/test_local_cache.py @@ -0,0 +1,252 @@ +import datetime +import typing +from dataclasses import dataclass + +import pandas +from dataclasses_json import dataclass_json +from pytest import fixture + +from flytekit import SQLTask, kwtypes +from flytekit.core.local_cache import LocalCache +from flytekit.core.task import TaskMetadata, task +from flytekit.core.testing import task_mock +from flytekit.core.workflow import workflow +from flytekit.types.schema import FlyteSchema + +# Global counter used to validate number of calls to cache +n_cached_task_calls = 0 + + +@fixture(scope="function", autouse=True) +def setup(): + global n_cached_task_calls + n_cached_task_calls = 0 + + LocalCache.initialize() + LocalCache.clear() + + +def test_to_confirm_that_cache_keys_include_function_name(): + """ + This test confirms that the function name is part of the cache key. It does so by defining 2 tasks with + identical parameters and metadata (i.e. cache=True and cache version). + """ + + @task(cache=True, cache_version="v1") + def f1(n: int) -> int: + global n_cached_task_calls + n_cached_task_calls += 1 + + return n + + @task(cache=True, cache_version="v1") + def f2(n: int) -> int: + global n_cached_task_calls + n_cached_task_calls += 1 + + return n + 1 + + @workflow + def wf(n: int) -> (int, int): + n_f1 = f1(n=n) + n_f2 = f2(n=n) + return n_f1, n_f2 + + # This is demonstrating that calls to f1 and f2 are cached by input parameters. + assert wf(n=1) == (1, 2) + + +def test_single_task_workflow(): + @task(cache=True, cache_version="v1") + def is_even(n: int) -> bool: + global n_cached_task_calls + n_cached_task_calls += 1 + return n % 2 == 0 + + @task(cache=False) + def uncached_task(a: int, b: int) -> int: + return a + b + + @workflow + def check_evenness(n: int) -> bool: + uncached_task(a=n, b=n) + return is_even(n=n) + + assert n_cached_task_calls == 0 + assert check_evenness(n=1) is False + # Confirm task is called + assert n_cached_task_calls == 1 + assert check_evenness(n=1) is False + # Subsequent calls of the workflow with the same parameter do not bump the counter + assert n_cached_task_calls == 1 + assert check_evenness(n=1) is False + assert n_cached_task_calls == 1 + + # Run workflow with a different parameter and confirm counter is bumped + assert check_evenness(n=8) is True + assert n_cached_task_calls == 2 + # Run workflow again with the same parameter and confirm the counter is not bumped + assert check_evenness(n=8) is True + assert n_cached_task_calls == 2 + + +def test_shared_tasks_in_two_separate_workflows(): + @task(cache=True, cache_version="0.0.1") + def is_odd(n: int) -> bool: + global n_cached_task_calls + n_cached_task_calls += 1 + return n % 2 == 1 + + @workflow + def check_oddness_wf1(n: int) -> bool: + return is_odd(n=n) + + @workflow + def check_oddness_wf2(n: int) -> bool: + return is_odd(n=n) + + assert n_cached_task_calls == 0 + assert check_oddness_wf1(n=42) is False + assert check_oddness_wf1(n=99) is True + assert n_cached_task_calls == 2 + + # The next two executions of the *_wf2 workflow are going to + # hit the cache for the calls to `is_odd` + assert check_oddness_wf2(n=42) is False + assert check_oddness_wf2(n=99) is True + assert n_cached_task_calls == 2 + + +# TODO add test with typing.List[str] + + +def test_sql_task(): + sql = SQLTask( + "my-query", + query_template="SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ .Inputs.ds }}' LIMIT 10", + inputs=kwtypes(ds=datetime.datetime), + outputs=kwtypes(results=FlyteSchema), + metadata=TaskMetadata(retries=2, cache=True, cache_version="0.1"), + ) + + @task(cache=True, cache_version="0.1.2") + def t1() -> datetime.datetime: + global n_cached_task_calls + n_cached_task_calls += 1 + return datetime.datetime.now() + + @workflow + def my_wf() -> FlyteSchema: + dt = t1() + return sql(ds=dt) + + with task_mock(sql) as mock: + mock.return_value = pandas.DataFrame(data={"x": [1, 2], "y": ["3", "4"]}) + assert n_cached_task_calls == 0 + assert (my_wf().open().all() == pandas.DataFrame(data={"x": [1, 2], "y": ["3", "4"]})).all().all() + assert n_cached_task_calls == 1 + # The second and third calls hit the cache + assert (my_wf().open().all() == pandas.DataFrame(data={"x": [1, 2], "y": ["3", "4"]})).all().all() + assert n_cached_task_calls == 1 + assert (my_wf().open().all() == pandas.DataFrame(data={"x": [1, 2], "y": ["3", "4"]})).all().all() + assert n_cached_task_calls == 1 + + +def test_wf_custom_types(): + @dataclass_json + @dataclass + class MyCustomType(object): + x: int + y: str + + @task(cache=True, cache_version="a.b.c") + def t1(a: int) -> MyCustomType: + global n_cached_task_calls + n_cached_task_calls += 1 + return MyCustomType(x=a, y="t1") + + @task(cache=True, cache_version="v1") + def t2(a: MyCustomType, b: str) -> (MyCustomType, int): + global n_cached_task_calls + n_cached_task_calls += 1 + return MyCustomType(x=a.x, y=f"{a.y} {b}"), 5 + + @workflow + def my_wf(a: int, b: str) -> (MyCustomType, int): + return t2(a=t1(a=a), b=b) + + assert n_cached_task_calls == 0 + c, v = my_wf(a=10, b="hello") + assert v == 5 + assert c.x == 10 + assert c.y == "t1 hello" + assert n_cached_task_calls == 2 + c, v = my_wf(a=10, b="hello") + assert v == 5 + assert c.x == 10 + assert c.y == "t1 hello" + assert n_cached_task_calls == 2 + + +def test_wf_schema_to_df(): + schema1 = FlyteSchema[kwtypes(x=int, y=str)] + + @task(cache=True, cache_version="v0") + def t1() -> schema1: + global n_cached_task_calls + n_cached_task_calls += 1 + + s = schema1() + s.open().write(pandas.DataFrame(data={"x": [1, 2], "y": ["3", "4"]})) + return s + + @task(cache=True, cache_version="v1") + def t2(df: pandas.DataFrame) -> int: + global n_cached_task_calls + n_cached_task_calls += 1 + + return len(df.columns.values) + + @workflow + def wf() -> int: + return t2(df=t1()) + + assert n_cached_task_calls == 0 + x = wf() + assert x == 2 + assert n_cached_task_calls == 2 + # Second call does not bump the counter + x = wf() + assert x == 2 + assert n_cached_task_calls == 2 + + +def test_dict_wf_with_constants(): + @task(cache=True, cache_version="v99") + def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): + global n_cached_task_calls + n_cached_task_calls += 1 + + return a + 2, "world" + + @task(cache=True, cache_version="v101") + def t2(a: typing.Dict[str, str]) -> str: + global n_cached_task_calls + n_cached_task_calls += 1 + + return " ".join([v for k, v in a.items()]) + + @workflow + def my_wf(a: int, b: str) -> (int, str): + x, y = t1(a=a) + d = t2(a={"key1": b, "key2": y}) + return x, d + + assert n_cached_task_calls == 0 + x = my_wf(a=5, b="hello") + assert x == (7, "hello world") + assert n_cached_task_calls == 2 + # Second call does not bump the counter + x = my_wf(a=5, b="hello") + assert x == (7, "hello world") + assert n_cached_task_calls == 2