Skip to content

Commit

Permalink
Binary IDL With MessagePack (flyteorg#2760)
Browse files Browse the repository at this point in the history
* [flytekit][1][Simple Type] Binary IDL With MessagePack

Signed-off-by: Future-Outlier <[email protected]>

* Add Tests

Signed-off-by: Future-Outlier <[email protected]>

* remove unused import

Signed-off-by: Future-Outlier <[email protected]>

* [flytekit][2][untyped dict] Binary IDL With MessagePack

Signed-off-by: Future-Outlier <[email protected]>

* Fix Tests

Signed-off-by: Future-Outlier <[email protected]>

* [Flyte][3][Attribute Access] Binary IDL With MessagePack

Signed-off-by: Future-Outlier <[email protected]>

* fix test_offloaded_literal

Signed-off-by: Future-Outlier <[email protected]>

* Add more tests

Signed-off-by: Future-Outlier <[email protected]>

* add tests for more complex cases

Signed-off-by: Future-Outlier <[email protected]>

* turn {} to dict()

Signed-off-by: Future-Outlier <[email protected]>

* lint

Signed-off-by: Future-Outlier <[email protected]>

* [flytekit][4][dataclass, flyte types and attribute access] Binary IDL With MessagePack

Signed-off-by: Future-Outlier <[email protected]>

* fix all tests, and support flytetypes and union from binary idl

Signed-off-by: Future-Outlier <[email protected]>

* self._encoder: Dict[Type, JSONEncoder]

Signed-off-by: Future-Outlier <[email protected]>

* fix lint

Signed-off-by: Future-Outlier <[email protected]>

* better comments

Signed-off-by: Future-Outlier <[email protected]>

* support enum transformer

Signed-off-by: Future-Outlier <[email protected]>

* add test_flytefile_in_dataclass_wf

Signed-off-by: Future-Outlier <[email protected]>

* add tests

Signed-off-by: Future-Outlier <[email protected]>

* Test Backward Compatible

Signed-off-by: Future-Outlier <[email protected]>

* add type transformer failed error

Signed-off-by: Future-Outlier <[email protected]>

* Update pingsu's review advice

Signed-off-by: Future-Outlier <[email protected]>
Co-authored-by: pingsutw  <[email protected]>

* update pingsu's review advice

Signed-off-by: Future-Outlier <[email protected]>
Co-authored-by: pingsutw  <[email protected]>

* update dict and list test with dataclass

Signed-off-by: Future-Outlier <[email protected]>

* ruff

Signed-off-by: Future-Outlier <[email protected]>

* support Dict[int, int] as input in workflow, including attribute access

Signed-off-by: Future-Outlier <[email protected]>

* Trigger CI

Signed-off-by: Future-Outlier <[email protected]>

* Add flytekit.bin.entrypoint to __init__.py for auto copy bug

Signed-off-by: Future-Outlier <[email protected]>

* revert back

Signed-off-by: Future-Outlier <[email protected]>

* add tests for union in dataclass, nested case

Signed-off-by: Future-Outlier <[email protected]>

---------

Signed-off-by: Future-Outlier <[email protected]>
Co-authored-by: pingsutw <[email protected]>
  • Loading branch information
2 people authored and otarabai committed Oct 15, 2024
1 parent b7873cc commit 1948110
Show file tree
Hide file tree
Showing 16 changed files with 1,270 additions and 132 deletions.
3 changes: 3 additions & 0 deletions flytekit/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@

# If set this environment variable overrides the default container image and the default base image in ImageSpec.
FLYTE_INTERNAL_IMAGE_ENV_VAR = "FLYTE_INTERNAL_IMAGE"

# Binary IDL Serialization Format
MESSAGEPACK = "msgpack"
45 changes: 34 additions & 11 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from flytekit.models import types as _type_models
from flytekit.models import types as type_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.literals import Primitive
from flytekit.models.literals import Binary, Literal, Primitive, Scalar
from flytekit.models.task import Resources
from flytekit.models.types import SimpleType

Expand Down Expand Up @@ -138,21 +138,43 @@ def resolve_attr_path_in_promise(p: Promise) -> Promise:
break

# If the current value is a dataclass, resolve the dataclass with the remaining path
if (
len(p.attr_path) > 0
and type(curr_val.value) is _literals_models.Scalar
and type(curr_val.value.value) is _struct.Struct
):
st = curr_val.value.value
new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:])
literal_type = TypeEngine.to_literal_type(type(new_st))
# Reconstruct the resolved result to flyte literal (because the resolved result might not be struct)
curr_val = TypeEngine.to_literal(FlyteContextManager.current_context(), new_st, type(new_st), literal_type)
if len(p.attr_path) > 0 and type(curr_val.value) is _literals_models.Scalar:
# We keep it for reference task local execution in the future.
if type(curr_val.value.value) is _struct.Struct:
st = curr_val.value.value
new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:])
literal_type = TypeEngine.to_literal_type(type(new_st))
# Reconstruct the resolved result to flyte literal (because the resolved result might not be struct)
curr_val = TypeEngine.to_literal(FlyteContextManager.current_context(), new_st, type(new_st), literal_type)
elif type(curr_val.value.value) is Binary:
binary_idl_obj = curr_val.value.value
if binary_idl_obj.tag == _common_constants.MESSAGEPACK:
import msgpack

dict_obj = msgpack.loads(binary_idl_obj.value, strict_map_key=False)
v = resolve_attr_path_in_dict(dict_obj, attr_path=p.attr_path[used:])
msgpack_bytes = msgpack.dumps(v)
curr_val = Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack")))
else:
raise TypeTransformerFailedError(f"Unsupported binary format {binary_idl_obj.tag}")

p._val = curr_val
return p


def resolve_attr_path_in_dict(d: dict, attr_path: List[Union[str, int]]) -> Any:
curr_val = d
for attr in attr_path:
try:
curr_val = curr_val[attr]
except (KeyError, IndexError, TypeError) as e:
raise FlytePromiseAttributeResolveException(
f"Failed to resolve attribute path {attr_path} in dict `{curr_val}`, attribute `{attr}` not found.\n"
f"Error Message: {e}"
)
return curr_val


def resolve_attr_path_in_pb_struct(st: _struct.Struct, attr_path: List[Union[str, int]]) -> _struct.Struct:
curr_val = st
for attr in attr_path:
Expand Down Expand Up @@ -211,6 +233,7 @@ def __init__(self, lhs: Union["Promise", Any], op: ComparisonOps, rhs: Union["Pr
self._op = op
self._lhs = None
self._rhs = None

if isinstance(lhs, Promise):
self._lhs = lhs
if lhs.is_ready:
Expand Down
132 changes: 92 additions & 40 deletions flytekit/core/type_engine.py

Large diffs are not rendered by default.

36 changes: 35 additions & 1 deletion flytekit/types/directory/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,20 @@
from uuid import UUID

import fsspec
import msgpack
from dataclasses_json import DataClassJsonMixin, config
from fsspec.utils import get_protocol
from marshmallow import fields
from mashumaro.types import SerializableType

from flytekit import BlobType
from flytekit.core.constants import MESSAGEPACK
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError, get_batch_size
from flytekit.exceptions.user import FlyteAssertion
from flytekit.models import types as _type_models
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.literals import Binary, Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType
from flytekit.types.file import FileExt, FlyteFile

Expand Down Expand Up @@ -504,9 +506,41 @@ def to_literal(
else:
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=source_path)))

def from_binary_idl(
self, binary_idl_object: Binary, expected_python_type: typing.Type[FlyteDirectory]
) -> FlyteDirectory:
if binary_idl_object.tag == MESSAGEPACK:
python_val = msgpack.loads(binary_idl_object.value)
path = python_val.get("path", None)

if path is None:
raise ValueError("FlyteDirectory's path should not be None")

return FlyteDirToMultipartBlobTransformer().to_python_value(
FlyteContextManager.current_context(),
Literal(
scalar=Scalar(
blob=Blob(
metadata=BlobMetadata(
type=_core_types.BlobType(
format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART
)
),
uri=path,
)
)
),
expected_python_type,
)
else:
raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`")

def to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Type[FlyteDirectory]
) -> FlyteDirectory:
if lv.scalar.binary:
return self.from_binary_idl(lv.scalar.binary, expected_python_type)

uri = lv.scalar.blob.uri

if lv.scalar.blob.metadata.type.dimensionality != BlobType.BlobDimensionality.MULTIPART:
Expand Down
37 changes: 36 additions & 1 deletion flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,20 @@
from typing import cast
from urllib.parse import unquote

import msgpack
from dataclasses_json import config
from marshmallow import fields
from mashumaro.mixins.json import DataClassJSONMixin
from mashumaro.types import SerializableType

from flytekit.core.constants import MESSAGEPACK
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError, get_underlying_type
from flytekit.exceptions.user import FlyteAssertion
from flytekit.loggers import logger
from flytekit.models.core import types as _core_types
from flytekit.models.core.types import BlobType
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.literals import Binary, Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType
from flytekit.types.pickle.pickle import FlytePickleTransformer

Expand Down Expand Up @@ -518,9 +520,42 @@ def get_additional_headers(source_path: str | os.PathLike) -> typing.Dict[str, s
return {"ContentEncoding": "gzip"}
return {}

def from_binary_idl(
self, binary_idl_object: Binary, expected_python_type: typing.Union[typing.Type[FlyteFile], os.PathLike]
) -> FlyteFile:
if binary_idl_object.tag == MESSAGEPACK:
python_val = msgpack.loads(binary_idl_object.value)
path = python_val.get("path", None)

if path is None:
raise ValueError("FlyteFile's path should not be None")

return FlyteFilePathTransformer().to_python_value(
FlyteContextManager.current_context(),
Literal(
scalar=Scalar(
blob=Blob(
metadata=BlobMetadata(
type=_core_types.BlobType(
format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
),
uri=path,
)
)
),
expected_python_type,
)
else:
raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`")

def to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Union[typing.Type[FlyteFile], os.PathLike]
) -> FlyteFile:
# Handle dataclass attribute access
if lv.scalar and lv.scalar.binary:
return self.from_binary_idl(lv.scalar.binary, expected_python_type)

try:
uri = lv.scalar.blob.uri
except AttributeError:
Expand Down
25 changes: 24 additions & 1 deletion flytekit/types/schema/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@
from pathlib import Path
from typing import Type

import msgpack
from dataclasses_json import config
from marshmallow import fields
from mashumaro.mixins.json import DataClassJSONMixin
from mashumaro.types import SerializableType

from flytekit.core.constants import MESSAGEPACK
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.loggers import logger
from flytekit.models.literals import Literal, Scalar, Schema
from flytekit.models.literals import Binary, Literal, Scalar, Schema
from flytekit.models.types import LiteralType, SchemaType

T = typing.TypeVar("T")
Expand Down Expand Up @@ -439,7 +441,28 @@ def to_literal(
schema.remote_path = ctx.file_access.put_data(schema.local_path, schema.remote_path, is_multipart=True)
return Literal(scalar=Scalar(schema=Schema(schema.remote_path, self._get_schema_type(python_type))))

def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[FlyteSchema]) -> FlyteSchema:
if binary_idl_object.tag == MESSAGEPACK:
python_val = msgpack.loads(binary_idl_object.value)
remote_path = python_val.get("remote_path", None)

if remote_path is None:
raise ValueError("FlyteSchema's path should not be None")

t = FlyteSchemaTransformer()
return t.to_python_value(
FlyteContextManager.current_context(),
Literal(scalar=Scalar(schema=Schema(remote_path, t._get_schema_type(expected_python_type)))),
expected_python_type,
)
else:
raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`")

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[FlyteSchema]) -> FlyteSchema:
# Handle dataclass attribute access
if lv.scalar and lv.scalar.binary:
return self.from_binary_idl(lv.scalar.binary, expected_python_type)

def downloader(x, y):
ctx.file_access.get_data(x, y, is_multipart=True)

Expand Down
38 changes: 36 additions & 2 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dataclasses import dataclass, field, is_dataclass
from typing import Dict, Generator, List, Optional, Type, Union

import msgpack
from dataclasses_json import config
from fsspec.utils import get_protocol
from marshmallow import fields
Expand All @@ -16,13 +17,14 @@
from typing_extensions import Annotated, TypeAlias, get_args, get_origin

from flytekit import lazy_module
from flytekit.core.constants import MESSAGEPACK
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.deck.renderer import Renderable
from flytekit.loggers import developer_logger, logger
from flytekit.models import literals
from flytekit.models import types as type_models
from flytekit.models.literals import Literal, Scalar, StructuredDatasetMetadata
from flytekit.models.literals import Binary, Literal, Scalar, StructuredDatasetMetadata
from flytekit.models.types import LiteralType, SchemaType, StructuredDatasetType

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -715,6 +717,34 @@ def encode(
sd._already_uploaded = True
return lit

def from_binary_idl(
self, binary_idl_object: Binary, expected_python_type: Type[T] | StructuredDataset
) -> T | StructuredDataset:
if binary_idl_object.tag == MESSAGEPACK:
python_val = msgpack.loads(binary_idl_object.value)
uri = python_val.get("uri", None)
file_format = python_val.get("file_format", None)

if uri is None:
raise ValueError("StructuredDataset's uri and file format should not be None")

return StructuredDatasetTransformerEngine().to_python_value(
FlyteContextManager.current_context(),
Literal(
scalar=Scalar(
structured_dataset=StructuredDataset(
metadata=StructuredDatasetMetadata(
structured_dataset_type=StructuredDatasetType(format=file_format)
),
uri=uri,
)
)
),
expected_python_type,
)
else:
raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`")

def to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T] | StructuredDataset
) -> T | StructuredDataset:
Expand Down Expand Up @@ -748,6 +778,10 @@ def t2(in_a: Annotated[StructuredDataset, kwtypes(col_b=float)]): ...
| | the running task's signature. | |
+-----------------------------+-----------------------------------------+--------------------------------------+
"""
# Handle dataclass attribute access
if lv.scalar and lv.scalar.binary:
return self.from_binary_idl(lv.scalar.binary, expected_python_type)

# Detect annotations and extract out all the relevant information that the user might supply
expected_python_type, column_dict, storage_fmt, pa_schema = extract_cols_and_format(expected_python_type)

Expand Down
3 changes: 3 additions & 0 deletions plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from datetime import datetime, timedelta
from unittest import mock
import msgpack
import base64

import pytest
from flyteidl.core.execution_pb2 import TaskExecution
Expand Down Expand Up @@ -161,6 +163,7 @@ async def test_agent(mock_boto_call, mock_return_value):
if "pickle_check" in mock_return_value[0][0]:
assert "pickle_file" in outputs["result"]
else:
outputs["result"] = msgpack.loads(base64.b64decode(outputs["result"]))
assert (
outputs["result"]["EndpointConfigArn"]
== "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config"
Expand Down
5 changes: 3 additions & 2 deletions plugins/flytekit-openai/tests/openai_batch/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from datetime import timedelta
from unittest import mock
from unittest.mock import AsyncMock

import msgpack
import base64
import pytest
from flyteidl.core.execution_pb2 import TaskExecution
from flytekitplugins.openai.batch.agent import BatchEndpointMetadata
Expand Down Expand Up @@ -159,7 +160,7 @@ async def test_openai_batch_agent(mock_retrieve, mock_create, mock_context):
outputs = literal_map_string_repr(resource.outputs)
result = outputs["result"]

assert result == batch_retrieve_result.to_dict()
assert msgpack.loads(base64.b64decode(result)) == batch_retrieve_result.to_dict()

# Status: Failed
mock_retrieve.return_value = batch_retrieve_result_failure
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies = [
"marshmallow-enum",
"marshmallow-jsonschema>=0.12.0",
"mashumaro>=3.11",
"msgpack>=1.1.0",
"protobuf!=4.25.0",
"pygments",
"python-json-logger>=2.0.0",
Expand Down
Loading

0 comments on commit 1948110

Please sign in to comment.