Skip to content

Commit

Permalink
Merge branch 'master' into pydantic-plugin-v2-with-msgpack
Browse files Browse the repository at this point in the history
  • Loading branch information
Future-Outlier committed Oct 25, 2024
2 parents 7735352 + 57f583e commit 959f02b
Show file tree
Hide file tree
Showing 17 changed files with 375 additions and 192 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ You can find the detailed contribution guide [here](https://docs.flyte.org/en/la
Please see the [contributor's guide](https://docs.flyte.org/en/latest/api/flytekit/contributing.html) for a quick summary of how this code is structured.

## 🐞 File an Issue
Refer to the [issues](https://docs.flyte.org/en/latest/community/contribute.html#file-an-issue) section in the contribution guide if you'd like to file an issue.
Refer to the [issues](https://github.com/flyteorg/flyte/issues) section in the contribution guide if you'd like to file an issue.

## 🔌 Flytekit Plugins
Refer to [plugins/README.md](plugins/README.md) for a list of available plugins.
Expand Down
15 changes: 1 addition & 14 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union, cast

import typing_extensions
from flyteidl.core import tasks_pb2

from flytekit.configuration import SerializationSettings
Expand All @@ -18,7 +17,7 @@
from flytekit.core.interface import transform_interface_to_list_interface
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask
from flytekit.core.type_engine import TypeEngine, TypeTransformer, is_annotated
from flytekit.core.type_engine import TypeEngine
from flytekit.core.utils import timeit
from flytekit.loggers import logger
from flytekit.models import literals as _literal_models
Expand All @@ -27,8 +26,6 @@
from flytekit.models.interface import Variable
from flytekit.models.task import Container, K8sPod, Sql, Task
from flytekit.tools.module_loader import load_object_from_module
from flytekit.types.pickle import pickle
from flytekit.types.pickle.pickle import FlytePickleTransformer
from flytekit.utils.asyn import loop_manager

if TYPE_CHECKING:
Expand Down Expand Up @@ -77,16 +74,6 @@ def __init__(
"Only PythonFunctionTask with default execution mode (not @dynamic or @eager) and PythonInstanceTask are supported in map tasks."
)

for k, v in actual_task.python_interface.inputs.items():
if bound_inputs and k in bound_inputs:
continue
transformer: TypeTransformer = TypeEngine.get_transformer(v)
if isinstance(transformer, FlytePickleTransformer):
if is_annotated(v):
for annotation in typing_extensions.get_args(v)[1:]:
if isinstance(annotation, pickle.BatchSize):
raise ValueError("Choosing a BatchSize for map tasks inputs is not supported.")

n_outputs = len(actual_task.python_interface.outputs)
if n_outputs > 1:
raise ValueError("Only tasks with a single output are supported in map tasks.")
Expand Down
71 changes: 20 additions & 51 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,18 @@ def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.
def _make_dataclass_serializable(self, python_val: T, python_type: Type[T]) -> typing.Any:
"""
If any field inside the dataclass is flyte type, we should use flyte type transformer for that field.
Since Flyte types are already serializable, this function is intended for using strings instead of directly creating Flyte files and directories in the dataclass.
An example shows the lifecycle:
@dataclass
class DC:
ff: FlyteFile
@task
def t1() -> DC:
return DC(ff="s3://path")
Lifecycle: DC(ff="s3://path") -> to_literal() -> DC(ff=FlyteFile(path="s3://path")) -> msgpack -> to_python_value() -> DC(ff=FlyteFile(path="s3://path"))
"""
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile
Expand Down Expand Up @@ -1592,49 +1604,15 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]:
except Exception as e:
raise ValueError(f"Type of Generic List type is not supported, {e}")

@staticmethod
def is_batchable(t: Type):
"""
This function evaluates whether the provided type is batchable or not.
It returns True only if the type is either List or Annotated(List) and the List subtype is FlytePickle.
"""
from flytekit.types.pickle import FlytePickle

if is_annotated(t):
return ListTransformer.is_batchable(get_args(t)[0])
if get_origin(t) is list:
subtype = get_args(t)[0]
if subtype == FlytePickle or (hasattr(subtype, "__origin__") and subtype.__origin__ == FlytePickle):
return True
return False

async def async_to_literal(
self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType
) -> Literal:
if type(python_val) != list:
raise TypeTransformerFailedError("Expected a list")

if ListTransformer.is_batchable(python_type):
from flytekit.types.pickle.pickle import BatchSize, FlytePickle

batch_size = len(python_val) # default batch size
# parse annotated to get the number of items saved in a pickle file.
if is_annotated(python_type):
for annotation in get_args(python_type)[1:]:
if isinstance(annotation, BatchSize):
batch_size = annotation.val
break
if batch_size > 0:
lit_list = [
TypeEngine.to_literal(ctx, python_val[i : i + batch_size], FlytePickle, expected.collection_type)
for i in range(0, len(python_val), batch_size)
] # type: ignore
else:
lit_list = []
else:
t = self.get_sub_type(python_type)
lit_list = [TypeEngine.async_to_literal(ctx, x, t, expected.collection_type) for x in python_val]
lit_list = await asyncio.gather(*lit_list)
t = self.get_sub_type(python_type)
lit_list = [TypeEngine.async_to_literal(ctx, x, t, expected.collection_type) for x in python_val]
lit_list = await asyncio.gather(*lit_list)

return Literal(collection=LiteralCollection(literals=lit_list))

Expand All @@ -1653,20 +1631,11 @@ async def async_to_python_value( # type: ignore
f"is not a collection (Flyte's representation of Python lists)."
)
)
if self.is_batchable(expected_python_type):
from flytekit.types.pickle import FlytePickle

batch_list = [TypeEngine.to_python_value(ctx, batch, FlytePickle) for batch in lits]
if len(batch_list) > 0 and type(batch_list[0]) is list:
# Make it have backward compatibility. The upstream task may use old version of Flytekit that won't
# merge the elements in the list. Therefore, we should check if the batch_list[0] is the list first.
return [item for batch in batch_list for item in batch]
return batch_list
else:
st = self.get_sub_type(expected_python_type)
result = [TypeEngine.async_to_python_value(ctx, x, st) for x in lits]
result = await asyncio.gather(*result)
return result # type: ignore # should be a list, thinks its a tuple

st = self.get_sub_type(expected_python_type)
result = [TypeEngine.async_to_python_value(ctx, x, st) for x in lits]
result = await asyncio.gather(*result)
return result # type: ignore # should be a list, thinks its a tuple

def guess_python_type(self, literal_type: LiteralType) -> list: # type: ignore
if literal_type.collection_type:
Expand Down
194 changes: 194 additions & 0 deletions flytekit/extras/accelerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,197 @@ class _A100_80GB(_A100_80GB_Base):
#: .. autoclass:: _A100_80GB
#: :members:
A100_80GB = _A100_80GB()


class _V5E_Base(MultiInstanceGPUAccelerator):
device = "tpu-v5-lite-podslice"


class _V5E(_V5E_Base):
"""
Slices of a `Google Cloud TPU v5e <https://cloud.google.com/tpu/docs/v5e>`_.
"""

slice_1x1 = _V5E_Base.partitioned("1x1")
"""
1x1 topology representing 1 TPU chip or 1/8 of a host.
"""
slice_2x2 = _V5E_Base.partitioned("2x2")
"""
2x2 topology representing 4 TPU chip or 1/2 of a host.
"""
slice_2x4 = _V5E_Base.partitioned("2x4")
"""
2x4 topology representing 8 TPU chip or 1 host.
"""
slice_4x4 = _V5E_Base.partitioned("4x4")
"""
4x4 topology representing 16 TPU chip or 2 hosts.
"""
slice_4x8 = _V5E_Base.partitioned("4x8")
"""
4x8 topology representing 32 TPU chip or 4 hosts.
"""
slice_8x8 = _V5E_Base.partitioned("8x8")
"""
8x8 topology representing 64 TPU chip or 8 hosts.
"""
slice_8x16 = _V5E_Base.partitioned("8x16")
"""
8x16 topology representing 128 TPU chip or 16 hosts.
"""
slice_16x16 = _V5E_Base.partitioned("16x16")
"""
16x16 topology representing 256 TPU chip or 32 hosts.
"""


#: use this constant to specify that the task should run on V5E TPU.
#: `Google V5E Cloud TPU <https://cloud.google.com/tpu/docs/v5e>`_.
#:
#: Use pre-defined slices (as instance attributes). For example, to specify a 2x4 slice, use
#: ``V5E.slice_2x4``.
#: All available partitions are listed below:
#:
#: .. autoclass:: _V5E
#: :members:
V5E = _V5E()


class _V5P_Base(MultiInstanceGPUAccelerator):
device = "tpu-v5p-slice"


class _V5P(_V5P_Base):
"""
Slices of a `Google Cloud TPU v5p <https://cloud.google.com/tpu/docs/v5p>`_.
"""

slice_2x2x1 = _V5P_Base.partitioned("2x2x1")
"""
2x2x1 topology representing 8 TPU cores, 4 chips, 1 host.
"""

slice_2x2x2 = _V5P_Base.partitioned("2x2x2")
"""
2x2x2 topology representing 16 TPU cores, 8 chips, 2 machines.
"""

slice_2x4x4 = _V5P_Base.partitioned("2x4x4")
"""
2x4x4 topology representing 64 TPU cores, 32 chips, 8 machines.
"""

slice_4x4x4 = _V5P_Base.partitioned("4x4x4")
"""
4x4x4 topology representing 128 TPU cores, 64 chips, 16 machines.
"""

slice_4x4x8 = _V5P_Base.partitioned("4x4x8")
"""
4x4x8 topology representing 256 TPU cores, 128 chips, 32 machines. Supports Twisted Topology.
"""

slice_4x8x8 = _V5P_Base.partitioned("4x8x8")
"""
4x8x8 topology representing 512 TPU cores, 256 chips, 64 machines. Supports Twisted Topology.
"""

slice_8x8x8 = _V5P_Base.partitioned("8x8x8")
"""
8x8x8 topology representing 1024 TPU cores, 512 chips, 128 machines.
"""

slice_8x8x16 = _V5P_Base.partitioned("8x8x16")
"""
8x8x16 topology representing 2048 TPU cores, 1024 chips, 256 machines. Supports Twisted Topology.
"""

slice_8x16x16 = _V5P_Base.partitioned("8x16x16")
"""
8x16x16 topology representing 4096 TPU cores, 2048 chips, 512 machines. Supports Twisted Topology.
"""

slice_16x16x16 = _V5P_Base.partitioned("16x16x16")
"""
16x16x16 topology representing 8192 TPU cores, 4096 chips, 1024 machines.
"""

slice_16x16x24 = _V5P_Base.partitioned("16x16x24")
"""
16x16x24 topology representing 12288 TPU cores, 6144 chips, 1536 machines.
"""


#: Use this constant to specify that the task should run on V5P TPU.
#: `Google V5P Cloud TPU <https://cloud.google.com/tpu/docs/v5p>`_.
#:
#: Use pre-defined slices (as instance attributes). For example, to specify a 2x4x4 slice, use
#: ``V5P.slice_2x4x4``.
#: All available partitions are listed below:
#:
#: .. autoclass:: _V5P
#: :members:
V5P = _V5P()


class _V6E_Base(MultiInstanceGPUAccelerator):
device = "tpu-v6e-slice"


class _V6E(_V6E_Base):
"""
Slices of a `Google Cloud TPU v6e <https://cloud.google.com/tpu/docs/v6e>`_.
"""

slice_1x1 = _V6E_Base.partitioned("1x1")
"""
1x1 topology representing 1 TPU core or 1/8 of a host.
"""

slice_2x2 = _V6E_Base.partitioned("2x2")
"""
2x2 topology representing 4 TPU cores or 1/2 of a host.
"""

slice_2x4 = _V6E_Base.partitioned("2x4")
"""
2x4 topology representing 8 TPU cores or 1 host.
"""

slice_4x4 = _V6E_Base.partitioned("4x4")
"""
4x4 topology representing 16 TPU cores or 2 hosts.
"""

slice_4x8 = _V6E_Base.partitioned("4x8")
"""
4x8 topology representing 32 TPU cores or 4 hosts.
"""

slice_8x8 = _V6E_Base.partitioned("8x8")
"""
8x8 topology representing 64 TPU cores or 8 hosts.
"""

slice_8x16 = _V6E_Base.partitioned("8x16")
"""
8x16 topology representing 128 TPU cores or 16 hosts.
"""

slice_16x16 = _V6E_Base.partitioned("16x16")
"""
16x16 topology representing 256 TPU cores or 32 hosts.
"""


#: Use this constant to specify that the task should run on V6E TPU.
#: `Google V6E Cloud TPU <https://cloud.google.com/tpu/docs/v6e>`_.
#:
#: Use pre-defined slices (as instance attributes). For example, to specify a 2x4 slice, use
#: ``V6E.slice_2x4``.
#: All available partitions are listed below:
#:
#: .. autoclass:: _V6E
#: :members:
V6E = _V6E()
3 changes: 3 additions & 0 deletions flytekit/models/core/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ def from_flyte_idl(cls, p):
"""
return cls(code=p.code, message=p.message, error_uri=p.error_uri, kind=p.kind)

def _repr_html_(self) -> str:
return f"<b>{self.code}</b> <pre>{self.message}</pre>"


class TaskLog(_common.FlyteIdlEntity):
class MessageFormat(object):
Expand Down
Loading

0 comments on commit 959f02b

Please sign in to comment.