diff --git a/cookbook/core/control_flow/map_task.py b/cookbook/core/control_flow/map_task.py index b822ea039f..a7ae4919a3 100644 --- a/cookbook/core/control_flow/map_task.py +++ b/cookbook/core/control_flow/map_task.py @@ -16,7 +16,7 @@ # %% # First, import the libraries. -import typing +from typing import List from flytekit import Resources, map_task, task, workflow @@ -36,19 +36,19 @@ def a_mappable_task(a: int) -> str: # %% # Also define a task to reduce the mapped output to a string. @task -def coalesce(b: typing.List[str]) -> str: +def coalesce(b: List[str]) -> str: coalesced = "".join(b) return coalesced # %% # We send ``a_mappable_task`` to be repeated across a collection of inputs to the :py:func:`~flytekit:flytekit.map_task` function. -# In the example, ``a`` of type ``typing.List[int]`` is the input. +# In the example, ``a`` of type ``List[int]`` is the input. # The task ``a_mappable_task`` is run for each element in the list. # # ``with_overrides`` is useful to set resources for individual map task. @workflow -def my_map_workflow(a: typing.List[int]) -> str: +def my_map_workflow(a: List[int]) -> str: mapped_out = map_task(a_mappable_task)(a=a).with_overrides( requests=Resources(mem="300Mi"), limits=Resources(mem="500Mi"), @@ -65,5 +65,83 @@ def my_map_workflow(a: typing.List[int]) -> str: print(f"{result}") # %% -# By default, the map task uses the K8s Array plugin. Map tasks can also run on alternate execution backends, such as `AWS Batch `__, +# When defining a map task, avoid calling other tasks in it. Flyte +# can't accurately register tasks that call other tasks. While Flyte +# will correctly execute a task that calls other tasks, it will not be +# able to give full performance advantages. This is +# especially true for map tasks. +# +# In this example, the map task ``suboptimal_mappable_task`` would not +# give you the best performance. +@task +def upperhalf(a: int) -> int: + return a / 2 + 1 + +@task +def suboptimal_mappable_task(a: int) -> str: + inc = upperhalf(a=a) + stringified = str(inc) + return stringified + + +# %% +# +# By default, the map task uses the K8s Array plugin. Map tasks can +# also run on alternate execution backends, such as +# `AWS Batch `__, # a provisioned service that can scale to great sizes. + + +# %% +# Map a Task with Multiple Inputs +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# You might need to map a task with multiple inputs. +# +# For example, we have a task that takes 3 inputs. +@task +def multi_input_task(quantity: int, price: float, shipping: float) -> float: + return quantity * price * shipping + +# %% +# But we only want to map this task with the ``quantity`` input +# while the other inputs stay the same. Since a map task accepts only +# one input, we can do this by creating a new task that prepares the +# map task's inputs. +# +# We start by putting the inputs in a Dataclass and +# ``dataclass_json``. We also define our helper task to prepare the map +# task's inputs. +from dataclasses import dataclass +from dataclasses_json import dataclass_json + +@dataclass_json +@dataclass +class MapInput: + quantity: float + price: float + shipping: float + +@task +def prepare_map_inputs(list_q: List[float], p: float, s: float) -> List[MapInput]: + return [MapInput(q, p, s) for q in list_q] + +# %% +# Then we refactor ``multi_input_task``. Instead of 3 inputs, ``mappable_task`` +# has a single input. +@task +def mappable_task(input: MapInput) -> float: + return input.quantity * input.price * input.shipping + +# %% +# Our workflow prepares a new list of inputs for the map task. +@workflow +def multiple_workflow(list_q: List[float], p: float, s: float) -> List[float]: + prepared = prepare_map_inputs(list_q=list_q, p=p, s=s) + return map_task(mappable_task)(input=prepared) + +# %% +# We can run our multi-input map task locally. +if __name__ == "__main__": + result = multiple_workflow(list_q=[1.0, 2.0, 3.0, 4.0, 5.0], p=6.0, s=7.0) + print(f"{result}")