Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binary IDL With MessagePack #2760

Merged
merged 38 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
e3a258a
[flytekit][1][Simple Type] Binary IDL With MessagePack
Future-Outlier Sep 18, 2024
3562f0c
Add Tests
Future-Outlier Sep 18, 2024
f93b441
remove unused import
Future-Outlier Sep 18, 2024
c05a905
[flytekit][2][untyped dict] Binary IDL With MessagePack
Future-Outlier Sep 18, 2024
b1bf20c
Fix Tests
Future-Outlier Sep 18, 2024
d9dead2
[Flyte][3][Attribute Access] Binary IDL With MessagePack
Future-Outlier Sep 19, 2024
5539c1e
Merge branch 'master' into binary-idl-with-message-pack-bytes-1
Future-Outlier Sep 19, 2024
be6c024
Merge branch 'binary-idl-with-message-pack-bytes-1' into binary-idl-w…
Future-Outlier Sep 19, 2024
6b59d89
fix test_offloaded_literal
Future-Outlier Sep 19, 2024
0bd0924
Merge branch 'binary-idl-with-message-pack-bytes-2' into binary-idl-w…
Future-Outlier Sep 19, 2024
bcaf573
Add more tests
Future-Outlier Sep 19, 2024
dd5e1c9
Merge branch 'binary-idl-with-message-pack-bytes-1' into binary-idl-w…
Future-Outlier Sep 19, 2024
ef92a8a
Merge branch 'binary-idl-with-message-pack-bytes-2' into binary-idl-w…
Future-Outlier Sep 19, 2024
08f7388
add tests for more complex cases
Future-Outlier Sep 19, 2024
1bac074
turn {} to dict()
Future-Outlier Sep 19, 2024
4e25e4c
Merge branch 'binary-idl-with-message-pack-bytes-1' into binary-idl-w…
Future-Outlier Sep 19, 2024
684f31d
lint
Future-Outlier Sep 19, 2024
042aa80
[flytekit][4][dataclass, flyte types and attribute access] Binary IDL…
Future-Outlier Sep 19, 2024
1fa9dc2
fix all tests, and support flytetypes and union from binary idl
Future-Outlier Sep 20, 2024
b491ff0
self._encoder: Dict[Type, JSONEncoder]
Future-Outlier Sep 20, 2024
16bb504
fix lint
Future-Outlier Sep 20, 2024
b940bfa
better comments
Future-Outlier Sep 23, 2024
3c652a5
support enum transformer
Future-Outlier Sep 23, 2024
4e87f2c
add test_flytefile_in_dataclass_wf
Future-Outlier Sep 23, 2024
a01d57e
add tests
Future-Outlier Sep 23, 2024
ddbdf73
Test Backward Compatible
Future-Outlier Sep 23, 2024
fc8cbfb
add type transformer failed error
Future-Outlier Sep 23, 2024
908153f
Update pingsu's review advice
Future-Outlier Sep 24, 2024
f596a8f
update pingsu's review advice
Future-Outlier Sep 25, 2024
1eb5739
update dict and list test with dataclass
Future-Outlier Sep 25, 2024
d3c50c8
ruff
Future-Outlier Sep 25, 2024
fc11274
support Dict[int, int] as input in workflow, including attribute access
Future-Outlier Sep 25, 2024
35df821
Trigger CI
Future-Outlier Sep 26, 2024
f50b1ee
Merge branch 'master' into binary-idl-with-message-pack-bytes-4
Future-Outlier Sep 26, 2024
4b75f8f
Add flytekit.bin.entrypoint to __init__.py for auto copy bug
Future-Outlier Sep 26, 2024
88dcdc0
revert back
Future-Outlier Sep 26, 2024
008e6ae
Merge branch 'master' into binary-idl-with-message-pack-bytes-4
Future-Outlier Sep 29, 2024
ab36e49
add tests for union in dataclass, nested case
Future-Outlier Oct 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 34 additions & 11 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from flytekit.models import types as _type_models
from flytekit.models import types as type_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.literals import Primitive
from flytekit.models.literals import Binary, Literal, Primitive, Scalar
from flytekit.models.task import Resources
from flytekit.models.types import SimpleType

Expand Down Expand Up @@ -138,21 +138,43 @@ def resolve_attr_path_in_promise(p: Promise) -> Promise:
break

# If the current value is a dataclass, resolve the dataclass with the remaining path
if (
len(p.attr_path) > 0
and type(curr_val.value) is _literals_models.Scalar
and type(curr_val.value.value) is _struct.Struct
):
st = curr_val.value.value
new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:])
literal_type = TypeEngine.to_literal_type(type(new_st))
# Reconstruct the resolved result to flyte literal (because the resolved result might not be struct)
curr_val = TypeEngine.to_literal(FlyteContextManager.current_context(), new_st, type(new_st), literal_type)
if len(p.attr_path) > 0 and type(curr_val.value) is _literals_models.Scalar:
# We keep it for reference task local execution in the future.
if type(curr_val.value.value) is _struct.Struct:
st = curr_val.value.value
new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:])
literal_type = TypeEngine.to_literal_type(type(new_st))
# Reconstruct the resolved result to flyte literal (because the resolved result might not be struct)
curr_val = TypeEngine.to_literal(FlyteContextManager.current_context(), new_st, type(new_st), literal_type)
elif type(curr_val.value.value) is Binary:
binary_idl_obj = curr_val.value.value
if binary_idl_obj.tag == "msgpack":
import msgpack

dict_obj = msgpack.loads(binary_idl_obj.value, strict_map_key=False)
v = resolve_attr_path_in_dict(dict_obj, attr_path=p.attr_path[used:])
msgpack_bytes = msgpack.dumps(v)
curr_val = Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack")))
else:
raise TypeTransformerFailedError(f"Unsupported binary format {binary_idl_obj.tag}")

p._val = curr_val
return p


def resolve_attr_path_in_dict(d: dict, attr_path: List[Union[str, int]]) -> Any:
curr_val = d
for attr in attr_path:
try:
curr_val = curr_val[attr]
except (KeyError, IndexError, TypeError) as e:
raise FlytePromiseAttributeResolveException(
f"Failed to resolve attribute path {attr_path} in dict {curr_val}, attribute {attr} not found.\n"
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
f"Error Message: {e}"
)
return curr_val


def resolve_attr_path_in_pb_struct(st: _struct.Struct, attr_path: List[Union[str, int]]) -> _struct.Struct:
curr_val = st
for attr in attr_path:
Expand Down Expand Up @@ -211,6 +233,7 @@ def __init__(self, lhs: Union["Promise", Any], op: ComparisonOps, rhs: Union["Pr
self._op = op
self._lhs = None
self._rhs = None

if isinstance(lhs, Promise):
self._lhs = lhs
if lhs.is_ready:
Expand Down
126 changes: 87 additions & 39 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from functools import lru_cache
from typing import Dict, List, NamedTuple, Optional, Type, cast
from typing import Any, Dict, List, NamedTuple, Optional, Type, cast

import msgpack
from dataclasses_json import DataClassJsonMixin, dataclass_json
from flyteidl.core import literals_pb2
from google.protobuf import json_format as _json_format
Expand All @@ -25,7 +26,8 @@
from google.protobuf.json_format import ParseDict as _ParseDict
from google.protobuf.message import Message
from google.protobuf.struct_pb2 import Struct
from mashumaro.codecs.json import JSONDecoder, JSONEncoder
from mashumaro.codecs.json import JSONDecoder
from mashumaro.codecs.msgpack import MessagePackDecoder, MessagePackEncoder
from mashumaro.mixins.json import DataClassJSONMixin
from typing_extensions import Annotated, get_args, get_origin

Expand All @@ -42,22 +44,21 @@
from flytekit.models import types as _type_models
from flytekit.models.annotation import TypeAnnotation as TypeAnnotationModel
from flytekit.models.core import types as _core_types
from flytekit.models.literals import (
Literal,
LiteralCollection,
LiteralMap,
Primitive,
Scalar,
Union,
Void,
)
from flytekit.models.literals import Binary, Literal, LiteralCollection, LiteralMap, Primitive, Scalar, Union, Void
from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType

T = typing.TypeVar("T")
DEFINITIONS = "definitions"
TITLE = "title"


# In Mashumaro, the default encoder uses strict_map_key=False, while the default decoder uses strict_map_key=True.
# This is relevant for cases like Dict[int, str].
# If strict_map_key=False is not used, the decoder will raise an error when trying to decode keys that are not strictly typed.`
def _default_msgpack_decoder(data: bytes) -> Any:
return msgpack.unpackb(data, raw=False, strict_map_key=False)


class BatchSize:
"""
This is used to annotate a FlyteDirectory when we want to download/upload the contents of the directory in batches. For example,
Expand Down Expand Up @@ -129,6 +130,8 @@ def __init__(self, name: str, t: Type[T], enable_type_assertions: bool = True):
self._t = t
self._name = name
self._type_assertions_enabled = enable_type_assertions
self._msgpack_encoder: Dict[Type, MessagePackEncoder] = dict()
self._msgpack_decoder: Dict[Type, MessagePackDecoder] = dict()

@property
def name(self):
Expand Down Expand Up @@ -221,6 +224,17 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
f"Conversion to python value expected type {expected_python_type} from literal not implemented"
)

def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> Optional[T]:
if binary_idl_object.tag == "msgpack":
try:
decoder = self._msgpack_decoder[expected_python_type]
except KeyError:
decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_msgpack_decoder)
self._msgpack_decoder[expected_python_type] = decoder
return decoder.decode(binary_idl_object.value)
else:
raise TypeTransformerFailedError(f"Unsupported binary format `{binary_idl_object.tag}`")

def to_html(self, ctx: FlyteContext, python_val: T, expected_python_type: Type[T]) -> str:
"""
Converts any python val (dataframe, int, float) to a html string, and it will be wrapped in the HTML div
Expand Down Expand Up @@ -271,6 +285,9 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
f"Cannot convert to type {expected_python_type}, only {self._type} is supported"
)

if lv.scalar and lv.scalar.binary:
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore

try: # todo(maximsmol): this is quite ugly and each transformer should really check their Literal
res = self._from_literal_transformer(lv)
if type(res) != self._type:
Expand Down Expand Up @@ -313,21 +330,21 @@ class DataclassTransformer(TypeTransformer[object]):
"""
The Dataclass Transformer provides a type transformer for dataclasses.

The dataclass is converted to and from a JSON string by the mashumaro library
and is transported between tasks using the proto.Structpb representation.
The dataclass is converted to and from MessagePack Bytes by the mashumaro library
and is transported between tasks using the Binary IDL representation.
Also, the type declaration will try to extract the JSON Schema for the
object, if possible, and pass it with the definition.

The lifecycle of the dataclass in the Flyte type system is as follows:

1. Serialization: The dataclass transformer converts the dataclass to a JSON string.
1. Serialization: The dataclass transformer converts the dataclass to MessagePack Bytes.
(1) Handle dataclass attributes to make them serializable with mashumaro.
(2) Use the mashumaro API to serialize the dataclass to a JSON string.
(3) Use the JSON string to create a Flyte Literal.
(4) Serialize the Flyte Literal to a protobuf.
(2) Use the mashumaro API to serialize the dataclass to MessagePack Bytes.
(3) Use MessagePack Bytes to create a Flyte Literal.
(4) Serialize the Flyte Literal to a Binary IDL Object.

2. Deserialization: The dataclass transformer converts the JSON string back to a dataclass.
(1) Convert the JSON string to a dataclass using mashumaro.
2. Deserialization: The dataclass transformer converts the MessagePack Bytes back to a dataclass.
(1) Convert MessagePack Bytes to a dataclass using mashumaro.
(2) Handle dataclass attributes to ensure they are of the correct types.

For Json Schema, we use https://github.com/fuhrysteve/marshmallow-jsonschema library.
Expand Down Expand Up @@ -366,8 +383,7 @@ class Test(DataClassJsonMixin):

def __init__(self):
super().__init__("Object-Dataclass-Transformer", object)
self._encoder: Dict[Type, JSONEncoder] = {}
self._decoder: Dict[Type, JSONDecoder] = {}
self._decoder: Dict[Type, JSONDecoder] = dict()

def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T):
# Skip iterating all attributes in the dataclass if the type of v already matches the expected_type
Expand Down Expand Up @@ -526,8 +542,8 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
if isinstance(python_val, dict):
json_str = json.dumps(python_val)
return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())))
msgpack_bytes = msgpack.dumps(python_val)
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack")))

if not dataclasses.is_dataclass(python_val):
raise TypeTransformerFailedError(
Expand All @@ -542,25 +558,27 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
# We can't use hasattr(python_val, "to_json") here because we rely on mashumaro's API to customize the serialization behavior for Flyte types.
if isinstance(python_val, DataClassJSONMixin):
json_str = python_val.to_json()
dict_obj = json.loads(json_str)
msgpack_bytes = msgpack.dumps(dict_obj)
else:
# The function looks up or creates a JSONEncoder specifically designed for the object's type.
# This encoder is then used to convert a data class into a JSON string.
# The function looks up or creates a MessagePackEncoder specifically designed for the object's type.
# This encoder is then used to convert a data class into MessagePack Bytes.
try:
encoder = self._encoder[python_type]
encoder = self._msgpack_encoder[python_type]
except KeyError:
encoder = JSONEncoder(python_type)
self._encoder[python_type] = encoder
encoder = MessagePackEncoder(python_type)
self._msgpack_encoder[python_type] = encoder

try:
json_str = encoder.encode(python_val)
msgpack_bytes = encoder.encode(python_val)
except NotImplementedError:
# you can refer FlyteFile, FlyteDirectory and StructuredDataset to see how flyte types can be implemented.
raise NotImplementedError(
f"{python_type} should inherit from mashumaro.types.SerializableType"
f" and implement _serialize and _deserialize methods."
)

return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) # type: ignore
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack")))

def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]:
# dataclass will try to hash python type when calling dataclass.schema(), but some types in the annotation is
Expand Down Expand Up @@ -699,13 +717,34 @@ def _fix_dataclass_int(self, dc_type: Type[dataclasses.dataclass], dc: typing.An

return dc

def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> T:
if binary_idl_object.tag == "msgpack":
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
if issubclass(expected_python_type, DataClassJSONMixin):
dict_obj = msgpack.loads(binary_idl_object.value, strict_map_key=False)
json_str = json.dumps(dict_obj)
dc = expected_python_type.from_json(json_str) # type: ignore
else:
try:
decoder = self._msgpack_decoder[expected_python_type]
except KeyError:
decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_msgpack_decoder)
self._msgpack_decoder[expected_python_type] = decoder
dc = decoder.decode(binary_idl_object.value)

return self._fix_structured_dataset_type(expected_python_type, dc) # type: ignore
else:
raise TypeTransformerFailedError(f"Unsupported binary format {binary_idl_object.tag}")

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
if not dataclasses.is_dataclass(expected_python_type):
raise TypeTransformerFailedError(
f"{expected_python_type} is not of type @dataclass, only Dataclasses are supported for "
"user defined datatypes in Flytekit"
)

if lv.scalar and lv.scalar.binary:
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore

json_str = _json_format.MessageToJson(lv.scalar.generic)

# The `from_json` function is provided from mashumaro's `DataClassJSONMixin`.
Expand Down Expand Up @@ -819,6 +858,8 @@ def to_literal(
return Literal(scalar=Scalar(primitive=Primitive(string_value=python_val.value))) # type: ignore

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
if lv.scalar and lv.scalar.binary:
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore
return expected_python_type(lv.scalar.primitive.string_value) # type: ignore

def guess_python_type(self, literal_type: LiteralType) -> Type[enum.Enum]:
Expand Down Expand Up @@ -1390,6 +1431,9 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
return Literal(collection=LiteralCollection(literals=lit_list))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[typing.Any]: # type: ignore
if lv and lv.scalar and lv.scalar.binary is not None:
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore

try:
lits = lv.collection.literals
except AttributeError:
Expand Down Expand Up @@ -1609,6 +1653,9 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> Optional[typing.Any]:
expected_python_type = get_underlying_type(expected_python_type)

if lv.scalar is not None and lv.scalar.binary is not None:
return self.from_binary_idl(lv.scalar.binary, expected_python_type)

union_tag = None
union_type = None
if lv.scalar is not None and lv.scalar.union is not None:
Expand Down Expand Up @@ -1671,8 +1718,8 @@ def guess_python_type(self, literal_type: LiteralType) -> type:

class DictTransformer(TypeTransformer[dict]):
"""
Transformer that transforms a univariate dictionary Dict[str, T] to a Literal Map or
transforms a untyped dictionary to a JSON (struct/Generic)
Transformer that transforms an univariate dictionary Dict[str, T] to a Literal Map or
transforms an untyped dictionary to a Binary Scalar Literal with a Struct Literal Type.
"""

def __init__(self):
Expand All @@ -1697,17 +1744,15 @@ def extract_types_or_metadata(t: Optional[Type[dict]]) -> typing.Tuple:
return None, None

@staticmethod
def dict_to_generic_literal(ctx: FlyteContext, v: dict, allow_pickle: bool) -> Literal:
def dict_to_binary_literal(ctx: FlyteContext, v: dict, allow_pickle: bool) -> Literal:
"""
Creates a flyte-specific ``Literal`` value from a native python dictionary.
"""
from flytekit.types.pickle import FlytePickle

try:
return Literal(
scalar=Scalar(generic=_json_format.Parse(json.dumps(v), _struct.Struct())),
metadata={"format": "json"},
)
msgpack_bytes = msgpack.dumps(v)
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack")))
except TypeError as e:
if allow_pickle:
remote_path = FlytePickle.to_pickle(ctx, v)
Expand All @@ -1717,7 +1762,7 @@ def dict_to_generic_literal(ctx: FlyteContext, v: dict, allow_pickle: bool) -> L
),
metadata={"format": "pickle"},
)
raise e
raise TypeTransformerFailedError(f"Cannot convert `{v}` to Flyte Literal.\n" f"Error Message: {e}")

@staticmethod
def is_pickle(python_type: Type[dict]) -> typing.Tuple[bool, Type]:
Expand Down Expand Up @@ -1768,7 +1813,7 @@ def to_literal(
allow_pickle, base_type = DictTransformer.is_pickle(python_type)

if expected and expected.simple and expected.simple == SimpleType.STRUCT:
return self.dict_to_generic_literal(ctx, python_val, allow_pickle)
return self.dict_to_binary_literal(ctx, python_val, allow_pickle)

lit_map = {}
for k, v in python_val.items():
Expand All @@ -1785,6 +1830,9 @@ def to_literal(
return Literal(map=LiteralMap(literals=lit_map))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[dict]) -> dict:
if lv and lv.scalar and lv.scalar.binary is not None:
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore

if lv and lv.map and lv.map.literals is not None:
tp = self.dict_types(expected_python_type)

Expand Down
Loading
Loading