Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pydantic Transformer V2 #2792

Merged
merged 110 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
110 commits
Select commit Hold shift + click to select a range
333a05c
Pydantic Transformer V2
Future-Outlier Oct 8, 2024
9251508
add __init__.py
Future-Outlier Oct 8, 2024
4c46dee
add json schema
Future-Outlier Oct 8, 2024
63ab9fd
convert float to int
Future-Outlier Oct 8, 2024
357ca00
change gitsha in test script mode
Future-Outlier Oct 9, 2024
cdd1b25
change gitsha
Future-Outlier Oct 9, 2024
c89b59a
use strict map=false
Future-Outlier Oct 9, 2024
6b480da
Test flytefile console input + attr access
Future-Outlier Oct 9, 2024
15ce9ad
add conditional branch
Future-Outlier Oct 9, 2024
94ce092
better rx
Future-Outlier Oct 9, 2024
d97bc2a
Add flytedir generic -> flytedir
Future-Outlier Oct 9, 2024
4b0f008
merge async type engine
Future-Outlier Oct 10, 2024
afd4344
support enum
Future-Outlier Oct 10, 2024
773a3b6
update
Future-Outlier Oct 10, 2024
7afbb00
add tests for input from flyte console
Future-Outlier Oct 10, 2024
0258853
Add Tests for dataclass in BaseModel and pydantic.dataclass in BaseModel
Future-Outlier Oct 11, 2024
f0fa0b9
update
Future-Outlier Oct 11, 2024
f4e7581
Merge branch 'master' into pydantic-plugin-v2-with-msgpack
Future-Outlier Oct 11, 2024
dac2da4
update thomas's advice
Future-Outlier Oct 11, 2024
3531201
change tree file structure
Future-Outlier Oct 11, 2024
e34ef69
update Niel's advice
Future-Outlier Oct 11, 2024
2755c14
> to >=
Future-Outlier Oct 11, 2024
69e455d
try monodoc build again
Future-Outlier Oct 11, 2024
93e74ae
add pydantic README.md
Future-Outlier Oct 11, 2024
77ded14
revert -vvv in monodocs
Future-Outlier Oct 11, 2024
8e94aed
use model_validate_json to turn protobuf struct to python val
Future-Outlier Oct 11, 2024
f9e95f6
fix issue
Future-Outlier Oct 11, 2024
a37f8e3
handle flyte types in dict transformer from protobuf struct input (e.…
Future-Outlier Oct 14, 2024
26a848b
Add print
Future-Outlier Oct 14, 2024
029d159
expected python type
Future-Outlier Oct 14, 2024
f175218
switch call function order
Future-Outlier Oct 14, 2024
cbff823
try msgpack to handle protobug struct
Future-Outlier Oct 14, 2024
ae380fd
Better Comment in Dict Transformer
Future-Outlier Oct 14, 2024
28ef345
Propeller -> FlytePropeller
Future-Outlier Oct 14, 2024
55334e9
dict_to_flyte_types
Future-Outlier Oct 14, 2024
169ac0c
remove comments
Future-Outlier Oct 14, 2024
3613e0c
add attr for protobuf struct . dict
Future-Outlier Oct 14, 2024
ffce49d
Add Life Cycle for Flyte Types
Future-Outlier Oct 14, 2024
5c9be13
better comments for derializing flyteschema and sd
Future-Outlier Oct 14, 2024
429ccd5
nit
Future-Outlier Oct 15, 2024
8a8c6ce
add back pv._remote_path = None to flytefile and flytedir
Future-Outlier Oct 15, 2024
81b2169
experiment
Future-Outlier Oct 15, 2024
d72d72d
experiment
Future-Outlier Oct 15, 2024
6d6c112
experiment
Future-Outlier Oct 15, 2024
a860803
Add comments
Future-Outlier Oct 15, 2024
fb82dd5
update Yee's advice
Future-Outlier Oct 15, 2024
8eb45ff
code dc -> bm
Future-Outlier Oct 15, 2024
2edd542
Example dc -> bm and Example all flyte types
Future-Outlier Oct 15, 2024
989e6e0
fix union dataclass, not yet add comments
Future-Outlier Oct 15, 2024
ffd3aa2
solve conflict
Future-Outlier Oct 15, 2024
f699419
add pydantic and dataclass optional test
Future-Outlier Oct 15, 2024
86b34ab
NoneType=type(None)
Future-Outlier Oct 15, 2024
ebe61db
fix union transformer none case with Eduardo
Future-Outlier Oct 15, 2024
ac624d4
add comments for none type transformer + union transformer
Future-Outlier Oct 15, 2024
9982e8c
add TODO
Future-Outlier Oct 15, 2024
42ab3d3
Merge branch 'master' into pydantic-plugin-v2-with-msgpack
Future-Outlier Oct 17, 2024
85839e9
use deserailize = True
Future-Outlier Oct 17, 2024
f188415
add all deserialize
Future-Outlier Oct 17, 2024
90771fd
better comments
Future-Outlier Oct 17, 2024
1d8fe55
better comments
Future-Outlier Oct 17, 2024
e1b3f4f
test
Future-Outlier Oct 17, 2024
729ba27
Fix flyte directory issue by discussion with Kevin
Future-Outlier Oct 17, 2024
5bd3616
merge master
Future-Outlier Oct 18, 2024
9883c7f
add tests for providing conext when doing serialization
Future-Outlier Oct 18, 2024
79c9dc7
Merge branch 'master' into pydantic-plugin-v2-with-msgpack
Future-Outlier Oct 18, 2024
047b6ac
lint
Future-Outlier Oct 18, 2024
26a3745
test
Future-Outlier Oct 19, 2024
f07ec94
move setattr to core
Future-Outlier Oct 19, 2024
a64efa0
remove comments
Future-Outlier Oct 19, 2024
6394175
lint
Future-Outlier Oct 19, 2024
d6707a0
remove
Future-Outlier Oct 19, 2024
6ab30a0
testing
Future-Outlier Oct 19, 2024
3f39657
Merge branch 'master' into pydantic-plugin-v2-with-msgpack
Future-Outlier Oct 19, 2024
f2e165a
nit
Future-Outlier Oct 19, 2024
8e848bf
flytekit/core/type_engine.py
Future-Outlier Oct 19, 2024
7d291e3
add tests for Union
Future-Outlier Oct 21, 2024
5374a14
Trigger CI
Future-Outlier Oct 21, 2024
7d045be
remove nonetype
Future-Outlier Oct 21, 2024
8dd746a
raw=Fasle as default
Future-Outlier Oct 21, 2024
eba2bf0
Merge branch 'master' into pydantic-plugin-v2-with-msgpack
Future-Outlier Oct 22, 2024
753b240
pydantic move to core test
Future-Outlier Oct 22, 2024
a380def
move to core
Future-Outlier Oct 22, 2024
971aa47
log
Future-Outlier Oct 22, 2024
ff2d4a0
update
Future-Outlier Oct 22, 2024
6c3450b
lint
Future-Outlier Oct 22, 2024
5a1d58b
test
Future-Outlier Oct 22, 2024
4cfde94
nit
Future-Outlier Oct 22, 2024
e0d0a76
nit
Future-Outlier Oct 22, 2024
76c1d56
lint
Future-Outlier Oct 22, 2024
7972f95
move to type_engine
Future-Outlier Oct 22, 2024
6de3478
move back to init
Future-Outlier Oct 22, 2024
7c4c009
update kevin's advice
Future-Outlier Oct 23, 2024
4fc0622
wip
Future-Outlier Oct 23, 2024
edfc8ef
use decorator
Future-Outlier Oct 23, 2024
d8e4c6a
Merge branch 'master' into pydantic-plugin-v2-with-msgpack
Future-Outlier Oct 23, 2024
6c5b19f
decorator
Future-Outlier Oct 23, 2024
76ae0ef
fix syntax to support python 3.9
Future-Outlier Oct 24, 2024
dfe8762
add Eduardo's advice
Future-Outlier Oct 25, 2024
7735352
warning
Future-Outlier Oct 25, 2024
959f02b
Merge branch 'master' into pydantic-plugin-v2-with-msgpack
Future-Outlier Oct 25, 2024
85df643
Update Yee's advice
Future-Outlier Oct 28, 2024
f440dab
Show traceback by default (#2862)
pingsutw Oct 25, 2024
4f5f74e
Support Identifier in generate_console_url (#2868)
thomasjpfan Oct 25, 2024
edf995c
Support overriding node metadata for array node (#2865)
pvditt Oct 25, 2024
3dbb336
Fix Jupyter Versioning (#2866)
Mecoli1219 Oct 26, 2024
2bccd1b
improved output handling in notebooks (#2869)
kumare3 Oct 27, 2024
bf283f6
tests by Yee's suggestion
Future-Outlier Oct 28, 2024
6993f09
fix tests
Future-Outlier Oct 28, 2024
a3fef67
Merge branch 'master' into pydantic-plugin-v2-with-msgpack
Future-Outlier Oct 28, 2024
d983c07
format tests
Future-Outlier Oct 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2052,6 +2052,17 @@ def _check_and_covert_float(lv: Literal) -> float:
raise TypeTransformerFailedError(f"Cannot convert literal {lv} to float")


def _check_and_covert_int(lv: Literal) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given this is the main function for the integer simple transformer, can we give this a more meaningful name? Will we delete this function after console is updated?

Copy link
Member Author

@Future-Outlier Future-Outlier Oct 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. yes I'm going to provide a more meaningful name
  2. If we use protobuf struct as input from the console, this can't be deleted.
    We have to provide an identifier to the console, then the console knows it need to generate msgpack idl, then it can be deleted.

if lv.scalar.primitive.integer is not None:
return lv.scalar.primitive.integer

if lv.scalar.primitive.float_value is not None:
logger.info(f"Converting literal float {lv.scalar.primitive.float_value} to int, might have precision loss.")
return int(lv.scalar.primitive.float_value)
Comment on lines +2329 to +2331
Copy link
Member Author

@Future-Outlier Future-Outlier Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for cases when you input from the flyte console, and you use attribute access directly, you have to convert the float to int.
Since javascript has only number, it can't tell the difference between int and float, and when goland (propeller) doing attribute access, it doesn't have the expected python type

class TrainConfig(BaseModel):
    lr: float = 1e-3
    batch_size: int = 32

@workflow
def wf(cfg: TrainConfig) -> TrainConfig:
    return t_args(a=cfg.lr, batch_size=cfg.batch_size)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the javascript issue and the attribute access issue are orthogonal right?

this should only be a javascript problem. attribute access should work since msgpack preserves float/int even in attribute access correct?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

YES, the attribute access works well, it's because javascript pass float to golang, and golang pass float to python.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should only be a javascript problem. attribute access should work since msgpack preserves float/int even in attribute access correct?

Yes, but when you are accessing a simple type, you have to change the behavior of SimpleTransformer.

For Pydantic Transformer, we will use strict=False as argument to convert it to right type.

    def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[BaseModel]) -> BaseModel:
        if binary_idl_object.tag == MESSAGEPACK:
            dict_obj = msgpack.loads(binary_idl_object.value)
            python_val = expected_python_type.model_validate(obj=dict_obj, strict=False)
            return python_val

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we can delete this part after console is updated right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can guarantee the console can generate an integer but not float from the input, then we can delete it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is this going to work though? Do we also do a version check of the backend?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After console does the right thing, won't this value be coming in through the binary value instead? Instead of lv.scalar.primitive.integer/float.


raise TypeTransformerFailedError(f"Cannot convert literal {lv} to int")


def _check_and_convert_void(lv: Literal) -> None:
if lv.scalar.none_type is None:
raise TypeTransformerFailedError(f"Cannot convert literal {lv} to None")
Expand All @@ -2065,7 +2076,7 @@ def _register_default_type_transformers():
int,
_type_models.LiteralType(simple=_type_models.SimpleType.INTEGER),
lambda x: Literal(scalar=Scalar(primitive=Primitive(integer=x))),
lambda x: x.scalar.primitive.integer,
_check_and_covert_int,
)
)

Expand Down
5 changes: 4 additions & 1 deletion flytekit/interaction/click_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def is_pydantic_basemodel(python_type: typing.Type) -> bool:
return False
else:
try:
from pydantic.v1 import BaseModel
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1

return issubclass(python_type, BaseModelV1) or issubclass(python_type, BaseModelV2)
Comment on lines +43 to +46
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for backward compatible

except ImportError:
from pydantic import BaseModel

Expand Down
3 changes: 2 additions & 1 deletion flytekit/types/directory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
:template: file_types.rst

FlyteDirectory
FlyteDirToMultipartBlobTransformer
TensorboardLogs
TFRecordsDirectory
"""

import typing

from .types import FlyteDirectory
from .types import FlyteDirectory, FlyteDirToMultipartBlobTransformer
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to import FlyteDirToMultipartBlobTransformer in the pydantic plugin, we have to import here.


# The following section provides some predefined aliases for commonly used FlyteDirectory formats.

Expand Down
3 changes: 2 additions & 1 deletion flytekit/types/file/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
:template: file_types.rst

FlyteFile
FlyteFilePathTransformer
HDF5EncodedFile
HTMLPage
JoblibSerializedFile
Expand All @@ -25,7 +26,7 @@

from typing_extensions import Annotated, get_args, get_origin

from .file import FlyteFile
from .file import FlyteFile, FlyteFilePathTransformer


class FileExt:
Expand Down
37 changes: 35 additions & 2 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
import mimetypes
import os
import pathlib
Expand All @@ -11,6 +12,8 @@

import msgpack
from dataclasses_json import config
from google.protobuf import json_format as _json_format
from google.protobuf.struct_pb2 import Struct
from marshmallow import fields
from mashumaro.mixins.json import DataClassJSONMixin
from mashumaro.types import SerializableType
Expand Down Expand Up @@ -549,12 +552,42 @@ def from_binary_idl(
else:
raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`")

def from_generic_idl(
self, generic: Struct, expected_python_type: typing.Union[typing.Type[FlyteFile], os.PathLike]
):
json_str = _json_format.MessageToJson(generic)
python_val = json.loads(json_str)
path = python_val.get("path", None)

if path is None:
raise ValueError("FlyteFile's path should not be None")

return FlyteFilePathTransformer().to_python_value(
FlyteContextManager.current_context(),
Literal(
scalar=Scalar(
blob=Blob(
metadata=BlobMetadata(
type=_core_types.BlobType(
format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
),
uri=path,
)
)
),
expected_python_type,
)

def to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Union[typing.Type[FlyteFile], os.PathLike]
) -> FlyteFile:
# Handle dataclass attribute access
if lv.scalar and lv.scalar.binary:
return self.from_binary_idl(lv.scalar.binary, expected_python_type)
if lv.scalar:
if lv.scalar.binary:
return self.from_binary_idl(lv.scalar.binary, expected_python_type)
if lv.scalar.generic:
return self.from_generic_idl(lv.scalar.generic, expected_python_type)
Comment on lines +655 to +659
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class DC(BaseModel):
    ff: FlyteFile = Field(default_factory=lambda: FlyteFile("s3://my-s3-bucket/example.txt"))

@task(container_image=image)
def t_args(dc: DC) -> DC:
    with open(dc.ff, "r") as f:
        print(f.read())
    return dc

@task(container_image=image)
def t_ff(ff: FlyteFile) -> FlyteFile:
    with open(ff, "r") as f:
        print(f.read())
    return ff

@workflow
def wf(dc: DC) -> DC:
    t_ff(dc.ff)
    return t_args(dc=dc)

this is for this case input from flyteconsole.


try:
uri = lv.scalar.blob.uri
Expand Down
1 change: 1 addition & 0 deletions flytekit/types/schema/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .types import (
FlyteSchema,
FlyteSchemaTransformer,
LocalIOSchemaReader,
LocalIOSchemaWriter,
SchemaEngine,
Expand Down
11 changes: 8 additions & 3 deletions flytekit/types/structured/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
:template: custom.rst
:toctree: generated/

StructuredDataset
StructuredDatasetEncoder
StructuredDatasetDecoder
StructuredDataset
StructuredDatasetDecoder
StructuredDatasetEncoder
StructuredDatasetMetadata
StructuredDatasetTransformerEngine
StructuredDatasetType
"""

from flytekit.deck.renderer import ArrowRenderer, TopFrameRenderer
Expand All @@ -19,7 +22,9 @@
StructuredDataset,
StructuredDatasetDecoder,
StructuredDatasetEncoder,
StructuredDatasetMetadata,
StructuredDatasetTransformerEngine,
StructuredDatasetType,
)


Expand Down
1 change: 1 addition & 0 deletions plugins/flytekit-pydantic-v2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# TMP
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile
from flytekit.types.schema import FlyteSchema
from flytekit.types.structured import StructuredDataset

from . import transformer
from .custom import (
deserialize_flyte_dir,
deserialize_flyte_file,
deserialize_flyte_schema,
deserialize_structured_dataset,
serialize_flyte_dir,
serialize_flyte_file,
serialize_flyte_schema,
serialize_structured_dataset,
)

setattr(FlyteFile, "serialize_flyte_file", serialize_flyte_file)
setattr(FlyteFile, "deserialize_flyte_file", deserialize_flyte_file)
setattr(FlyteDirectory, "serialize_flyte_dir", serialize_flyte_dir)
setattr(FlyteDirectory, "deserialize_flyte_dir", deserialize_flyte_dir)
setattr(FlyteSchema, "serialize_flyte_schema", serialize_flyte_schema)
setattr(FlyteSchema, "deserialize_flyte_schema", deserialize_flyte_schema)
setattr(StructuredDataset, "serialize_structured_dataset", serialize_structured_dataset)
setattr(StructuredDataset, "deserialize_structured_dataset", deserialize_structured_dataset)
130 changes: 130 additions & 0 deletions plugins/flytekit-pydantic-v2/flytekitplugins/pydantic/v2/custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from typing import Dict

from flytekit.core.context_manager import FlyteContextManager
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar, Schema
from flytekit.types.directory import FlyteDirectory, FlyteDirToMultipartBlobTransformer
from flytekit.types.file import FlyteFile, FlyteFilePathTransformer
from flytekit.types.schema import FlyteSchema, FlyteSchemaTransformer
from flytekit.types.structured import (
StructuredDataset,
StructuredDatasetMetadata,
StructuredDatasetTransformerEngine,
StructuredDatasetType,
)
from pydantic import model_serializer, model_validator


@model_serializer
def serialize_flyte_file(self) -> Dict[str, str]:
lv = FlyteFilePathTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None)
return {"path": lv.scalar.blob.uri}


@model_validator(mode="after")
def deserialize_flyte_file(self) -> FlyteFile:
pv = FlyteFilePathTransformer().to_python_value(
FlyteContextManager.current_context(),
Literal(
scalar=Scalar(
blob=Blob(
metadata=BlobMetadata(
type=_core_types.BlobType(
format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
),
uri=self.path,
)
)
),
type(self),
)
pv._remote_path = None
return pv


@model_serializer
def serialize_flyte_dir(self) -> Dict[str, str]:
lv = FlyteDirToMultipartBlobTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None)
return {"path": lv.scalar.blob.uri}


@model_validator(mode="after")
def deserialize_flyte_dir(self) -> FlyteDirectory:
pv = FlyteDirToMultipartBlobTransformer().to_python_value(
FlyteContextManager.current_context(),
Literal(
scalar=Scalar(
blob=Blob(
metadata=BlobMetadata(
type=_core_types.BlobType(
format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART
)
),
uri=self.path,
)
)
),
type(self),
)
pv._remote_directory = None
return pv


@model_serializer
def serialize_flyte_schema(self) -> Dict[str, str]:
FlyteSchemaTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None)
return {"remote_path": self.remote_path}


@model_validator(mode="after")
def deserialize_flyte_schema(self) -> FlyteSchema:
# If we call the method to_python_value, FlyteSchemaTransformer will overwrite the local_path,
# which will lose our data.
# If this data is from an existed FlyteSchema, local path will be None.

if hasattr(self, "_local_path"):
return self

t = FlyteSchemaTransformer()
return t.to_python_value(
FlyteContextManager.current_context(),
Literal(scalar=Scalar(schema=Schema(self.remote_path, t._get_schema_type(type(self))))),
type(self),
)


@model_serializer
def serialize_structured_dataset(self) -> Dict[str, str]:
lv = StructuredDatasetTransformerEngine().to_literal(FlyteContextManager.current_context(), self, type(self), None)
sd = StructuredDataset(uri=lv.scalar.structured_dataset.uri)
sd.file_format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format
return {
"uri": sd.uri,
"file_format": sd.file_format,
}


@model_validator(mode="after")
def deserialize_structured_dataset(self) -> StructuredDataset:
# If we call the method to_python_value, StructuredDatasetTransformerEngine will overwrite the 'dataframe',
# which will lose our data.
# If this data is from an existed StructuredDataset, dataframe will be None.

if hasattr(self, "dataframe"):
return self

return StructuredDatasetTransformerEngine().to_python_value(
FlyteContextManager.current_context(),
Literal(
scalar=Scalar(
structured_dataset=StructuredDataset(
metadata=StructuredDatasetMetadata(
structured_dataset_type=StructuredDatasetType(format=self.file_format)
),
uri=self.uri,
)
)
),
type(self),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import json
from typing import Type

import msgpack
from google.protobuf import json_format as _json_format

from flytekit import FlyteContext
from flytekit.core.constants import MESSAGEPACK
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.loggers import logger
from flytekit.models import types
from flytekit.models.literals import Binary, Literal, Scalar
from flytekit.models.types import LiteralType, TypeStructure
from pydantic import BaseModel


class PydanticTransformer(TypeTransformer[BaseModel]):
def __init__(self):
super().__init__("Pydantic Transformer", BaseModel, enable_type_assertions=False)

def get_literal_type(self, t: Type[BaseModel]) -> LiteralType:
schema = t.model_json_schema()
literal_type = {}
fields = t.__annotations__.items()

for name, python_type in fields:
try:
literal_type[name] = TypeEngine.to_literal_type(python_type)
except Exception as e:
logger.warning(
"Field {} of type {} cannot be converted to a literal type. Error: {}".format(name, python_type, e)
)

ts = TypeStructure(tag="", dataclass_type=literal_type)

return types.LiteralType(simple=types.SimpleType.STRUCT, metadata=schema, structure=ts)

def to_literal(
self,
ctx: FlyteContext,
python_val: BaseModel,
python_type: Type[BaseModel],
expected: types.LiteralType,
) -> Literal:
dict_obj = python_val.model_dump()
msgpack_bytes = msgpack.dumps(dict_obj)
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack")))

def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[BaseModel]) -> BaseModel:
if binary_idl_object.tag == MESSAGEPACK:
dict_obj = msgpack.loads(binary_idl_object.value, raw=False, strict_map_key=False)
python_val = expected_python_type.model_validate(obj=dict_obj, strict=False)
return python_val
else:
raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`")

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[BaseModel]) -> BaseModel:
"""
There will have 2 kinds of literal values:
1. protobuf Struct (From Flyte Console)
2. binary scalar (Others)
Hence we have to handle 2 kinds of cases.
"""
if lv and lv.scalar and lv.scalar.binary is not None:
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore

json_str = _json_format.MessageToJson(lv.scalar.generic)
dict_obj = json.loads(json_str)
python_val = expected_python_type.model_validate(obj=dict_obj, strict=False)
return python_val


TypeEngine.register(PydanticTransformer())
Loading
Loading