Skip to content

Commit

Permalink
Override Dataclass Serialization/Deserialization Behavior for `FlyteT…
Browse files Browse the repository at this point in the history
…ypes` by `mashumaro` (#2554)

* Add to_json and from_json to Flyte type

Signed-off-by: Future-Outlier <[email protected]>

* Add to_json and from_json to Flyte type

Signed-off-by: Future-Outlier <[email protected]>

* remove call of self._serialize_flyte_type in dataclass transformer to_literal function

Signed-off-by: Future-Outlier <[email protected]>

* uncomment _serialize_flyte_type in dataclass transformer

Signed-off-by: Future-Outlier <[email protected]>

* use mashumaro SerializableType in flytefile, implemented _serialize and _deserialize methods

Signed-off-by: Future-Outlier <[email protected]>

* remove flytefile in serialize and deserialize function

Signed-off-by: Future-Outlier <[email protected]>

* support FlyteDirectory

Signed-off-by: Future-Outlier <[email protected]>

* uncomment self._serialize_flyte_type() in DataclassTransformer to_literal

Signed-off-by: Future-Outlier <[email protected]>

* lint

Signed-off-by: Future-Outlier <[email protected]>

* lint

Signed-off-by: Future-Outlier <[email protected]>

* remove a line

Signed-off-by: Future-Outlier <[email protected]>

* add print

Signed-off-by: Future-Outlier <[email protected]>

* uncomment deserialize in dataclass

Signed-off-by: Future-Outlier <[email protected]>

* remove dataclass json in the dataclass transformer

Signed-off-by: Future-Outlier <[email protected]>

* remove comments

Signed-off-by: Future-Outlier <[email protected]>

* remove prints

Signed-off-by: Future-Outlier <[email protected]>

* update notes

Signed-off-by: Future-Outlier <[email protected]>

* lint and fix test

Signed-off-by: Future-Outlier <[email protected]>

* Support structured dataset in dataclass

Signed-off-by: Future-Outlier <[email protected]>

* add back DataClassJsonMixin inheritance in test

Signed-off-by: Future-Outlier <[email protected]>

* add flytefile type hints

Signed-off-by: Future-Outlier <[email protected]>

* Improve type hints and use FlyteContextManager instead of FlyteContext

Signed-off-by: Future-Outlier <[email protected]>

* rename serialize flyte types to _convert_flyte_type_serializable and deserialize flyte types to _revert_to_flyte_type

Signed-off-by: Future-Outlier <[email protected]>

* Add comments to describe the dataclass transformer's lifecycle

Signed-off-by: Future-Outlier <[email protected]>

* rename using _make_flyte_type_serializable

Signed-off-by: Future-Outlier <[email protected]>

* Add logs to remind users to not use FlyteFile or FlyteDirectory in  dataclass

Signed-off-by: Future-Outlier <[email protected]>

* Add unit tests in test_dataclass.py

Signed-off-by: Future-Outlier <[email protected]>

* Add Try Catch in dataclass transformer to literal

Signed-off-by: Future-Outlier <[email protected]>

* support default input

Signed-off-by: Future-Outlier <[email protected]>

* lint

Signed-off-by: Future-Outlier <[email protected]>

* generate random prefix for file and dir

Signed-off-by: Future-Outlier <[email protected]>

* upload successful tests

Signed-off-by: Future-Outlier <[email protected]>

* update flyteschema behaviour

Signed-off-by: Future-Outlier <[email protected]>

* add type hints

Signed-off-by: Future-Outlier <[email protected]>

* fix flyteschema tests

Signed-off-by: Future-Outlier <[email protected]>

* update tests

Signed-off-by: Future-Outlier <[email protected]>

* add coverage.xml in .gitignore

Signed-off-by: Future-Outlier <[email protected]>

* kevin's update

Signed-off-by: Kevin Su <[email protected]>

* add  delattr(cls, "__class_getitem__")

Signed-off-by: Future-Outlier <[email protected]>

* use AttributeHider to change the behavior of hasattr(cls, __class_getitem__) for FlyteTypes

Signed-off-by: Future-Outlier <[email protected]>

* remove get_origin()

Signed-off-by: Future-Outlier <[email protected]>

* add back tests

Signed-off-by: Future-Outlier <[email protected]>

* Update pingsu's advice

Signed-off-by: Future-Outlier <[email protected]>

* test: create a variable for DCWithOptional

Signed-off-by: Future-Outlier <[email protected]>

* remove parent class from the AttributeHider

Signed-off-by: Future-Outlier <[email protected]>

---------

Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
Signed-off-by: Jan Fiedler <[email protected]>
  • Loading branch information
2 people authored and fiedlerNr9 committed Jul 25, 2024
1 parent af384c4 commit 14b1e38
Show file tree
Hide file tree
Showing 13 changed files with 1,120 additions and 207 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ docs/source/_tags/
.hypothesis
.npm
/**/target
coverage.xml

# Version file is auto-generated by setuptools_scm
flytekit/_version.py
7 changes: 4 additions & 3 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import tempfile
import typing
from dataclasses import dataclass, field, fields
from typing import cast, get_args
from typing import get_args

import rich_click as click
from dataclasses_json import DataClassJsonMixin
from mashumaro.codecs.json import JSONEncoder
from rich.progress import Progress

from flytekit import Annotations, FlyteContext, FlyteContextManager, Labels, Literal
Expand Down Expand Up @@ -395,7 +395,8 @@ def to_click_option(
if type(default_val) == dict or type(default_val) == list:
default_val = json.dumps(default_val)
else:
default_val = cast(DataClassJsonMixin, default_val).to_json()
encoder = JSONEncoder(python_type)
default_val = encoder.encode(default_val)
if literal_var.type.metadata:
description_extra = f": {json.dumps(literal_var.type.metadata)}"

Expand Down
230 changes: 65 additions & 165 deletions flytekit/core/type_engine.py

Large diffs are not rendered by default.

45 changes: 44 additions & 1 deletion flytekit/types/directory/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dataclasses_json import DataClassJsonMixin, config
from fsspec.utils import get_protocol
from marshmallow import fields
from mashumaro.types import SerializableType

from flytekit import BlobType
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
Expand All @@ -32,7 +33,7 @@ def noop(): ...


@dataclass
class FlyteDirectory(DataClassJsonMixin, os.PathLike, typing.Generic[T]):
class FlyteDirectory(SerializableType, DataClassJsonMixin, os.PathLike, typing.Generic[T]):
path: PathType = field(default=None, metadata=config(mm_field=fields.String())) # type: ignore
"""
.. warning::
Expand Down Expand Up @@ -120,6 +121,36 @@ def t1(in1: FlyteDirectory["svg"]):
field in the ``BlobType``.
"""

def _serialize(self) -> typing.Dict[str, str]:
lv = FlyteDirToMultipartBlobTransformer().to_literal(
FlyteContextManager.current_context(), self, FlyteDirectory, None
)
return {"path": lv.scalar.blob.uri}

@classmethod
def _deserialize(cls, value) -> "FlyteDirectory":
path = value.get("path", None)

if path is None:
raise ValueError("FlyteDirectory's path should not be None")

return FlyteDirToMultipartBlobTransformer().to_python_value(
FlyteContextManager.current_context(),
Literal(
scalar=Scalar(
blob=Blob(
metadata=BlobMetadata(
type=_core_types.BlobType(
format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART
)
),
uri=path,
)
)
),
cls,
)

def __init__(
self,
path: typing.Union[str, os.PathLike],
Expand Down Expand Up @@ -182,6 +213,18 @@ class _SpecificFormatDirectoryClass(FlyteDirectory):
# Get the type engine to see this as kind of a generic
__origin__ = FlyteDirectory

class AttributeHider:
def __get__(self, instance, owner):
raise AttributeError(
"""We have to return false in hasattr(cls, "__class_getitem__") to make mashumaro deserialize FlyteDirectory correctly."""
)

# Set __class_getitem__ to AttributeHider to make mashumaro deserialize FlyteDirectory correctly
# https://stackoverflow.com/questions/6057130/python-deleting-a-class-attribute-in-a-subclass/6057409
# Since mashumaro will use the method __class_getitem__ and __origin__ to construct the dataclass back
# https://github.com/Fatal1ty/mashumaro/blob/e945ee4319db49da9f7b8ede614e988cc8c8956b/mashumaro/core/meta/helpers.py#L300-L303
__class_getitem__ = AttributeHider() # type: ignore

@classmethod
def extension(cls) -> str:
return item_string
Expand Down
47 changes: 45 additions & 2 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@
import typing
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import cast
from urllib.parse import unquote

from dataclasses_json import config
from marshmallow import fields
from mashumaro.mixins.json import DataClassJSONMixin
from mashumaro.types import SerializableType

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError, get_underlying_type
from flytekit.exceptions.user import FlyteAssertion
from flytekit.loggers import logger
from flytekit.models.core import types as _core_types
from flytekit.models.core.types import BlobType
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType
Expand All @@ -29,7 +32,7 @@ def noop(): ...


@dataclass
class FlyteFile(os.PathLike, typing.Generic[T], DataClassJSONMixin):
class FlyteFile(SerializableType, os.PathLike, typing.Generic[T], DataClassJSONMixin):
path: typing.Union[str, os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) # type: ignore
"""
Since there is no native Python implementation of files and directories for the Flyte Blob type, (like how int
Expand Down Expand Up @@ -143,6 +146,34 @@ def t2() -> flytekit_typing.FlyteFile["csv"]:
return "/tmp/local_file.csv"
"""

def _serialize(self) -> typing.Dict[str, str]:
lv = FlyteFilePathTransformer().to_literal(FlyteContextManager.current_context(), self, FlyteFile, None)
return {"path": lv.scalar.blob.uri}

@classmethod
def _deserialize(cls, value) -> "FlyteFile":
path = value.get("path", None)

if path is None:
raise ValueError("FlyteFile's path should not be None")

return FlyteFilePathTransformer().to_python_value(
FlyteContextManager.current_context(),
Literal(
scalar=Scalar(
blob=Blob(
metadata=BlobMetadata(
type=_core_types.BlobType(
format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
),
uri=path,
)
)
),
cls,
)

@classmethod
def extension(cls) -> str:
return ""
Expand Down Expand Up @@ -190,6 +221,18 @@ class _SpecificFormatClass(FlyteFile):
# Get the type engine to see this as kind of a generic
__origin__ = FlyteFile

class AttributeHider:
def __get__(self, instance, owner):
raise AttributeError(
"""We have to return false in hasattr(cls, "__class_getitem__") to make mashumaro deserialize FlyteFile correctly."""
)

# Set __class_getitem__ to AttributeHider to make mashumaro deserialize FlyteFile correctly
# https://stackoverflow.com/questions/6057130/python-deleting-a-class-attribute-in-a-subclass/6057409
# Since mashumaro will use the method __class_getitem__ and __origin__ to construct the dataclass back
# https://github.com/Fatal1ty/mashumaro/blob/e945ee4319db49da9f7b8ede614e988cc8c8956b/mashumaro/core/meta/helpers.py#L300-L303
__class_getitem__ = AttributeHider() # type: ignore

@classmethod
def extension(cls) -> str:
return item_string
Expand Down Expand Up @@ -323,7 +366,7 @@ def __init__(self):
def get_format(t: typing.Union[typing.Type[FlyteFile], os.PathLike]) -> str:
if t is os.PathLike:
return ""
return typing.cast(FlyteFile, t).extension()
return cast(FlyteFile, t).extension()

def _blob_type(self, format: str) -> BlobType:
return BlobType(format=format, dimensionality=BlobType.BlobDimensionality.SINGLE)
Expand Down
32 changes: 31 additions & 1 deletion flytekit/types/schema/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dataclasses_json import config
from marshmallow import fields
from mashumaro.mixins.json import DataClassJSONMixin
from mashumaro.types import SerializableType

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
Expand Down Expand Up @@ -177,12 +178,29 @@ def get_handler(cls, t: Type) -> SchemaHandler:


@dataclass
class FlyteSchema(DataClassJSONMixin):
class FlyteSchema(SerializableType, DataClassJSONMixin):
remote_path: typing.Optional[str] = field(default=None, metadata=config(mm_field=fields.String()))
"""
This is the main schema class that users should use.
"""

def _serialize(self) -> typing.Dict[str, typing.Optional[str]]:
return {"remote_path": self.remote_path}

@classmethod
def _deserialize(cls, value) -> "FlyteSchema":
remote_path = value.get("remote_path", None)

if remote_path is None:
raise ValueError("FlyteSchema's path should not be None")

t = FlyteSchemaTransformer()
return t.to_python_value(
FlyteContextManager.current_context(),
Literal(scalar=Scalar(schema=Schema(remote_path, t._get_schema_type(cls)))),
cls,
)

@classmethod
def columns(cls) -> typing.Dict[str, typing.Type]:
return {}
Expand Down Expand Up @@ -219,6 +237,18 @@ class _TypedSchema(FlyteSchema):
# Get the type engine to see this as kind of a generic
__origin__ = FlyteSchema

class AttributeHider:
def __get__(self, instance, owner):
raise AttributeError(
"""We have to return false in hasattr(cls, "__class_getitem__") to make mashumaro deserialize FlyteSchema correctly."""
)

# Set __class_getitem__ to AttributeHider to make mashumaro deserialize FlyteSchema correctly
# https://stackoverflow.com/questions/6057130/python-deleting-a-class-attribute-in-a-subclass/6057409
# Since mashumaro will use the method __class_getitem__ and __origin__ to construct the dataclass back
# https://github.com/Fatal1ty/mashumaro/blob/e945ee4319db49da9f7b8ede614e988cc8c8956b/mashumaro/core/meta/helpers.py#L300-L303
__class_getitem__ = AttributeHider() # type: ignore

@classmethod
def columns(cls) -> typing.Dict[str, typing.Type]:
return columns
Expand Down
37 changes: 36 additions & 1 deletion flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from fsspec.utils import get_protocol
from marshmallow import fields
from mashumaro.mixins.json import DataClassJSONMixin
from mashumaro.types import SerializableType
from typing_extensions import Annotated, TypeAlias, get_args, get_origin

from flytekit import lazy_module
Expand Down Expand Up @@ -45,7 +46,7 @@


@dataclass
class StructuredDataset(DataClassJSONMixin):
class StructuredDataset(SerializableType, DataClassJSONMixin):
"""
This is the user facing StructuredDataset class. Please don't confuse it with the literals.StructuredDataset
class (that is just a model, a Python class representation of the protobuf).
Expand All @@ -54,6 +55,40 @@ class (that is just a model, a Python class representation of the protobuf).
uri: typing.Optional[str] = field(default=None, metadata=config(mm_field=fields.String()))
file_format: typing.Optional[str] = field(default=GENERIC_FORMAT, metadata=config(mm_field=fields.String()))

def _serialize(self) -> Dict[str, Optional[str]]:
lv = StructuredDatasetTransformerEngine().to_literal(
FlyteContextManager.current_context(), self, StructuredDataset, None
)
sd = StructuredDataset(uri=lv.scalar.structured_dataset.uri)
sd.file_format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format
return {
"uri": sd.uri,
"file_format": sd.file_format,
}

@classmethod
def _deserialize(cls, value) -> "StructuredDataset":
uri = value.get("uri", None)
file_format = value.get("file_format", None)

if uri is None:
raise ValueError("StructuredDataset's uri and file format should not be None")

return StructuredDatasetTransformerEngine().to_python_value(
FlyteContextManager.current_context(),
Literal(
scalar=Scalar(
structured_dataset=StructuredDataset(
metadata=StructuredDatasetMetadata(
structured_dataset_type=StructuredDatasetType(format=file_format)
),
uri=uri,
)
)
),
cls,
)

@classmethod
def columns(cls) -> typing.Dict[str, typing.Type]:
return {}
Expand Down
Loading

0 comments on commit 14b1e38

Please sign in to comment.