Skip to content

Commit

Permalink
New content for map task with multiple inputs (flyteorg#822)
Browse files Browse the repository at this point in the history
* New content for map task with multiple inputs

Signed-off-by: Marc Paquette <[email protected]>

* Typo

Co-authored-by: Samhita Alla <[email protected]>

Signed-off-by: Marc Paquette <[email protected]>
  • Loading branch information
marc-union authored Jun 30, 2022
1 parent 1bd7f6b commit 475af3e
Showing 1 changed file with 83 additions and 5 deletions.
88 changes: 83 additions & 5 deletions cookbook/core/control_flow/map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# %%
# First, import the libraries.
import typing
from typing import List

from flytekit import Resources, map_task, task, workflow

Expand All @@ -36,19 +36,19 @@ def a_mappable_task(a: int) -> str:
# %%
# Also define a task to reduce the mapped output to a string.
@task
def coalesce(b: typing.List[str]) -> str:
def coalesce(b: List[str]) -> str:
coalesced = "".join(b)
return coalesced


# %%
# We send ``a_mappable_task`` to be repeated across a collection of inputs to the :py:func:`~flytekit:flytekit.map_task` function.
# In the example, ``a`` of type ``typing.List[int]`` is the input.
# In the example, ``a`` of type ``List[int]`` is the input.
# The task ``a_mappable_task`` is run for each element in the list.
#
# ``with_overrides`` is useful to set resources for individual map task.
@workflow
def my_map_workflow(a: typing.List[int]) -> str:
def my_map_workflow(a: List[int]) -> str:
mapped_out = map_task(a_mappable_task)(a=a).with_overrides(
requests=Resources(mem="300Mi"),
limits=Resources(mem="500Mi"),
Expand All @@ -65,5 +65,83 @@ def my_map_workflow(a: typing.List[int]) -> str:
print(f"{result}")

# %%
# By default, the map task uses the K8s Array plugin. Map tasks can also run on alternate execution backends, such as `AWS Batch <https://docs.flyte.org/en/latest/deployment/plugin_setup/aws/batch.html#deployment-plugin-setup-aws-array>`__,
# When defining a map task, avoid calling other tasks in it. Flyte
# can't accurately register tasks that call other tasks. While Flyte
# will correctly execute a task that calls other tasks, it will not be
# able to give full performance advantages. This is
# especially true for map tasks.
#
# In this example, the map task ``suboptimal_mappable_task`` would not
# give you the best performance.
@task
def upperhalf(a: int) -> int:
return a / 2 + 1

@task
def suboptimal_mappable_task(a: int) -> str:
inc = upperhalf(a=a)
stringified = str(inc)
return stringified


# %%
#
# By default, the map task uses the K8s Array plugin. Map tasks can
# also run on alternate execution backends, such as
# `AWS Batch <https://docs.flyte.org/en/latest/deployment/plugin_setup/aws/batch.html#deployment-plugin-setup-aws-array>`__,
# a provisioned service that can scale to great sizes.


# %%
# Map a Task with Multiple Inputs
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# You might need to map a task with multiple inputs.
#
# For example, we have a task that takes 3 inputs.
@task
def multi_input_task(quantity: int, price: float, shipping: float) -> float:
return quantity * price * shipping

# %%
# But we only want to map this task with the ``quantity`` input
# while the other inputs stay the same. Since a map task accepts only
# one input, we can do this by creating a new task that prepares the
# map task's inputs.
#
# We start by putting the inputs in a Dataclass and
# ``dataclass_json``. We also define our helper task to prepare the map
# task's inputs.
from dataclasses import dataclass
from dataclasses_json import dataclass_json

@dataclass_json
@dataclass
class MapInput:
quantity: float
price: float
shipping: float

@task
def prepare_map_inputs(list_q: List[float], p: float, s: float) -> List[MapInput]:
return [MapInput(q, p, s) for q in list_q]

# %%
# Then we refactor ``multi_input_task``. Instead of 3 inputs, ``mappable_task``
# has a single input.
@task
def mappable_task(input: MapInput) -> float:
return input.quantity * input.price * input.shipping

# %%
# Our workflow prepares a new list of inputs for the map task.
@workflow
def multiple_workflow(list_q: List[float], p: float, s: float) -> List[float]:
prepared = prepare_map_inputs(list_q=list_q, p=p, s=s)
return map_task(mappable_task)(input=prepared)

# %%
# We can run our multi-input map task locally.
if __name__ == "__main__":
result = multiple_workflow(list_q=[1.0, 2.0, 3.0, 4.0, 5.0], p=6.0, s=7.0)
print(f"{result}")

0 comments on commit 475af3e

Please sign in to comment.