-
Notifications
You must be signed in to change notification settings - Fork 300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pydantic Transformer V2 #2792
Pydantic Transformer V2 #2792
Changes from 10 commits
333a05c
9251508
4c46dee
63ab9fd
357ca00
cdd1b25
c89b59a
6b480da
15ce9ad
94ce092
d97bc2a
4b0f008
afd4344
773a3b6
7afbb00
0258853
f0fa0b9
f4e7581
dac2da4
3531201
e34ef69
2755c14
69e455d
93e74ae
77ded14
8e94aed
f9e95f6
a37f8e3
26a848b
029d159
f175218
cbff823
ae380fd
28ef345
55334e9
169ac0c
3613e0c
ffce49d
5c9be13
429ccd5
8a8c6ce
81b2169
d72d72d
6d6c112
a860803
fb82dd5
8eb45ff
2edd542
989e6e0
ffd3aa2
f699419
86b34ab
ebe61db
ac624d4
9982e8c
42ab3d3
85839e9
f188415
90771fd
1d8fe55
e1b3f4f
729ba27
5bd3616
9883c7f
79c9dc7
047b6ac
26a3745
f07ec94
a64efa0
6394175
d6707a0
6ab30a0
3f39657
f2e165a
8e848bf
7d291e3
5374a14
7d045be
8dd746a
eba2bf0
753b240
a380def
971aa47
ff2d4a0
6c3450b
5a1d58b
4cfde94
e0d0a76
76c1d56
7972f95
6de3478
7c4c009
4fc0622
edfc8ef
d8e4c6a
6c5b19f
76ae0ef
dfe8762
7735352
959f02b
85df643
f440dab
4f5f74e
edf995c
3dbb336
2bccd1b
bf283f6
6993f09
a3fef67
d983c07
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2052,6 +2052,17 @@ def _check_and_covert_float(lv: Literal) -> float: | |
raise TypeTransformerFailedError(f"Cannot convert literal {lv} to float") | ||
|
||
|
||
def _check_and_covert_int(lv: Literal) -> int: | ||
if lv.scalar.primitive.integer is not None: | ||
return lv.scalar.primitive.integer | ||
|
||
if lv.scalar.primitive.float_value is not None: | ||
logger.info(f"Converting literal float {lv.scalar.primitive.float_value} to int, might have precision loss.") | ||
return int(lv.scalar.primitive.float_value) | ||
Comment on lines
+2329
to
+2331
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is for cases when you input from the flyte console, and you use attribute access directly, you have to convert the class TrainConfig(BaseModel):
lr: float = 1e-3
batch_size: int = 32
@workflow
def wf(cfg: TrainConfig) -> TrainConfig:
return t_args(a=cfg.lr, batch_size=cfg.batch_size) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the javascript issue and the attribute access issue are orthogonal right? this should only be a javascript problem. attribute access should work since msgpack preserves float/int even in attribute access correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. YES, the attribute access works well, it's because javascript pass float to golang, and golang pass float to python. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, but when you are accessing a simple type, you have to change the behavior of For Pydantic Transformer, we will use def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[BaseModel]) -> BaseModel:
if binary_idl_object.tag == MESSAGEPACK:
dict_obj = msgpack.loads(binary_idl_object.value)
python_val = expected_python_type.model_validate(obj=dict_obj, strict=False)
return python_val There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So we can delete this part after console is updated right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we can guarantee the console can generate an integer but not float from the input, then we can delete it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how is this going to work though? Do we also do a version check of the backend? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After console does the right thing, won't this value be coming in through the binary value instead? Instead of |
||
|
||
raise TypeTransformerFailedError(f"Cannot convert literal {lv} to int") | ||
|
||
|
||
def _check_and_convert_void(lv: Literal) -> None: | ||
if lv.scalar.none_type is None: | ||
raise TypeTransformerFailedError(f"Cannot convert literal {lv} to None") | ||
|
@@ -2065,7 +2076,7 @@ def _register_default_type_transformers(): | |
int, | ||
_type_models.LiteralType(simple=_type_models.SimpleType.INTEGER), | ||
lambda x: Literal(scalar=Scalar(primitive=Primitive(integer=x))), | ||
lambda x: x.scalar.primitive.integer, | ||
_check_and_covert_int, | ||
) | ||
) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,7 +36,10 @@ def is_pydantic_basemodel(python_type: typing.Type) -> bool: | |
return False | ||
else: | ||
try: | ||
from pydantic.v1 import BaseModel | ||
from pydantic import BaseModel as BaseModelV2 | ||
from pydantic.v1 import BaseModel as BaseModelV1 | ||
|
||
return issubclass(python_type, BaseModelV1) or issubclass(python_type, BaseModelV2) | ||
Comment on lines
+43
to
+46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for backward compatible |
||
except ImportError: | ||
from pydantic import BaseModel | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,13 +10,14 @@ | |
:template: file_types.rst | ||
|
||
FlyteDirectory | ||
FlyteDirToMultipartBlobTransformer | ||
TensorboardLogs | ||
TFRecordsDirectory | ||
""" | ||
|
||
import typing | ||
|
||
from .types import FlyteDirectory | ||
from .types import FlyteDirectory, FlyteDirToMultipartBlobTransformer | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to import |
||
|
||
# The following section provides some predefined aliases for commonly used FlyteDirectory formats. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from __future__ import annotations | ||
|
||
import json | ||
import mimetypes | ||
import os | ||
import pathlib | ||
|
@@ -11,6 +12,8 @@ | |
|
||
import msgpack | ||
from dataclasses_json import config | ||
from google.protobuf import json_format as _json_format | ||
from google.protobuf.struct_pb2 import Struct | ||
from marshmallow import fields | ||
from mashumaro.mixins.json import DataClassJSONMixin | ||
from mashumaro.types import SerializableType | ||
|
@@ -549,12 +552,42 @@ def from_binary_idl( | |
else: | ||
raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") | ||
|
||
def from_generic_idl( | ||
self, generic: Struct, expected_python_type: typing.Union[typing.Type[FlyteFile], os.PathLike] | ||
): | ||
json_str = _json_format.MessageToJson(generic) | ||
python_val = json.loads(json_str) | ||
path = python_val.get("path", None) | ||
|
||
if path is None: | ||
raise ValueError("FlyteFile's path should not be None") | ||
|
||
return FlyteFilePathTransformer().to_python_value( | ||
FlyteContextManager.current_context(), | ||
Literal( | ||
scalar=Scalar( | ||
blob=Blob( | ||
metadata=BlobMetadata( | ||
type=_core_types.BlobType( | ||
format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE | ||
) | ||
), | ||
uri=path, | ||
) | ||
) | ||
), | ||
expected_python_type, | ||
) | ||
|
||
def to_python_value( | ||
self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Union[typing.Type[FlyteFile], os.PathLike] | ||
) -> FlyteFile: | ||
# Handle dataclass attribute access | ||
if lv.scalar and lv.scalar.binary: | ||
return self.from_binary_idl(lv.scalar.binary, expected_python_type) | ||
if lv.scalar: | ||
if lv.scalar.binary: | ||
return self.from_binary_idl(lv.scalar.binary, expected_python_type) | ||
if lv.scalar.generic: | ||
return self.from_generic_idl(lv.scalar.generic, expected_python_type) | ||
Comment on lines
+655
to
+659
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. class DC(BaseModel):
ff: FlyteFile = Field(default_factory=lambda: FlyteFile("s3://my-s3-bucket/example.txt"))
@task(container_image=image)
def t_args(dc: DC) -> DC:
with open(dc.ff, "r") as f:
print(f.read())
return dc
@task(container_image=image)
def t_ff(ff: FlyteFile) -> FlyteFile:
with open(ff, "r") as f:
print(f.read())
return ff
@workflow
def wf(dc: DC) -> DC:
t_ff(dc.ff)
return t_args(dc=dc) this is for this case input from |
||
|
||
try: | ||
uri = lv.scalar.blob.uri | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# TMP |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from flytekit.types.directory import FlyteDirectory | ||
from flytekit.types.file import FlyteFile | ||
from flytekit.types.schema import FlyteSchema | ||
from flytekit.types.structured import StructuredDataset | ||
|
||
from . import transformer | ||
from .custom import ( | ||
deserialize_flyte_dir, | ||
deserialize_flyte_file, | ||
deserialize_flyte_schema, | ||
deserialize_structured_dataset, | ||
serialize_flyte_dir, | ||
serialize_flyte_file, | ||
serialize_flyte_schema, | ||
serialize_structured_dataset, | ||
) | ||
|
||
setattr(FlyteFile, "serialize_flyte_file", serialize_flyte_file) | ||
setattr(FlyteFile, "deserialize_flyte_file", deserialize_flyte_file) | ||
setattr(FlyteDirectory, "serialize_flyte_dir", serialize_flyte_dir) | ||
setattr(FlyteDirectory, "deserialize_flyte_dir", deserialize_flyte_dir) | ||
setattr(FlyteSchema, "serialize_flyte_schema", serialize_flyte_schema) | ||
setattr(FlyteSchema, "deserialize_flyte_schema", deserialize_flyte_schema) | ||
setattr(StructuredDataset, "serialize_structured_dataset", serialize_structured_dataset) | ||
setattr(StructuredDataset, "deserialize_structured_dataset", deserialize_structured_dataset) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
from typing import Dict | ||
|
||
from flytekit.core.context_manager import FlyteContextManager | ||
from flytekit.models.core import types as _core_types | ||
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar, Schema | ||
from flytekit.types.directory import FlyteDirectory, FlyteDirToMultipartBlobTransformer | ||
from flytekit.types.file import FlyteFile, FlyteFilePathTransformer | ||
from flytekit.types.schema import FlyteSchema, FlyteSchemaTransformer | ||
from flytekit.types.structured import ( | ||
StructuredDataset, | ||
StructuredDatasetMetadata, | ||
StructuredDatasetTransformerEngine, | ||
StructuredDatasetType, | ||
) | ||
from pydantic import model_serializer, model_validator | ||
|
||
|
||
@model_serializer | ||
def serialize_flyte_file(self) -> Dict[str, str]: | ||
lv = FlyteFilePathTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None) | ||
return {"path": lv.scalar.blob.uri} | ||
|
||
|
||
@model_validator(mode="after") | ||
def deserialize_flyte_file(self) -> FlyteFile: | ||
pv = FlyteFilePathTransformer().to_python_value( | ||
FlyteContextManager.current_context(), | ||
Literal( | ||
scalar=Scalar( | ||
blob=Blob( | ||
metadata=BlobMetadata( | ||
type=_core_types.BlobType( | ||
format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE | ||
) | ||
), | ||
uri=self.path, | ||
) | ||
) | ||
), | ||
type(self), | ||
) | ||
pv._remote_path = None | ||
return pv | ||
|
||
|
||
@model_serializer | ||
def serialize_flyte_dir(self) -> Dict[str, str]: | ||
lv = FlyteDirToMultipartBlobTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None) | ||
return {"path": lv.scalar.blob.uri} | ||
|
||
|
||
@model_validator(mode="after") | ||
def deserialize_flyte_dir(self) -> FlyteDirectory: | ||
pv = FlyteDirToMultipartBlobTransformer().to_python_value( | ||
FlyteContextManager.current_context(), | ||
Literal( | ||
scalar=Scalar( | ||
blob=Blob( | ||
metadata=BlobMetadata( | ||
type=_core_types.BlobType( | ||
format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART | ||
) | ||
), | ||
uri=self.path, | ||
) | ||
) | ||
), | ||
type(self), | ||
) | ||
pv._remote_directory = None | ||
return pv | ||
|
||
|
||
@model_serializer | ||
def serialize_flyte_schema(self) -> Dict[str, str]: | ||
FlyteSchemaTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None) | ||
return {"remote_path": self.remote_path} | ||
|
||
|
||
@model_validator(mode="after") | ||
def deserialize_flyte_schema(self) -> FlyteSchema: | ||
# If we call the method to_python_value, FlyteSchemaTransformer will overwrite the local_path, | ||
# which will lose our data. | ||
# If this data is from an existed FlyteSchema, local path will be None. | ||
|
||
if hasattr(self, "_local_path"): | ||
return self | ||
|
||
t = FlyteSchemaTransformer() | ||
return t.to_python_value( | ||
FlyteContextManager.current_context(), | ||
Literal(scalar=Scalar(schema=Schema(self.remote_path, t._get_schema_type(type(self))))), | ||
type(self), | ||
) | ||
|
||
|
||
@model_serializer | ||
def serialize_structured_dataset(self) -> Dict[str, str]: | ||
lv = StructuredDatasetTransformerEngine().to_literal(FlyteContextManager.current_context(), self, type(self), None) | ||
sd = StructuredDataset(uri=lv.scalar.structured_dataset.uri) | ||
sd.file_format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format | ||
return { | ||
"uri": sd.uri, | ||
"file_format": sd.file_format, | ||
} | ||
|
||
|
||
@model_validator(mode="after") | ||
def deserialize_structured_dataset(self) -> StructuredDataset: | ||
# If we call the method to_python_value, StructuredDatasetTransformerEngine will overwrite the 'dataframe', | ||
# which will lose our data. | ||
# If this data is from an existed StructuredDataset, dataframe will be None. | ||
|
||
if hasattr(self, "dataframe"): | ||
return self | ||
|
||
return StructuredDatasetTransformerEngine().to_python_value( | ||
FlyteContextManager.current_context(), | ||
Literal( | ||
scalar=Scalar( | ||
structured_dataset=StructuredDataset( | ||
metadata=StructuredDatasetMetadata( | ||
structured_dataset_type=StructuredDatasetType(format=self.file_format) | ||
), | ||
uri=self.uri, | ||
) | ||
) | ||
), | ||
type(self), | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import json | ||
from typing import Type | ||
|
||
import msgpack | ||
from google.protobuf import json_format as _json_format | ||
|
||
from flytekit import FlyteContext | ||
from flytekit.core.constants import MESSAGEPACK | ||
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError | ||
from flytekit.loggers import logger | ||
from flytekit.models import types | ||
from flytekit.models.literals import Binary, Literal, Scalar | ||
from flytekit.models.types import LiteralType, TypeStructure | ||
from pydantic import BaseModel | ||
|
||
|
||
class PydanticTransformer(TypeTransformer[BaseModel]): | ||
def __init__(self): | ||
super().__init__("Pydantic Transformer", BaseModel, enable_type_assertions=False) | ||
|
||
def get_literal_type(self, t: Type[BaseModel]) -> LiteralType: | ||
schema = t.model_json_schema() | ||
literal_type = {} | ||
fields = t.__annotations__.items() | ||
|
||
for name, python_type in fields: | ||
try: | ||
literal_type[name] = TypeEngine.to_literal_type(python_type) | ||
except Exception as e: | ||
logger.warning( | ||
"Field {} of type {} cannot be converted to a literal type. Error: {}".format(name, python_type, e) | ||
) | ||
|
||
ts = TypeStructure(tag="", dataclass_type=literal_type) | ||
|
||
return types.LiteralType(simple=types.SimpleType.STRUCT, metadata=schema, structure=ts) | ||
|
||
def to_literal( | ||
self, | ||
ctx: FlyteContext, | ||
python_val: BaseModel, | ||
python_type: Type[BaseModel], | ||
expected: types.LiteralType, | ||
) -> Literal: | ||
dict_obj = python_val.model_dump() | ||
msgpack_bytes = msgpack.dumps(dict_obj) | ||
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack"))) | ||
|
||
def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[BaseModel]) -> BaseModel: | ||
if binary_idl_object.tag == MESSAGEPACK: | ||
dict_obj = msgpack.loads(binary_idl_object.value, raw=False, strict_map_key=False) | ||
python_val = expected_python_type.model_validate(obj=dict_obj, strict=False) | ||
return python_val | ||
else: | ||
raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") | ||
|
||
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[BaseModel]) -> BaseModel: | ||
""" | ||
There will have 2 kinds of literal values: | ||
1. protobuf Struct (From Flyte Console) | ||
2. binary scalar (Others) | ||
Hence we have to handle 2 kinds of cases. | ||
""" | ||
if lv and lv.scalar and lv.scalar.binary is not None: | ||
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore | ||
|
||
json_str = _json_format.MessageToJson(lv.scalar.generic) | ||
dict_obj = json.loads(json_str) | ||
python_val = expected_python_type.model_validate(obj=dict_obj, strict=False) | ||
return python_val | ||
|
||
|
||
TypeEngine.register(PydanticTransformer()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
given this is the main function for the integer simple transformer, can we give this a more meaningful name? Will we delete this function after console is updated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have to provide an identifier to the console, then the console knows it need to generate msgpack idl, then it can be deleted.