diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 6721e9afff..8230cf22c8 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -356,7 +356,6 @@ jobs: - flytekit-pandera - flytekit-papermill - flytekit-polars - - flytekit-pydantic - flytekit-ray - flytekit-snowflake - flytekit-spark diff --git a/dev-requirements.in b/dev-requirements.in index 27c17ac6d0..20aba11e9d 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -30,6 +30,7 @@ torch<=1.12.1; python_version<'3.11' # pytorch 2 supports python 3.11 # pytorch 2 does not support 3.12 yet: https://github.com/pytorch/pytorch/issues/110436 torch; python_version<'3.12' +pydantic # TODO: Currently, the python-magic library causes build errors on Windows due to its dependency on DLLs for libmagic. # We have temporarily disabled this feature on Windows and are using python-magic for Mac OS and Linux instead. diff --git a/dev-requirements.txt b/dev-requirements.txt index 5fd363804e..002f5421c4 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -22,6 +22,8 @@ aiosignal==1.3.1 # via aiohttp annotated-types==0.7.0 # via pydantic +appnope==0.1.4 + # via ipykernel asn1crypto==1.5.1 # via snowflake-connector-python asttokens==2.4.1 @@ -31,7 +33,6 @@ attrs==23.2.0 # aiohttp # hypothesis # jsonlines - # visions autoflake==2.3.1 # via -r dev-requirements.in azure-core==1.30.1 @@ -73,8 +74,8 @@ cloudpickle==3.0.0 # via flytekit codespell==2.3.0 # via -r dev-requirements.in -contourpy==1.3.0 - # via matplotlib +comm==0.2.2 + # via ipykernel coverage[toml]==7.5.3 # via # -r dev-requirements.in @@ -89,12 +90,10 @@ cryptography==43.0.1 # pyjwt # pyopenssl # snowflake-connector-python -cycler==0.12.1 - # via matplotlib -dacite==1.8.1 - # via ydata-profiling dataclasses-json==0.5.9 # via flytekit +debugpy==1.8.7 + # via ipykernel decorator==5.1.1 # via # gcsfs @@ -119,8 +118,6 @@ flyteidl @ git+https://github.com/flyteorg/flyte.git@master#subdirectory=flyteid # via # -r dev-requirements.in # flytekit -fonttools==4.54.1 - # via matplotlib frozenlist==1.4.1 # via # aiohttp @@ -185,8 +182,6 @@ grpcio-status==1.62.2 # via # flytekit # google-api-core -htmlmin==0.1.12 - # via ydata-profiling hypothesis==6.103.0 # via -r dev-requirements.in icdiff==2.0.7 @@ -198,16 +193,16 @@ idna==3.7 # requests # snowflake-connector-python # yarl -imagehash==4.3.1 - # via - # visions - # ydata-profiling importlib-metadata==7.1.0 # via flytekit iniconfig==2.0.0 # via pytest -ipython==8.25.0 +ipykernel==6.29.5 # via -r dev-requirements.in +ipython==8.25.0 + # via + # -r dev-requirements.in + # ipykernel isodate==0.6.1 # via azure-storage-blob jaraco-classes==3.4.0 @@ -222,38 +217,35 @@ jaraco-functools==4.0.1 # via keyring jedi==0.19.1 # via ipython -jinja2==3.1.4 - # via ydata-profiling jmespath==1.0.1 # via botocore joblib==1.4.2 # via # -r dev-requirements.in # flytekit - # phik # scikit-learn jsonlines==4.0.0 # via flytekit jsonpickle==3.0.4 # via flytekit +jupyter-client==8.6.3 + # via + # -r dev-requirements.in + # ipykernel +jupyter-core==5.7.2 + # via + # ipykernel + # jupyter-client keyring==25.2.1 # via flytekit keyrings-alt==5.0.1 # via -r dev-requirements.in -kiwisolver==1.4.7 - # via matplotlib kubernetes==29.0.0 # via -r dev-requirements.in -llvmlite==0.43.0 - # via numba -markdown==3.7 - # via -r dev-requirements.in markdown-it-py==3.0.0 # via # flytekit # rich -markupsafe==2.1.5 - # via jinja2 marshmallow==3.21.2 # via # dataclasses-json @@ -267,14 +259,10 @@ marshmallow-jsonschema==0.13.0 # via flytekit mashumaro==3.13 # via flytekit -matplotlib==3.9.2 - # via - # phik - # seaborn - # wordcloud - # ydata-profiling matplotlib-inline==0.1.7 - # via ipython + # via + # ipykernel + # ipython mdurl==0.1.2 # via markdown-it-py mock==5.1.0 @@ -290,45 +278,29 @@ msal==1.28.0 # msal-extensions msal-extensions==1.1.0 # via azure-identity +msgpack==1.1.0 + # via flytekit multidict==6.0.5 # via # aiohttp # yarl -multimethod==1.12 - # via - # visions - # ydata-profiling mypy==1.6.1 # via -r dev-requirements.in mypy-extensions==1.0.0 # via # mypy # typing-inspect -networkx==3.3 - # via visions +nest-asyncio==1.6.0 + # via ipykernel nodeenv==1.9.0 # via pre-commit -numba==0.60.0 - # via ydata-profiling numpy==1.26.4 # via # -r dev-requirements.in - # contourpy - # imagehash - # matplotlib - # numba # pandas - # patsy - # phik # pyarrow - # pywavelets # scikit-learn # scipy - # seaborn - # statsmodels - # visions - # wordcloud - # ydata-profiling oauthlib==3.2.2 # via # kubernetes @@ -339,43 +311,25 @@ packaging==24.0 # via # docker # google-cloud-bigquery + # ipykernel # marshmallow - # matplotlib # msal-extensions - # plotly # pytest # setuptools-scm # snowflake-connector-python - # statsmodels pandas==2.2.2 - # via - # -r dev-requirements.in - # phik - # seaborn - # statsmodels - # visions - # ydata-profiling + # via -r dev-requirements.in parso==0.8.4 # via jedi -patsy==0.5.6 - # via statsmodels pexpect==4.9.0 # via ipython -phik==0.12.4 - # via ydata-profiling pillow==10.3.0 - # via - # -r dev-requirements.in - # imagehash - # matplotlib - # visions - # wordcloud + # via -r dev-requirements.in platformdirs==4.2.2 # via + # jupyter-core # snowflake-connector-python # virtualenv -plotly==5.24.1 - # via -r dev-requirements.in pluggy==1.5.0 # via pytest portalocker==2.8.2 @@ -405,6 +359,8 @@ protobuf==4.25.3 # protoc-gen-openapiv2 protoc-gen-openapiv2==0.0.1 # via flyteidl +psutil==6.1.0 + # via ipykernel ptyprocess==0.7.0 # via pexpect pure-eval==0.2.2 @@ -420,14 +376,13 @@ pyasn1-modules==0.4.0 pycparser==2.22 # via cffi pydantic==2.9.2 - # via ydata-profiling + # via -r dev-requirements.in pydantic-core==2.23.4 # via pydantic pyflakes==3.2.0 # via autoflake pygments==2.18.0 # via - # -r dev-requirements.in # flytekit # ipython # rich @@ -437,8 +392,6 @@ pyjwt[crypto]==2.8.0 # snowflake-connector-python pyopenssl==24.2.1 # via snowflake-connector-python -pyparsing==3.1.4 - # via matplotlib pytest==8.2.1 # via # -r dev-requirements.in @@ -465,8 +418,8 @@ python-dateutil==2.9.0.post0 # botocore # croniter # google-cloud-bigquery + # jupyter-client # kubernetes - # matplotlib # pandas python-json-logger==2.0.7 # via flytekit @@ -479,14 +432,15 @@ pytz==2024.1 # croniter # pandas # snowflake-connector-python -pywavelets==1.7.0 - # via imagehash pyyaml==6.0.1 # via # flytekit # kubernetes # pre-commit - # ydata-profiling +pyzmq==26.2.0 + # via + # ipykernel + # jupyter-client requests==2.32.3 # via # azure-core @@ -501,7 +455,6 @@ requests==2.32.3 # msal # requests-oauthlib # snowflake-connector-python - # ydata-profiling requests-oauthlib==2.0.0 # via # google-auth-oauthlib @@ -519,14 +472,7 @@ s3fs==2024.5.0 scikit-learn==1.5.0 # via -r dev-requirements.in scipy==1.13.1 - # via - # imagehash - # phik - # scikit-learn - # statsmodels - # ydata-profiling -seaborn==0.13.2 - # via ydata-profiling + # via scikit-learn setuptools-scm==8.1.0 # via -r dev-requirements.in six==1.16.0 @@ -535,7 +481,6 @@ six==1.16.0 # azure-core # isodate # kubernetes - # patsy # python-dateutil snowflake-connector-python==3.12.1 # via -r dev-requirements.in @@ -547,22 +492,22 @@ stack-data==0.6.3 # via ipython statsd==3.3.0 # via flytekit -statsmodels==0.14.3 - # via ydata-profiling -tenacity==9.0.0 - # via plotly threadpoolctl==3.5.0 # via scikit-learn tomlkit==0.13.2 # via snowflake-connector-python -tqdm==4.66.5 - # via ydata-profiling +tornado==6.4.1 + # via + # ipykernel + # jupyter-client traitlets==5.14.3 # via + # comm + # ipykernel # ipython + # jupyter-client + # jupyter-core # matplotlib-inline -typeguard==4.3.0 - # via ydata-profiling types-croniter==2.0.0.20240423 # via -r dev-requirements.in types-decorator==5.1.8.20240310 @@ -584,7 +529,6 @@ typing-extensions==4.12.0 # pydantic-core # rich-click # snowflake-connector-python - # typeguard # typing-inspect typing-inspect==0.9.0 # via dataclasses-json @@ -600,22 +544,16 @@ urllib3==2.2.1 # types-requests virtualenv==20.26.2 # via pre-commit -visions[type-image-path]==0.7.6 - # via ydata-profiling wcwidth==0.2.13 # via prompt-toolkit websocket-client==1.8.0 # via # docker # kubernetes -wordcloud==1.9.3 - # via ydata-profiling wrapt==1.16.0 # via aiobotocore yarl==1.9.4 # via aiohttp -ydata-profiling==4.10.0 - # via -r dev-requirements.in zipp==3.19.1 # via importlib-metadata diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 968e3153eb..a6eb70004b 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -216,6 +216,7 @@ else: from importlib.metadata import entry_points + from flytekit._version import __version__ from flytekit.configuration import Config from flytekit.core.array_node_map_task import map_task diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 8a08c7a2cd..fada71410d 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -61,7 +61,7 @@ # This is relevant for cases like Dict[int, str]. # If strict_map_key=False is not used, the decoder will raise an error when trying to decode keys that are not strictly typed.` def _default_msgpack_decoder(data: bytes) -> Any: - return msgpack.unpackb(data, raw=False, strict_map_key=False) + return msgpack.unpackb(data, strict_map_key=False) class BatchSize: @@ -216,16 +216,41 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: ) def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> Optional[T]: + """ + This function primarily handles deserialization for untyped dicts, dataclasses, Pydantic BaseModels, and attribute access.` + + For untyped dict, dataclass, and pydantic basemodel: + Life Cycle (Untyped Dict as example): + python val -> msgpack bytes -> binary literal scalar -> msgpack bytes -> python val + (to_literal) (from_binary_idl) + + For attribute access: + Life Cycle: + python val -> msgpack bytes -> binary literal scalar -> resolved golang value -> binary literal scalar -> msgpack bytes -> python val + (to_literal) (propeller attribute access) (from_binary_idl) + """ if binary_idl_object.tag == MESSAGEPACK: try: decoder = self._msgpack_decoder[expected_python_type] except KeyError: decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_msgpack_decoder) self._msgpack_decoder[expected_python_type] = decoder - return decoder.decode(binary_idl_object.value) + python_val = decoder.decode(binary_idl_object.value) + + return python_val else: raise TypeTransformerFailedError(f"Unsupported binary format `{binary_idl_object.tag}`") + def from_generic_idl(self, generic: Struct, expected_python_type: Type[T]) -> Optional[T]: + """ + TODO: Support all Flyte Types. + This is for dataclass attribute access from input created from the Flyte Console. + + Note: + - This can be removed in the future when the Flyte Console support generate Binary IDL Scalar as input. + """ + raise NotImplementedError(f"Conversion from generic idl to python type {expected_python_type} not implemented") + def to_html(self, ctx: FlyteContext, python_val: T, expected_python_type: Type[T]) -> str: """ Converts any python val (dataframe, int, float) to a html string, and it will be wrapped in the HTML div @@ -322,6 +347,51 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp ) return self._to_literal_transformer(python_val) + def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> Optional[T]: + if binary_idl_object.tag == MESSAGEPACK: + if expected_python_type in [datetime.date, datetime.datetime, datetime.timedelta]: + """ + MessagePack doesn't support datetime, date, and timedelta. + However, mashumaro's MessagePackEncoder and MessagePackDecoder can convert them to str and vice versa. + That's why we need to use mashumaro's MessagePackDecoder here. + """ + try: + decoder = self._msgpack_decoder[expected_python_type] + except KeyError: + decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_msgpack_decoder) + self._msgpack_decoder[expected_python_type] = decoder + python_val = decoder.decode(binary_idl_object.value) + else: + python_val = msgpack.loads(binary_idl_object.value) + """ + In the case below, when using Union Transformer + Simple Transformer, then `a` + can be converted to int, bool, str and float if we use MessagePackDecoder[expected_python_type]. + + Life Cycle: + 1 -> msgpack bytes -> (1, true, "1", 1.0) + + Example Code: + @dataclass + class DC: + a: Union[int, bool, str, float] + b: Union[int, bool, str, float] + + @task(container_image=custom_image) + def add(a: Union[int, bool, str, float], b: Union[int, bool, str, float]) -> Union[int, bool, str, float]: + return a + b + + @workflow + def wf(dc: DC) -> Union[int, bool, str, float]: + return add(dc.a, dc.b) + + wf(DC(1, 1)) + """ + assert type(python_val) == expected_python_type + + return python_val + else: + raise TypeTransformerFailedError(f"Unsupported binary format `{binary_idl_object.tag}`") + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: expected_python_type = get_underlying_type(expected_python_type) @@ -1125,6 +1195,8 @@ def lazy_import_transformers(cls): from flytekit.extras import pytorch # noqa: F401 if is_imported("sklearn"): from flytekit.extras import sklearn # noqa: F401 + if is_imported("pydantic"): + from flytekit.extras import pydantic_transformer # noqa: F401 if is_imported("pandas"): try: from flytekit.types.schema.types_pandas import PandasSchemaReader, PandasSchemaWriter # noqa: F401 @@ -1779,9 +1851,6 @@ async def async_to_python_value( ) -> Optional[typing.Any]: expected_python_type = get_underlying_type(expected_python_type) - if lv.scalar is not None and lv.scalar.binary is not None: - return self.from_binary_idl(lv.scalar.binary, expected_python_type) - union_tag = None union_type = None if lv.scalar is not None and lv.scalar.union is not None: @@ -1810,9 +1879,15 @@ async def async_to_python_value( assert lv.scalar.union is not None # type checker if isinstance(trans, AsyncTypeTransformer): - res = await trans.async_to_python_value(ctx, lv.scalar.union.value, v) + if lv.scalar.binary: + res = await trans.async_to_python_value(ctx, lv, v) + else: + res = await trans.async_to_python_value(ctx, lv.scalar.union.value, v) else: - res = trans.to_python_value(ctx, lv.scalar.union.value, v) + if lv.scalar.binary: + res = trans.to_python_value(ctx, lv, v) + else: + res = trans.to_python_value(ctx, lv.scalar.union.value, v) if isinstance(res, asyncio.Future): res = await res @@ -2013,7 +2088,42 @@ async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_p return await FlytePickle.from_pickle(uri) try: - return json.loads(_json_format.MessageToJson(lv.scalar.generic)) + """ + Handles the case where Flyte Console provides input as a protobuf struct. + When resolving an attribute like 'dc.dict_int_ff', FlytePropeller retrieves a dictionary. + Mashumaro's decoder can convert this dictionary to the expected Python object if the correct type is provided. + Since Flyte Types handle their own deserialization, the dictionary is automatically converted to the expected Python object. + + Example Code: + @dataclass + class DC: + dict_int_ff: Dict[int, FlyteFile] + + @workflow + def wf(dc: DC): + t_ff(dc.dict_int_ff) + + Life Cycle: + json str -> protobuf struct -> resolved protobuf struct -> dictionary -> expected Python object + (console user input) (console output) (propeller) (flytekit dict transformer) (mashumaro decoder) + + Related PR: + - Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro + - Link: https://github.com/flyteorg/flytekit/pull/2554 + - Title: Binary IDL With MessagePack + - Link: https://github.com/flyteorg/flytekit/pull/2760 + """ + + dict_obj = json.loads(_json_format.MessageToJson(lv.scalar.generic)) + msgpack_bytes = msgpack.dumps(dict_obj) + + try: + decoder = self._msgpack_decoder[expected_python_type] + except KeyError: + decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_msgpack_decoder) + self._msgpack_decoder[expected_python_type] = decoder + + return decoder.decode(msgpack_bytes) except TypeError: raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") @@ -2197,6 +2307,35 @@ def _check_and_covert_float(lv: Literal) -> float: raise TypeTransformerFailedError(f"Cannot convert literal {lv} to float") +def _handle_flyte_console_float_input_to_int(lv: Literal) -> int: + """ + Flyte Console is written by JavaScript and JavaScript has only one number type which is Number. + Sometimes it keeps track of trailing 0s and sometimes it doesn't. + We have to convert float to int back in the following example. + + Example Code: + @dataclass + class DC: + a: int + + @workflow + def wf(dc: DC): + t_int(a=dc.a) + + Life Cycle: + json str -> protobuf struct -> resolved float -> float -> int + (console user input) (console output) (propeller) (flytekit simple transformer) (_handle_flyte_console_float_input_to_int) + """ + if lv.scalar.primitive.integer is not None: + return lv.scalar.primitive.integer + + if lv.scalar.primitive.float_value is not None: + logger.info(f"Converting literal float {lv.scalar.primitive.float_value} to int, might have precision loss.") + return int(lv.scalar.primitive.float_value) + + raise TypeTransformerFailedError(f"Cannot convert literal {lv} to int") + + def _check_and_convert_void(lv: Literal) -> None: if lv.scalar.none_type is None: raise TypeTransformerFailedError(f"Cannot convert literal {lv} to None") @@ -2208,7 +2347,7 @@ def _check_and_convert_void(lv: Literal) -> None: int, _type_models.LiteralType(simple=_type_models.SimpleType.INTEGER), lambda x: Literal(scalar=Scalar(primitive=Primitive(integer=x))), - lambda x: x.scalar.primitive.integer, + _handle_flyte_console_float_input_to_int, ) FloatTransformer = SimpleTransformer( diff --git a/flytekit/extras/pydantic_transformer/__init__.py b/flytekit/extras/pydantic_transformer/__init__.py new file mode 100644 index 0000000000..3f7744fe2f --- /dev/null +++ b/flytekit/extras/pydantic_transformer/__init__.py @@ -0,0 +1,11 @@ +from flytekit.loggers import logger + +try: + # isolate the exception to the pydantic import + # model_validator and model_serializer are only available in pydantic > 2 + from pydantic import model_serializer, model_validator + + from . import transformer +except (ImportError, OSError) as e: + logger.warning(f"Meet error when importing pydantic: `{e}`") + logger.warning("Flytekit only support pydantic version > 2.") diff --git a/flytekit/extras/pydantic_transformer/decorator.py b/flytekit/extras/pydantic_transformer/decorator.py new file mode 100644 index 0000000000..9db567739a --- /dev/null +++ b/flytekit/extras/pydantic_transformer/decorator.py @@ -0,0 +1,62 @@ +import logging +from typing import Any, Callable, TypeVar, Union + +logger = logging.getLogger(__name__) + +try: + # isolate the exception to the pydantic import + # model_validator and model_serializer are only available in pydantic > 2 + from pydantic import model_serializer, model_validator + +except ImportError: + """ + It's to support the case where pydantic is not installed at all. + It looks nicer in the real Flyte File/Directory class, but we also want it to not fail. + """ + + logger.warning( + "Pydantic is not installed.\n" "Please install Pydantic version > 2 to use FlyteTypes in pydantic BaseModel." + ) + + FuncType = TypeVar("FuncType", bound=Callable[..., Any]) + + from typing_extensions import Literal as typing_literal + + def model_serializer( + __f: Union[Callable[..., Any], None] = None, + *, + mode: typing_literal["plain", "wrap"] = "plain", + when_used: typing_literal["always", "unless-none", "json", "json-unless-none"] = "always", + return_type: Any = None, + ) -> Callable[[Any], Any]: + """Placeholder decorator for Pydantic model_serializer.""" + + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(*args, **kwargs): + raise Exception( + "Pydantic is not installed.\n" "Please install Pydantic version > 2 to use this feature." + ) + + return wrapper + + # If no function (__f) is provided, return the decorator + if __f is None: + return decorator + # If __f is provided, directly decorate the function + return decorator(__f) + + def model_validator( + *, + mode: typing_literal["wrap", "before", "after"], + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """Placeholder decorator for Pydantic model_validator.""" + + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(*args, **kwargs): + raise Exception( + "Pydantic is not installed.\n" "Please install Pydantic version > 2 to use this feature." + ) + + return wrapper + + return decorator diff --git a/flytekit/extras/pydantic_transformer/transformer.py b/flytekit/extras/pydantic_transformer/transformer.py new file mode 100644 index 0000000000..4abefcc298 --- /dev/null +++ b/flytekit/extras/pydantic_transformer/transformer.py @@ -0,0 +1,81 @@ +import json +from typing import Type + +import msgpack +from google.protobuf import json_format as _json_format +from pydantic import BaseModel + +from flytekit import FlyteContext +from flytekit.core.constants import MESSAGEPACK +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.loggers import logger +from flytekit.models import types +from flytekit.models.literals import Binary, Literal, Scalar +from flytekit.models.types import LiteralType, TypeStructure + + +class PydanticTransformer(TypeTransformer[BaseModel]): + def __init__(self): + super().__init__("Pydantic Transformer", BaseModel, enable_type_assertions=False) + + def get_literal_type(self, t: Type[BaseModel]) -> LiteralType: + schema = t.model_json_schema() + literal_type = {} + fields = t.__annotations__.items() + + for name, python_type in fields: + try: + literal_type[name] = TypeEngine.to_literal_type(python_type) + except Exception as e: + logger.warning( + "Field {} of type {} cannot be converted to a literal type. Error: {}".format(name, python_type, e) + ) + + ts = TypeStructure(tag="", dataclass_type=literal_type) + + return types.LiteralType(simple=types.SimpleType.STRUCT, metadata=schema, structure=ts) + + def to_literal( + self, + ctx: FlyteContext, + python_val: BaseModel, + python_type: Type[BaseModel], + expected: types.LiteralType, + ) -> Literal: + """ + For pydantic basemodel, we have to go through json first. + This is for handling enum in basemodel. + More details: https://github.com/flyteorg/flytekit/pull/2792 + """ + json_str = python_val.model_dump_json() + dict_obj = json.loads(json_str) + msgpack_bytes = msgpack.dumps(dict_obj) + return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag=MESSAGEPACK))) + + def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[BaseModel]) -> BaseModel: + if binary_idl_object.tag == MESSAGEPACK: + dict_obj = msgpack.loads(binary_idl_object.value, strict_map_key=False) + json_str = json.dumps(dict_obj) + python_val = expected_python_type.model_validate_json( + json_data=json_str, strict=False, context={"deserialize": True} + ) + return python_val + else: + raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[BaseModel]) -> BaseModel: + """ + There will have 2 kinds of literal values: + 1. protobuf Struct (From Flyte Console) + 2. binary scalar (Others) + Hence we have to handle 2 kinds of cases. + """ + if lv and lv.scalar and lv.scalar.binary is not None: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore + + json_str = _json_format.MessageToJson(lv.scalar.generic) + python_val = expected_python_type.model_validate_json(json_str, strict=False, context={"deserialize": True}) + return python_val + + +TypeEngine.register(PydanticTransformer()) diff --git a/flytekit/interaction/click_types.py b/flytekit/interaction/click_types.py index 1d038b0319..8cc2cc21cf 100644 --- a/flytekit/interaction/click_types.py +++ b/flytekit/interaction/click_types.py @@ -2,6 +2,7 @@ import datetime import enum import importlib +import importlib.util import json import logging import os @@ -39,7 +40,10 @@ def is_pydantic_basemodel(python_type: typing.Type) -> bool: return False else: try: - from pydantic.v1 import BaseModel + from pydantic import BaseModel as BaseModelV2 + from pydantic.v1 import BaseModel as BaseModelV1 + + return issubclass(python_type, BaseModelV1) or issubclass(python_type, BaseModelV2) except ImportError: from pydantic import BaseModel @@ -374,7 +378,24 @@ def has_nested_dataclass(t: typing.Type) -> bool: return parsed_value if is_pydantic_basemodel(self._python_type): - return self._python_type.parse_raw(json.dumps(parsed_value)) # type: ignore + """ + This function supports backward compatibility for the Pydantic v1 plugin. + If the class is a Pydantic BaseModel, it attempts to parse JSON input using + the appropriate version of Pydantic (v1 or v2). + """ + try: + if importlib.util.find_spec("pydantic.v1") is not None: + from pydantic import BaseModel as BaseModelV2 + + if issubclass(self._python_type, BaseModelV2): + return self._python_type.model_validate_json( + json.dumps(parsed_value), strict=False, context={"deserialize": True} + ) + except ImportError: + pass + + # The behavior of the Pydantic v1 plugin. + return self._python_type.parse_raw(json.dumps(parsed_value)) # Ensure that the python type has `from_json` function if not hasattr(self._python_type, "from_json"): diff --git a/flytekit/types/directory/__init__.py b/flytekit/types/directory/__init__.py index 87b494d0ae..83bb0c8fa8 100644 --- a/flytekit/types/directory/__init__.py +++ b/flytekit/types/directory/__init__.py @@ -16,7 +16,7 @@ import typing -from .types import FlyteDirectory +from .types import FlyteDirectory, FlyteDirToMultipartBlobTransformer # The following section provides some predefined aliases for commonly used FlyteDirectory formats. diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index 7e22879126..b1eb13964a 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -1,28 +1,32 @@ from __future__ import annotations +import json import os import pathlib import random import typing from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Generator, Tuple +from typing import Any, Dict, Generator, Tuple from uuid import UUID import fsspec import msgpack from dataclasses_json import DataClassJsonMixin, config from fsspec.utils import get_protocol +from google.protobuf import json_format as _json_format +from google.protobuf.struct_pb2 import Struct from marshmallow import fields from mashumaro.types import SerializableType -from flytekit import BlobType from flytekit.core.constants import MESSAGEPACK from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError, get_batch_size from flytekit.exceptions.user import FlyteAssertion +from flytekit.extras.pydantic_transformer.decorator import model_serializer, model_validator from flytekit.models import types as _type_models from flytekit.models.core import types as _core_types +from flytekit.models.core.types import BlobType from flytekit.models.literals import Binary, Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType from flytekit.types.file import FileExt, FlyteFile @@ -131,12 +135,21 @@ def _serialize(self) -> typing.Dict[str, str]: @classmethod def _deserialize(cls, value) -> "FlyteDirectory": - path = value.get("path", None) + return FlyteDirToMultipartBlobTransformer().dict_to_flyte_directory(dict_obj=value, expected_python_type=cls) - if path is None: - raise ValueError("FlyteDirectory's path should not be None") + @model_serializer + def serialize_flyte_dir(self) -> Dict[str, str]: + lv = FlyteDirToMultipartBlobTransformer().to_literal( + FlyteContextManager.current_context(), self, type(self), None + ) + return {"path": lv.scalar.blob.uri} - return FlyteDirToMultipartBlobTransformer().to_python_value( + @model_validator(mode="after") + def deserialize_flyte_dir(self, info) -> FlyteDirectory: + if info.context is None or info.context.get("deserialize") is not True: + return self + + pv = FlyteDirToMultipartBlobTransformer().to_python_value( FlyteContextManager.current_context(), Literal( scalar=Scalar( @@ -146,12 +159,13 @@ def _deserialize(cls, value) -> "FlyteDirectory": format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART ) ), - uri=path, + uri=self.path, ) ) ), - cls, + type(self), ) + return pv def __init__( self, @@ -532,42 +546,106 @@ async def async_to_literal( else: return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=source_path))) + def dict_to_flyte_directory( + self, dict_obj: typing.Dict[str, str], expected_python_type: typing.Type[FlyteDirectory] + ) -> FlyteDirectory: + path = dict_obj.get("path", None) + + if path is None: + raise ValueError("FlyteDirectory's path should not be None") + + return self.to_python_value( + FlyteContextManager.current_context(), + Literal( + scalar=Scalar( + blob=Blob( + metadata=BlobMetadata( + type=_core_types.BlobType( + format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART + ) + ), + uri=path, + ) + ) + ), + expected_python_type, + ) + def from_binary_idl( self, binary_idl_object: Binary, expected_python_type: typing.Type[FlyteDirectory] ) -> FlyteDirectory: + """ + If the input is from flytekit, the Life Cycle will be as follows: + + Life Cycle: + binary IDL -> resolved binary -> bytes -> expected Python object + (flytekit customized (propeller processing) (flytekit binary IDL) (flytekit customized + serialization) deserialization) + + Example Code: + @dataclass + class DC: + fd: FlyteDirectory + + @workflow + def wf(dc: DC): + t_fd(dc.fd) + + Note: + - The deserialization is the same as put a flyte directory in a dataclass, which will deserialize by the mashumaro's API. + + Related PR: + - Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro + - Link: https://github.com/flyteorg/flytekit/pull/2554 + """ if binary_idl_object.tag == MESSAGEPACK: python_val = msgpack.loads(binary_idl_object.value) - path = python_val.get("path", None) - - if path is None: - raise ValueError("FlyteDirectory's path should not be None") - - return FlyteDirToMultipartBlobTransformer().to_python_value( - FlyteContextManager.current_context(), - Literal( - scalar=Scalar( - blob=Blob( - metadata=BlobMetadata( - type=_core_types.BlobType( - format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART - ) - ), - uri=path, - ) - ) - ), - expected_python_type, - ) + return self.dict_to_flyte_directory(python_val, expected_python_type) else: raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") + def from_generic_idl(self, generic: Struct, expected_python_type: typing.Type[FlyteDirectory]) -> FlyteDirectory: + """ + If the input is from Flyte Console, the Life Cycle will be as follows: + + Life Cycle: + json str -> protobuf struct -> resolved protobuf struct -> expected Python object + (console user input) (console output) (propeller) (flytekit customized deserialization) + + Example Code: + @dataclass + class DC: + fd: FlyteDirectory + + @workflow + def wf(dc: DC): + t_fd(dc.fd) + + Note: + - The deserialization is the same as put a flyte directory in a dataclass, which will deserialize by the mashumaro's API. + + Related PR: + - Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro + - Link: https://github.com/flyteorg/flytekit/pull/2554 + """ + json_str = _json_format.MessageToJson(generic) + python_val = json.loads(json_str) + return self.dict_to_flyte_directory(python_val, expected_python_type) + async def async_to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Type[FlyteDirectory] ) -> FlyteDirectory: - if lv.scalar.binary: - return self.from_binary_idl(lv.scalar.binary, expected_python_type) - - uri = lv.scalar.blob.uri + # Handle dataclass attribute access + if lv.scalar: + if lv.scalar.binary: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) + if lv.scalar.generic: + return self.from_generic_idl(lv.scalar.generic, expected_python_type) + + try: + uri = lv.scalar.blob.uri + except AttributeError: + raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") if lv.scalar.blob.metadata.type.dimensionality != BlobType.BlobDimensionality.MULTIPART: raise TypeTransformerFailedError(f"{lv.scalar.blob.uri} is not a directory.") diff --git a/flytekit/types/file/__init__.py b/flytekit/types/file/__init__.py index 838516f33d..8d69247e10 100644 --- a/flytekit/types/file/__init__.py +++ b/flytekit/types/file/__init__.py @@ -25,7 +25,7 @@ from typing_extensions import Annotated, get_args, get_origin -from .file import FlyteFile +from .file import FlyteFile, FlyteFilePathTransformer class FileExt: diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index a087af11eb..eb0aa5544d 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -1,16 +1,19 @@ from __future__ import annotations +import json import mimetypes import os import pathlib import typing from contextlib import contextmanager from dataclasses import dataclass, field -from typing import cast +from typing import Dict, cast from urllib.parse import unquote import msgpack from dataclasses_json import config +from google.protobuf import json_format as _json_format +from google.protobuf.struct_pb2 import Struct from marshmallow import fields from mashumaro.mixins.json import DataClassJSONMixin from mashumaro.types import SerializableType @@ -24,6 +27,7 @@ get_underlying_type, ) from flytekit.exceptions.user import FlyteAssertion +from flytekit.extras.pydantic_transformer.decorator import model_serializer, model_validator from flytekit.loggers import logger from flytekit.models.core import types as _core_types from flytekit.models.core.types import BlobType @@ -159,12 +163,19 @@ def _serialize(self) -> typing.Dict[str, str]: @classmethod def _deserialize(cls, value) -> "FlyteFile": - path = value.get("path", None) + return FlyteFilePathTransformer().dict_to_flyte_file(dict_obj=value, expected_python_type=cls) - if path is None: - raise ValueError("FlyteFile's path should not be None") + @model_serializer + def serialize_flyte_file(self) -> Dict[str, str]: + lv = FlyteFilePathTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None) + return {"path": lv.scalar.blob.uri} + + @model_validator(mode="after") + def deserialize_flyte_file(self, info) -> "FlyteFile": + if info.context is None or info.context.get("deserialize") is not True: + return self - return FlyteFilePathTransformer().to_python_value( + pv = FlyteFilePathTransformer().to_python_value( FlyteContextManager.current_context(), Literal( scalar=Scalar( @@ -174,12 +185,13 @@ def _deserialize(cls, value) -> "FlyteFile": format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE ) ), - uri=path, + uri=self.path, ) ) ), - cls, + type(self), ) + return pv @classmethod def extension(cls) -> str: @@ -548,41 +560,103 @@ def get_additional_headers(source_path: str | os.PathLike) -> typing.Dict[str, s return {"ContentEncoding": "gzip"} return {} + def dict_to_flyte_file( + self, dict_obj: typing.Dict[str, str], expected_python_type: typing.Union[typing.Type[FlyteFile], os.PathLike] + ) -> FlyteFile: + path = dict_obj.get("path", None) + + if path is None: + raise ValueError("FlyteFile's path should not be None") + + return self.to_python_value( + FlyteContextManager.current_context(), + Literal( + scalar=Scalar( + blob=Blob( + metadata=BlobMetadata( + type=_core_types.BlobType( + format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE + ) + ), + uri=path, + ) + ) + ), + expected_python_type, + ) + def from_binary_idl( self, binary_idl_object: Binary, expected_python_type: typing.Union[typing.Type[FlyteFile], os.PathLike] ) -> FlyteFile: + """ + If the input is from flytekit, the Life Cycle will be as follows: + + Life Cycle: + binary IDL -> resolved binary -> bytes -> expected Python object + (flytekit customized (propeller processing) (flytekit binary IDL) (flytekit customized + serialization) deserialization) + + Example Code: + @dataclass + class DC: + ff: FlyteFile + + @workflow + def wf(dc: DC): + t_ff(dc.ff) + + Note: + - The deserialization is the same as put a flyte file in a dataclass, which will deserialize by the mashumaro's API. + + Related PR: + - Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro + - Link: https://github.com/flyteorg/flytekit/pull/2554 + """ if binary_idl_object.tag == MESSAGEPACK: python_val = msgpack.loads(binary_idl_object.value) - path = python_val.get("path", None) - - if path is None: - raise ValueError("FlyteFile's path should not be None") - - return FlyteFilePathTransformer().to_python_value( - FlyteContextManager.current_context(), - Literal( - scalar=Scalar( - blob=Blob( - metadata=BlobMetadata( - type=_core_types.BlobType( - format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE - ) - ), - uri=path, - ) - ) - ), - expected_python_type, - ) + return self.dict_to_flyte_file(dict_obj=python_val, expected_python_type=expected_python_type) else: raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") + def from_generic_idl( + self, generic: Struct, expected_python_type: typing.Union[typing.Type[FlyteFile], os.PathLike] + ) -> FlyteFile: + """ + If the input is from Flyte Console, the Life Cycle will be as follows: + + Life Cycle: + json str -> protobuf struct -> resolved protobuf struct -> expected Python object + (console user input) (console output) (propeller) (flytekit customized deserialization) + + Example Code: + @dataclass + class DC: + ff: FlyteFile + + @workflow + def wf(dc: DC): + t_ff(dc.ff) + + Note: + - The deserialization is the same as put a flyte file in a dataclass, which will deserialize by the mashumaro's API. + + Related PR: + - Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro + - Link: https://github.com/flyteorg/flytekit/pull/2554 + """ + json_str = _json_format.MessageToJson(generic) + python_val = json.loads(json_str) + return self.dict_to_flyte_file(dict_obj=python_val, expected_python_type=expected_python_type) + async def async_to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Union[typing.Type[FlyteFile], os.PathLike] ) -> FlyteFile: # Handle dataclass attribute access - if lv.scalar and lv.scalar.binary: - return self.from_binary_idl(lv.scalar.binary, expected_python_type) + if lv.scalar: + if lv.scalar.binary: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) + if lv.scalar.generic: + return self.from_generic_idl(lv.scalar.generic, expected_python_type) try: uri = lv.scalar.blob.uri diff --git a/flytekit/types/schema/__init__.py b/flytekit/types/schema/__init__.py index 080927021a..33ee8ef72c 100644 --- a/flytekit/types/schema/__init__.py +++ b/flytekit/types/schema/__init__.py @@ -1,5 +1,6 @@ from .types import ( FlyteSchema, + FlyteSchemaTransformer, LocalIOSchemaReader, LocalIOSchemaWriter, SchemaEngine, diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index 28a2c542ef..34dcc18058 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -1,16 +1,19 @@ from __future__ import annotations import datetime +import json import os import typing from abc import abstractmethod from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from typing import Type +from typing import Dict, Optional, Type import msgpack from dataclasses_json import config +from google.protobuf import json_format as _json_format +from google.protobuf.struct_pb2 import Struct from marshmallow import fields from mashumaro.mixins.json import DataClassJSONMixin from mashumaro.types import SerializableType @@ -18,6 +21,7 @@ from flytekit.core.constants import MESSAGEPACK from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError +from flytekit.extras.pydantic_transformer.decorator import model_serializer, model_validator from flytekit.loggers import logger from flytekit.models.literals import Binary, Literal, Scalar, Schema from flytekit.models.types import LiteralType, SchemaType @@ -185,7 +189,7 @@ class FlyteSchema(SerializableType, DataClassJSONMixin): This is the main schema class that users should use. """ - def _serialize(self) -> typing.Dict[str, typing.Optional[str]]: + def _serialize(self) -> Dict[str, Optional[str]]: FlyteSchemaTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None) return {"remote_path": self.remote_path} @@ -203,6 +207,23 @@ def _deserialize(cls, value) -> "FlyteSchema": cls, ) + @model_serializer + def serialize_flyte_schema(self) -> Dict[str, Optional[str]]: + FlyteSchemaTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None) + return {"remote_path": self.remote_path} + + @model_validator(mode="after") + def deserialize_flyte_schema(self, info) -> FlyteSchema: + if info.context is None or info.context.get("deserialize") is not True: + return self + + t = FlyteSchemaTransformer() + return t.to_python_value( + FlyteContextManager.current_context(), + Literal(scalar=Scalar(schema=Schema(self.remote_path, t._get_schema_type(type(self))))), + type(self), + ) + @classmethod def columns(cls) -> typing.Dict[str, typing.Type]: return {} @@ -445,29 +466,89 @@ async def async_to_literal( ) return Literal(scalar=Scalar(schema=Schema(schema.remote_path, self._get_schema_type(python_type)))) + def dict_to_flyte_schema( + self, dict_obj: typing.Dict[str, str], expected_python_type: Type[FlyteSchema] + ) -> FlyteSchema: + remote_path = dict_obj.get("remote_path", None) + + if remote_path is None: + raise ValueError("FlyteSchema's path should not be None") + + t = FlyteSchemaTransformer() + return t.to_python_value( + FlyteContextManager.current_context(), + Literal(scalar=Scalar(schema=Schema(remote_path, t._get_schema_type(expected_python_type)))), + expected_python_type, + ) + def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[FlyteSchema]) -> FlyteSchema: - if binary_idl_object.tag == MESSAGEPACK: - python_val = msgpack.loads(binary_idl_object.value) - remote_path = python_val.get("remote_path", None) + """ + If the input is from flytekit, the Life Cycle will be as follows: - if remote_path is None: - raise ValueError("FlyteSchema's path should not be None") + Life Cycle: + binary IDL -> resolved binary -> bytes -> expected Python object + (flytekit customized (propeller processing) (flytekit binary IDL) (flytekit customized + serialization) deserialization) - t = FlyteSchemaTransformer() - return t.to_python_value( - FlyteContextManager.current_context(), - Literal(scalar=Scalar(schema=Schema(remote_path, t._get_schema_type(expected_python_type)))), - expected_python_type, - ) + Example Code: + @dataclass + class DC: + fs: FlyteSchema + + @workflow + def wf(dc: DC): + t_fs(dc.fs) + + Note: + - The deserialization is the same as put a flyte schema in a dataclass, which will deserialize by the mashumaro's API. + + Related PR: + - Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro + - Link: https://github.com/flyteorg/flytekit/pull/2554 + """ + if binary_idl_object.tag == MESSAGEPACK: + python_val = msgpack.loads(binary_idl_object.value) + return self.dict_to_flyte_schema(dict_obj=python_val, expected_python_type=expected_python_type) else: raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") + def from_generic_idl(self, generic: Struct, expected_python_type: Type[FlyteSchema]) -> FlyteSchema: + """ + If the input is from Flyte Console, the Life Cycle will be as follows: + + Life Cycle: + json str -> protobuf struct -> resolved protobuf struct -> expected Python object + (console user input) (console output) (propeller) (flytekit customized deserialization) + + Example Code: + @dataclass + class DC: + fs: FlyteSchema + + @workflow + def wf(dc: DC): + t_fs(dc.fs) + + Note: + - The deserialization is the same as put a flyte schema in a dataclass, which will deserialize by the mashumaro's API. + + Related PR: + - Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro + - Link: https://github.com/flyteorg/flytekit/pull/2554 + """ + json_str = _json_format.MessageToJson(generic) + python_val = json.loads(json_str) + return self.dict_to_flyte_schema(dict_obj=python_val, expected_python_type=expected_python_type) + async def async_to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[FlyteSchema] ) -> FlyteSchema: # Handle dataclass attribute access - if lv.scalar and lv.scalar.binary: - return self.from_binary_idl(lv.scalar.binary, expected_python_type) + if lv.scalar: + if lv.scalar.binary: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) + if lv.scalar.generic: + return self.from_generic_idl(lv.scalar.generic, expected_python_type) def downloader(x, y): ctx.file_access.get_data(x, y, is_multipart=True) diff --git a/flytekit/types/structured/__init__.py b/flytekit/types/structured/__init__.py index 05d1fa86e3..254ff16721 100644 --- a/flytekit/types/structured/__init__.py +++ b/flytekit/types/structured/__init__.py @@ -7,9 +7,9 @@ :template: custom.rst :toctree: generated/ - StructuredDataset - StructuredDatasetEncoder - StructuredDatasetDecoder + StructuredDataset + StructuredDatasetDecoder + StructuredDatasetEncoder """ from flytekit.deck.renderer import ArrowRenderer, TopFrameRenderer @@ -19,7 +19,9 @@ StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, + StructuredDatasetMetadata, StructuredDatasetTransformerEngine, + StructuredDatasetType, ) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 57c028e71c..39843668cb 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -2,6 +2,7 @@ import _datetime import collections +import json import types import typing from abc import ABC, abstractmethod @@ -11,6 +12,8 @@ import msgpack from dataclasses_json import config from fsspec.utils import get_protocol +from google.protobuf import json_format as _json_format +from google.protobuf.struct_pb2 import Struct from marshmallow import fields from mashumaro.mixins.json import DataClassJSONMixin from mashumaro.types import SerializableType @@ -21,6 +24,7 @@ from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError from flytekit.deck.renderer import Renderable +from flytekit.extras.pydantic_transformer.decorator import model_serializer, model_validator from flytekit.loggers import developer_logger, logger from flytekit.models import literals from flytekit.models import types as type_models @@ -91,6 +95,38 @@ def _deserialize(cls, value) -> "StructuredDataset": cls, ) + @model_serializer + def serialize_structured_dataset(self) -> Dict[str, Optional[str]]: + lv = StructuredDatasetTransformerEngine().to_literal( + FlyteContextManager.current_context(), self, type(self), None + ) + sd = StructuredDataset(uri=lv.scalar.structured_dataset.uri) + sd.file_format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format + return { + "uri": sd.uri, + "file_format": sd.file_format, + } + + @model_validator(mode="after") + def deserialize_structured_dataset(self, info) -> StructuredDataset: + if info.context is None or info.context.get("deserialize") is not True: + return self + + return StructuredDatasetTransformerEngine().to_python_value( + FlyteContextManager.current_context(), + Literal( + scalar=Scalar( + structured_dataset=StructuredDataset( + metadata=StructuredDatasetMetadata( + structured_dataset_type=StructuredDatasetType(format=self.file_format) + ), + uri=self.uri, + ) + ) + ), + type(self), + ) + @classmethod def columns(cls) -> typing.Dict[str, typing.Type]: return {} @@ -724,34 +760,93 @@ def encode( sd._already_uploaded = True return lit + def dict_to_structured_dataset( + self, dict_obj: typing.Dict[str, str], expected_python_type: Type[T] | StructuredDataset + ) -> T | StructuredDataset: + uri = dict_obj.get("uri", None) + file_format = dict_obj.get("file_format", None) + + if uri is None: + raise ValueError("StructuredDataset's uri and file format should not be None") + + return StructuredDatasetTransformerEngine().to_python_value( + FlyteContextManager.current_context(), + Literal( + scalar=Scalar( + structured_dataset=StructuredDataset( + metadata=StructuredDatasetMetadata( + structured_dataset_type=StructuredDatasetType(format=file_format) + ), + uri=uri, + ) + ) + ), + expected_python_type, + ) + def from_binary_idl( self, binary_idl_object: Binary, expected_python_type: Type[T] | StructuredDataset ) -> T | StructuredDataset: + """ + If the input is from flytekit, the Life Cycle will be as follows: + + Life Cycle: + binary IDL -> resolved binary -> bytes -> expected Python object + (flytekit customized (propeller processing) (flytekit binary IDL) (flytekit customized + serialization) deserialization) + + Example Code: + @dataclass + class DC: + sd: StructuredDataset + + @workflow + def wf(dc: DC): + t_sd(dc.sd) + + Note: + - The deserialization is the same as put a structured dataset in a dataclass, which will deserialize by the mashumaro's API. + + Related PR: + - Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro + - Link: https://github.com/flyteorg/flytekit/pull/2554 + """ if binary_idl_object.tag == MESSAGEPACK: python_val = msgpack.loads(binary_idl_object.value) - uri = python_val.get("uri", None) - file_format = python_val.get("file_format", None) - - if uri is None: - raise ValueError("StructuredDataset's uri and file format should not be None") - - return StructuredDatasetTransformerEngine().to_python_value( - FlyteContextManager.current_context(), - Literal( - scalar=Scalar( - structured_dataset=StructuredDataset( - metadata=StructuredDatasetMetadata( - structured_dataset_type=StructuredDatasetType(format=file_format) - ), - uri=uri, - ) - ) - ), - expected_python_type, - ) + return self.dict_to_structured_dataset(dict_obj=python_val, expected_python_type=expected_python_type) else: raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") + def from_generic_idl( + self, generic: Struct, expected_python_type: Type[T] | StructuredDataset + ) -> T | StructuredDataset: + """ + If the input is from Flyte Console, the Life Cycle will be as follows: + + Life Cycle: + json str -> protobuf struct -> resolved protobuf struct -> expected Python object + (console user input) (console output) (propeller) (flytekit customized deserialization) + + Example Code: + @dataclass + class DC: + sd: StructuredDataset + + @workflow + def wf(dc: DC): + t_sd(dc.sd) + + Note: + - The deserialization is the same as put a structured dataset in a dataclass, which will deserialize by the mashumaro's API. + + Related PR: + - Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro + - Link: https://github.com/flyteorg/flytekit/pull/2554 + """ + json_str = _json_format.MessageToJson(generic) + python_val = json.loads(json_str) + return self.dict_to_structured_dataset(dict_obj=python_val, expected_python_type=expected_python_type) + async def async_to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T] | StructuredDataset ) -> T | StructuredDataset: @@ -786,8 +881,11 @@ def t2(in_a: Annotated[StructuredDataset, kwtypes(col_b=float)]): ... +-----------------------------+-----------------------------------------+--------------------------------------+ """ # Handle dataclass attribute access - if lv.scalar and lv.scalar.binary: - return self.from_binary_idl(lv.scalar.binary, expected_python_type) + if lv.scalar: + if lv.scalar.binary: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) + if lv.scalar.generic: + return self.from_generic_idl(lv.scalar.generic, expected_python_type) # Detect annotations and extract out all the relevant information that the user might supply expected_python_type, column_dict, storage_fmt, pa_schema = extract_cols_and_format(expected_python_type) diff --git a/plugins/flytekit-pydantic/README.md b/plugins/flytekit-pydantic/README.md index 3f42c9cd21..8fef623d03 100644 --- a/plugins/flytekit-pydantic/README.md +++ b/plugins/flytekit-pydantic/README.md @@ -1,5 +1,10 @@ # Flytekit Pydantic Plugin +## Warning +This plugin is deprecated and will be removed in the future. +Please directly install `pydantic` and use `BaseModel` in your Flyte tasks. + +## Introduction Pydantic is a data validation and settings management library that uses Python type annotations to enforce type hints at runtime and provide user-friendly errors when data is invalid. Pydantic models are classes that inherit from `pydantic.BaseModel` and are used to define the structure and validation of data using Python type annotations. The plugin adds type support for pydantic models. diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py index 23e7e341bd..491bd8c9c4 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py @@ -1,4 +1,11 @@ +from flytekit.loggers import logger + from .basemodel_transformer import BaseModelTransformer from .deserialization import set_validators_on_supported_flyte_types as _set_validators_on_supported_flyte_types _set_validators_on_supported_flyte_types() # enables you to use flytekit.types in pydantic model +logger.warning( + "The Flytekit Pydantic V1 plugin is deprecated.\n" + "Please uninstall `flytekitplugins-pydantic` and install Pydantic directly.\n" + "You can now use Pydantic V2 BaseModels in Flytekit tasks." +) diff --git a/plugins/flytekit-pydantic/setup.py b/plugins/flytekit-pydantic/setup.py index 63e2c941e7..7001506a70 100644 --- a/plugins/flytekit-pydantic/setup.py +++ b/plugins/flytekit-pydantic/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.7.0b0", "pydantic"] +plugin_requires = ["flytekit>=1.7.0b0", "pydantic<2"] __version__ = "0.0.0+develop" diff --git a/tests/flytekit/unit/core/test_flyte_directory.py b/tests/flytekit/unit/core/test_flyte_directory.py index aa7e7dca4f..be61388fa5 100644 --- a/tests/flytekit/unit/core/test_flyte_directory.py +++ b/tests/flytekit/unit/core/test_flyte_directory.py @@ -22,9 +22,14 @@ from flytekit.models.core.types import BlobType from flytekit.models.literals import LiteralMap from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer - +from google.protobuf import json_format as _json_format +from google.protobuf import struct_pb2 as _struct +from flytekit.models.literals import Literal, Scalar +import json # Fixture that ensures a dummy local file + + @pytest.fixture def local_dummy_directory(): temp_dir = tempfile.TemporaryDirectory() @@ -51,8 +56,13 @@ def test_engine(): def test_transformer_to_literal_local(): - random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() - fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "raw")) + random_dir = context_manager.FlyteContext.current_context( + ).file_access.get_random_local_directory() + fs = FileAccessProvider( + local_sandbox_dir=random_dir, + raw_output_prefix=os.path.join( + random_dir, + "raw")) ctx = context_manager.FlyteContext.current_context() with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)) as ctx: p = tempfile.mkdtemp(prefix="temp_example_") @@ -68,7 +78,13 @@ def test_transformer_to_literal_local(): assert literal.scalar.blob.uri.startswith(random_dir) # Create a FlyteDirectory where remote_directory is False - literal = tf.to_literal(ctx, FlyteDirectory(p, remote_directory=False), FlyteDirectory, lt) + literal = tf.to_literal( + ctx, + FlyteDirectory( + p, + remote_directory=False), + FlyteDirectory, + lt) assert literal.scalar.blob.uri.startswith(p) # Create a director with one file in it @@ -91,8 +107,13 @@ def test_transformer_to_literal_local(): def test_transformer_to_literal_local_path(): - random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() - fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "raw")) + random_dir = context_manager.FlyteContext.current_context( + ).file_access.get_random_local_directory() + fs = FileAccessProvider( + local_sandbox_dir=random_dir, + raw_output_prefix=os.path.join( + random_dir, + "raw")) ctx = context_manager.FlyteContext.current_context() with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)) as ctx: tf = FlyteDirToMultipartBlobTransformer() @@ -111,12 +132,18 @@ def test_transformer_to_literal_local_path(): def test_transformer_to_literal_remote(): - random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() - fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "raw")) + random_dir = context_manager.FlyteContext.current_context( + ).file_access.get_random_local_directory() + fs = FileAccessProvider( + local_sandbox_dir=random_dir, + raw_output_prefix=os.path.join( + random_dir, + "raw")) ctx = context_manager.FlyteContext.current_context() with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)) as ctx: # Use a separate directory that we know won't be the same as anything generated by flytekit itself, lest we - # accidentally try to cp -R /some/folder /some/folder/sub which causes exceptions obviously. + # accidentally try to cp -R /some/folder /some/folder/sub which causes + # exceptions obviously. p = "/tmp/flyte/test_fd_transformer" # Create an empty directory and call to literal on it if os.path.exists(p): @@ -127,7 +154,8 @@ def test_transformer_to_literal_remote(): lt = tf.get_literal_type(FlyteDirectory) # Remote directories should be copied as is. - literal = tf.to_literal(ctx, FlyteDirectory("s3://anything"), FlyteDirectory, lt) + literal = tf.to_literal(ctx, FlyteDirectory( + "s3://anything"), FlyteDirectory, lt) assert literal.scalar.blob.uri == "s3://anything" @@ -190,16 +218,23 @@ def dyn(in1: FlyteDirectory): project="test_proj", domain="test_domain", version="abc", - image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), + image_config=ImageConfig( + Image( + name="name", + fqn="image", + tag="name")), env={}, ) ) ) as ctx: with context_manager.FlyteContextManager.with_context( - ctx.with_execution_state(ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION)) + ctx.with_execution_state( + ctx.execution_state.with_params( + mode=ExecutionState.Mode.TASK_EXECUTION)) ) as ctx: lit = TypeEngine.to_literal( - ctx, fd, FlyteDirectory, BlobType("", dimensionality=BlobType.BlobDimensionality.MULTIPART) + ctx, fd, FlyteDirectory, BlobType( + "", dimensionality=BlobType.BlobDimensionality.MULTIPART) ) lm = LiteralMap(literals={"in1": lit}) wf = dyn.dispatch_execute(ctx, lm) @@ -235,7 +270,8 @@ def wf1() -> FlyteDirectory: with open(os.path.join(wf_out.path, "file"), "r") as fh: assert fh.read() == "Hello world" - # Remove the file, then call download again, it should not because _downloaded was already set. + # Remove the file, then call download again, it should not because + # _downloaded was already set. shutil.rmtree(wf_out) wf_out.download() assert not os.path.exists(wf_out.path) @@ -247,7 +283,8 @@ def test_fd_with_local_remote(local_dummy_directory): @task def t1() -> FlyteDirectory: - return FlyteDirectory(local_dummy_directory, remote_directory=temp_dir.name) + return FlyteDirectory(local_dummy_directory, + remote_directory=temp_dir.name) # TODO: Remove this - only here to trigger type engine @workflow @@ -347,14 +384,18 @@ def test_manual_creation_sandbox(local_dummy_directory): assert os.path.exists(fd_new.path) assert os.path.isdir(fd_new.path) + def test_flytefile_in_dataclass(local_dummy_directory): SvgDirectory = FlyteDirectory["svg"] + @dataclass class DC: f: SvgDirectory + @task def t1(path: SvgDirectory) -> DC: return DC(f=path) + @workflow def my_wf(path: SvgDirectory) -> DC: dc = t1(path=path) @@ -364,3 +405,20 @@ def my_wf(path: SvgDirectory) -> DC: dc1 = my_wf(path=svg_directory) dc2 = DC(f=svg_directory) assert dc1 == dc2 + + +def test_input_from_flyte_console_attribute_access_flytefile( + local_dummy_directory): + # Flyte Console will send the input data as protobuf Struct + + dict_obj = {"path": local_dummy_directory} + json_str = json.dumps(dict_obj) + upstream_output = Literal( + scalar=Scalar( + generic=_json_format.Parse( + json_str, + _struct.Struct()))) + downstream_input = TypeEngine.to_python_value( + FlyteContextManager.current_context(), upstream_output, FlyteDirectory) + assert isinstance(downstream_input, FlyteDirectory) + assert downstream_input == FlyteDirectory(local_dummy_directory) diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index 352984ca37..1b5d32b1f2 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -1,3 +1,4 @@ +import json import os import pathlib import tempfile @@ -20,9 +21,13 @@ from flytekit.models.core.types import BlobType from flytekit.models.literals import LiteralMap from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer - +from google.protobuf import json_format as _json_format +from google.protobuf import struct_pb2 as _struct +from flytekit.models.literals import Literal, Scalar # Fixture that ensures a dummy local file + + @pytest.fixture def local_dummy_file(): fd, path = tempfile.mkstemp() @@ -78,11 +83,13 @@ def my_wf() -> FlyteFile[typing.TypeVar("txt")]: def test_matching_file_types_in_workflow(local_dummy_txt_file): # TXT @task - def t1(path: FlyteFile[typing.TypeVar("txt")]) -> FlyteFile[typing.TypeVar("txt")]: + def t1(path: FlyteFile[typing.TypeVar("txt")] + ) -> FlyteFile[typing.TypeVar("txt")]: return path @workflow - def my_wf(path: FlyteFile[typing.TypeVar("txt")]) -> FlyteFile[typing.TypeVar("txt")]: + def my_wf(path: FlyteFile[typing.TypeVar("txt")] + ) -> FlyteFile[typing.TypeVar("txt")]: f = t1(path=path) return f @@ -132,20 +139,24 @@ def my_wf(path: TxtFile) -> DC: assert dc1 == dc2 -@pytest.mark.skipif(not can_import("magic"), reason="Libmagic is not installed") +@pytest.mark.skipif(not can_import("magic"), + reason="Libmagic is not installed") def test_mismatching_file_types(local_dummy_txt_file): @task - def t1(path: FlyteFile[typing.TypeVar("txt")]) -> FlyteFile[typing.TypeVar("jpeg")]: + def t1(path: FlyteFile[typing.TypeVar("txt")] + ) -> FlyteFile[typing.TypeVar("jpeg")]: return path @workflow - def my_wf(path: FlyteFile[typing.TypeVar("txt")]) -> FlyteFile[typing.TypeVar("jpeg")]: + def my_wf(path: FlyteFile[typing.TypeVar("txt")] + ) -> FlyteFile[typing.TypeVar("jpeg")]: f = t1(path=path) return f with pytest.raises(ValueError) as excinfo: my_wf(path=local_dummy_txt_file) - assert "Incorrect file type, expected image/jpeg, got text/plain" in str(excinfo.value) + assert "Incorrect file type, expected image/jpeg, got text/plain" in str( + excinfo.value) def test_get_mime_type_from_extension_success(): @@ -154,14 +165,19 @@ def test_get_mime_type_from_extension_success(): assert transformer.get_mime_type_from_extension("jpeg") == "image/jpeg" assert transformer.get_mime_type_from_extension("png") == "image/png" assert transformer.get_mime_type_from_extension("hdf5") == "text/plain" - assert transformer.get_mime_type_from_extension("joblib") == "application/octet-stream" + assert transformer.get_mime_type_from_extension( + "joblib") == "application/octet-stream" assert transformer.get_mime_type_from_extension("pdf") == "application/pdf" - assert transformer.get_mime_type_from_extension("python_pickle") == "application/octet-stream" - assert transformer.get_mime_type_from_extension("ipynb") == "application/json" + assert transformer.get_mime_type_from_extension( + "python_pickle") == "application/octet-stream" + assert transformer.get_mime_type_from_extension( + "ipynb") == "application/json" assert transformer.get_mime_type_from_extension("svg") == "image/svg+xml" assert transformer.get_mime_type_from_extension("csv") == "text/csv" - assert transformer.get_mime_type_from_extension("onnx") == "application/json" - assert transformer.get_mime_type_from_extension("tfrecord") == "application/octet-stream" + assert transformer.get_mime_type_from_extension( + "onnx") == "application/json" + assert transformer.get_mime_type_from_extension( + "tfrecord") == "application/octet-stream" assert transformer.get_mime_type_from_extension("txt") == "text/plain" @@ -171,7 +187,8 @@ def test_get_mime_type_from_extension_failure(): transformer.get_mime_type_from_extension("unknown_extension") -@pytest.mark.skipif(not can_import("magic"), reason="Libmagic is not installed") +@pytest.mark.skipif(not can_import("magic"), + reason="Libmagic is not installed") def test_validate_file_type_incorrect(): transformer = TypeEngine.get_transformer(FlyteFile) source_path = "/tmp/flytekit_test.png" @@ -183,10 +200,12 @@ def test_validate_file_type_incorrect(): with pytest.raises( ValueError, match=f"Incorrect file type, expected image/jpeg, got {source_file_mime_type}" ): - transformer.validate_file_type(user_defined_format, source_path) + transformer.validate_file_type( + user_defined_format, source_path) -@pytest.mark.skipif(not can_import("magic"), reason="Libmagic is not installed") +@pytest.mark.skipif(not can_import("magic"), + reason="Libmagic is not installed") def test_flyte_file_type_annotated_hashmethod(local_dummy_file): def calc_hash(ff: FlyteFile) -> str: return str(ff.path) @@ -208,7 +227,8 @@ def wf(path: str) -> None: with pytest.raises(ValueError) as excinfo: wf(path=local_dummy_file) - assert "Incorrect file type, expected image/jpeg, got text/plain" in str(excinfo.value) + assert "Incorrect file type, expected image/jpeg, got text/plain" in str( + excinfo.value) def test_file_handling_remote_default_wf_input(): @@ -241,8 +261,13 @@ def t1() -> FlyteFile: def my_wf() -> FlyteFile: return t1() - random_dir = FlyteContextManager.current_context().file_access.get_random_local_directory() - fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "mock_remote")) + random_dir = FlyteContextManager.current_context( + ).file_access.get_random_local_directory() + fs = FileAccessProvider( + local_sandbox_dir=random_dir, + raw_output_prefix=os.path.join( + random_dir, + "mock_remote")) ctx = FlyteContextManager.current_context() with FlyteContextManager.with_context(ctx.with_file_access(fs)): top_level_files = os.listdir(random_dir) @@ -262,8 +287,13 @@ def t1() -> FlyteFile: def my_wf() -> FlyteFile: return t1() - random_dir = FlyteContextManager.current_context().file_access.get_random_local_directory() - fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "mock_remote")) + random_dir = FlyteContextManager.current_context( + ).file_access.get_random_local_directory() + fs = FileAccessProvider( + local_sandbox_dir=random_dir, + raw_output_prefix=os.path.join( + random_dir, + "mock_remote")) ctx = FlyteContextManager.current_context() with FlyteContextManager.with_context(ctx.with_file_access(fs)): top_level_files = os.listdir(random_dir) @@ -271,7 +301,8 @@ def my_wf() -> FlyteFile: workflow_output = my_wf() - # After running, this test file should've been copied to the mock remote location. + # After running, this test file should've been copied to the mock + # remote location. assert not os.path.exists(os.path.join(random_dir, "mock_remote")) # Because Flyte doesn't presume to handle a uri that look like a raw path, the path that is returned is @@ -291,10 +322,15 @@ def my_wf() -> FlyteFile: return t1() # This creates a random directory that we know is empty. - random_dir = FlyteContextManager.current_context().file_access.get_random_local_directory() + random_dir = FlyteContextManager.current_context( + ).file_access.get_random_local_directory() # Creating a new FileAccessProvider will add two folderst to the random dir print(f"Random {random_dir}") - fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "mock_remote")) + fs = FileAccessProvider( + local_sandbox_dir=random_dir, + raw_output_prefix=os.path.join( + random_dir, + "mock_remote")) ctx = FlyteContextManager.current_context() with FlyteContextManager.with_context(ctx.with_file_access(fs)): working_dir = os.listdir(random_dir) @@ -302,12 +338,14 @@ def my_wf() -> FlyteFile: workflow_output = my_wf() - # After running the mock remote dir should still be empty, since the workflow_output has not been used + # After running the mock remote dir should still be empty, since the + # workflow_output has not been used with pytest.raises(FileNotFoundError): os.listdir(os.path.join(random_dir, "mock_remote")) # While the literal returned by t1 does contain the web address as the uri, because it's a remote address, - # flytekit will translate it back into a FlyteFile object on the local drive (but not download it) + # flytekit will translate it back into a FlyteFile object on the local + # drive (but not download it) assert workflow_output.path.startswith(random_dir) # But the remote source should still be the https address assert workflow_output.remote_source == SAMPLE_DATA @@ -335,7 +373,8 @@ def test_file_handling_remote_file_handling_flyte_file(): @task def t1() -> FlyteFile: - # Unlike the test above, this returns the remote path wrapped in a FlyteFile object + # Unlike the test above, this returns the remote path wrapped in a + # FlyteFile object return FlyteFile(SAMPLE_DATA) @workflow @@ -343,25 +382,34 @@ def my_wf() -> FlyteFile: return t1() # This creates a random directory that we know is empty. - random_dir = FlyteContextManager.current_context().file_access.get_random_local_directory() + random_dir = FlyteContextManager.current_context( + ).file_access.get_random_local_directory() # Creating a new FileAccessProvider will add two folderst to the random dir - fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "mock_remote")) + fs = FileAccessProvider( + local_sandbox_dir=random_dir, + raw_output_prefix=os.path.join( + random_dir, + "mock_remote")) ctx = FlyteContextManager.current_context() with FlyteContextManager.with_context(ctx.with_file_access(fs)): working_dir = os.listdir(random_dir) assert len(working_dir) == 1 # the local_flytekit dir mock_remote_path = os.path.join(random_dir, "mock_remote") - assert not os.path.exists(mock_remote_path) # the persistence layer won't create the folder yet + # the persistence layer won't create the folder yet + assert not os.path.exists(mock_remote_path) workflow_output = my_wf() - # After running the mock remote dir should still be empty, since the workflow_output has not been used + # After running the mock remote dir should still be empty, since the + # workflow_output has not been used assert not os.path.exists(mock_remote_path) # While the literal returned by t1 does contain the web address as the uri, because it's a remote address, - # flytekit will translate it back into a FlyteFile object on the local drive (but not download it) - assert workflow_output.path.startswith(f"{random_dir}{os.sep}local_flytekit") + # flytekit will translate it back into a FlyteFile object on the local + # drive (but not download it) + assert workflow_output.path.startswith( + f"{random_dir}{os.sep}local_flytekit") # But the remote source should still be the https address assert workflow_output.remote_source == SAMPLE_DATA @@ -403,17 +451,24 @@ def dyn(in1: FlyteFile): project="test_proj", domain="test_domain", version="abc", - image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), + image_config=ImageConfig( + Image( + name="name", + fqn="image", + tag="name")), env={}, ) ) ): ctx = FlyteContextManager.current_context() with FlyteContextManager.with_context( - ctx.with_execution_state(ctx.new_execution_state().with_params(mode=ExecutionState.Mode.TASK_EXECUTION)) + ctx.with_execution_state( + ctx.new_execution_state().with_params( + mode=ExecutionState.Mode.TASK_EXECUTION)) ) as ctx: lit = TypeEngine.to_literal( - ctx, fd, FlyteFile, BlobType("", dimensionality=BlobType.BlobDimensionality.SINGLE) + ctx, fd, FlyteFile, BlobType( + "", dimensionality=BlobType.BlobDimensionality.SINGLE) ) lm = LiteralMap(literals={"in1": lit}) wf = dyn.dispatch_execute(ctx, lm) @@ -421,7 +476,8 @@ def dyn(in1: FlyteFile): with pytest.raises(TypeError, match="No automatic conversion found from type "): TypeEngine.to_literal( - ctx, 3, FlyteFile, BlobType("", dimensionality=BlobType.BlobDimensionality.SINGLE) + ctx, 3, FlyteFile, BlobType( + "", dimensionality=BlobType.BlobDimensionality.SINGLE) ) @@ -453,7 +509,8 @@ def wf1() -> FlyteFile: assert fh.read() == "Hello world" assert wf_out._downloaded - # Remove the file, then call download again, it should not because _downloaded was already set. + # Remove the file, then call download again, it should not because + # _downloaded was already set. os.remove(wf_out.path) p = wf_out.download() assert not os.path.exists(wf_out.path) @@ -591,6 +648,7 @@ def wf(path: str) -> os.PathLike: assert flyte_tmp_dir in wf(path="s3://somewhere").path + def test_flyte_file_name_with_special_chars(): temp_dir = tempfile.TemporaryDirectory() file_path = os.path.join(temp_dir.name, "foo bar") @@ -610,6 +668,7 @@ def wf(f: FlyteFile) -> FlyteFile: finally: temp_dir.cleanup() + def test_flyte_file_annotated_hashmethod(local_dummy_file): def calc_hash(ff: FlyteFile) -> str: return str(ff.path) @@ -672,15 +731,18 @@ def print_file(ff: FlyteFile): local_sandbox_dir=new_sandbox, raw_output_prefix="s3://my-s3-bucket/testdata/", data_config=dc ) ctx = FlyteContextManager.current_context() - local = ctx.file_access.get_filesystem("file") # get a local file system. + # get a local file system. + local = ctx.file_access.get_filesystem("file") with FlyteContextManager.with_context(ctx.with_file_access(provider)): f = write_this_file_to_s3() copy_file(ff=f) files = local.find(new_sandbox) - # copy_file was done via streaming so no files should have been written + # copy_file was done via streaming so no files should have been + # written assert len(files) == 0 print_file(ff=f) - # print_file uses traditional download semantics so now a file should have been created + # print_file uses traditional download semantics so now a file + # should have been created files = local.find(new_sandbox) assert len(files) == 1 @@ -705,3 +767,18 @@ def test_new_remote_file(): nf = FlyteFile.new_remote_file(name="foo.txt") assert isinstance(nf, FlyteFile) assert nf.path.endswith('foo.txt') + + +def test_input_from_flyte_console_attribute_access_flytefile(local_dummy_file): + # Flyte Console will send the input data as protobuf Struct + + dict_obj = {"path": local_dummy_file} + json_str = json.dumps(dict_obj) + upstream_output = Literal( + scalar=Scalar( + generic=_json_format.Parse( + json_str, + _struct.Struct()))) + downstream_input = TypeEngine.to_python_value( + FlyteContextManager.current_context(), upstream_output, FlyteFile) + assert downstream_input == FlyteFile(local_dummy_file) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 7966b00f2c..0db5f10a46 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -3287,7 +3287,7 @@ class InnerWorkflowOutput(DataClassJSONMixin): @task def inner_task(input: float) -> float | None: - if input == 0: + if input == 0.0: return None return input @@ -3322,7 +3322,7 @@ def outer_workflow(input: OuterWorkflowInput) -> OuterWorkflowOutput: float_value_output = outer_workflow(OuterWorkflowInput(input=1.0)).nullable_output assert float_value_output == 1.0, f"Float value was {float_value_output}, not 1.0 as expected" - none_value_output = outer_workflow(OuterWorkflowInput(input=0)).nullable_output + none_value_output = outer_workflow(OuterWorkflowInput(input=0.0)).nullable_output assert none_value_output is None, f"None value was {none_value_output}, not None as expected" diff --git a/tests/flytekit/unit/core/test_type_engine_binary_idl.py b/tests/flytekit/unit/core/test_type_engine_binary_idl.py index 986fac0c7c..3426e8021d 100644 --- a/tests/flytekit/unit/core/test_type_engine_binary_idl.py +++ b/tests/flytekit/unit/core/test_type_engine_binary_idl.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from datetime import date, datetime, timedelta from enum import Enum -from typing import Dict, List +from typing import Dict, List, Optional, Union import pytest from google.protobuf import json_format as _json_format @@ -33,7 +33,11 @@ def test_simple_type_transformer(): encoder = MessagePackEncoder(int) for int_input in int_inputs: int_msgpack_bytes = encoder.encode(int_input) - lv = Literal(scalar=Scalar(binary=Binary(value=int_msgpack_bytes, tag="msgpack"))) + lv = Literal( + scalar=Scalar( + binary=Binary( + value=int_msgpack_bytes, + tag="msgpack"))) int_output = TypeEngine.to_python_value(ctx, lv, int) assert int_input == int_output @@ -41,7 +45,11 @@ def test_simple_type_transformer(): encoder = MessagePackEncoder(float) for float_input in float_inputs: float_msgpack_bytes = encoder.encode(float_input) - lv = Literal(scalar=Scalar(binary=Binary(value=float_msgpack_bytes, tag="msgpack"))) + lv = Literal( + scalar=Scalar( + binary=Binary( + value=float_msgpack_bytes, + tag="msgpack"))) float_output = TypeEngine.to_python_value(ctx, lv, float) assert float_input == float_output @@ -49,7 +57,11 @@ def test_simple_type_transformer(): encoder = MessagePackEncoder(bool) for bool_input in bool_inputs: bool_msgpack_bytes = encoder.encode(bool_input) - lv = Literal(scalar=Scalar(binary=Binary(value=bool_msgpack_bytes, tag="msgpack"))) + lv = Literal( + scalar=Scalar( + binary=Binary( + value=bool_msgpack_bytes, + tag="msgpack"))) bool_output = TypeEngine.to_python_value(ctx, lv, bool) assert bool_input == bool_output @@ -57,20 +69,28 @@ def test_simple_type_transformer(): encoder = MessagePackEncoder(str) for str_input in str_inputs: str_msgpack_bytes = encoder.encode(str_input) - lv = Literal(scalar=Scalar(binary=Binary(value=str_msgpack_bytes, tag="msgpack"))) + lv = Literal( + scalar=Scalar( + binary=Binary( + value=str_msgpack_bytes, + tag="msgpack"))) str_output = TypeEngine.to_python_value(ctx, lv, str) assert str_input == str_output datetime_inputs = [datetime.now(), - datetime(2024, 9, 18), - datetime(2024, 9, 18, 1), - datetime(2024, 9, 18, 1, 1), - datetime(2024, 9, 18, 1, 1, 1), - datetime(2024, 9, 18, 1, 1, 1, 1)] + datetime(2024, 9, 18), + datetime(2024, 9, 18, 1), + datetime(2024, 9, 18, 1, 1), + datetime(2024, 9, 18, 1, 1, 1), + datetime(2024, 9, 18, 1, 1, 1, 1)] encoder = MessagePackEncoder(datetime) for datetime_input in datetime_inputs: datetime_msgpack_bytes = encoder.encode(datetime_input) - lv = Literal(scalar=Scalar(binary=Binary(value=datetime_msgpack_bytes, tag="msgpack"))) + lv = Literal( + scalar=Scalar( + binary=Binary( + value=datetime_msgpack_bytes, + tag="msgpack"))) datetime_output = TypeEngine.to_python_value(ctx, lv, datetime) assert datetime_input == datetime_output @@ -79,22 +99,52 @@ def test_simple_type_transformer(): encoder = MessagePackEncoder(date) for date_input in date_inputs: date_msgpack_bytes = encoder.encode(date_input) - lv = Literal(scalar=Scalar(binary=Binary(value=date_msgpack_bytes, tag="msgpack"))) + lv = Literal( + scalar=Scalar( + binary=Binary( + value=date_msgpack_bytes, + tag="msgpack"))) date_output = TypeEngine.to_python_value(ctx, lv, date) assert date_input == date_output timedelta_inputs = [timedelta(days=1), timedelta(days=1, seconds=1), timedelta(days=1, seconds=1, microseconds=1), - timedelta(days=1, seconds=1, microseconds=1, milliseconds=1), - timedelta(days=1, seconds=1, microseconds=1, milliseconds=1, minutes=1), - timedelta(days=1, seconds=1, microseconds=1, milliseconds=1, minutes=1, hours=1), - timedelta(days=1, seconds=1, microseconds=1, milliseconds=1, minutes=1, hours=1, weeks=1), - timedelta(days=-1, seconds=-1, microseconds=-1, milliseconds=-1, minutes=-1, hours=-1, weeks=-1)] + timedelta( + days=1, + seconds=1, + microseconds=1, + milliseconds=1), + timedelta( + days=1, + seconds=1, + microseconds=1, + milliseconds=1, + minutes=1), + timedelta( + days=1, + seconds=1, + microseconds=1, + milliseconds=1, + minutes=1, + hours=1), + timedelta( + days=1, + seconds=1, + microseconds=1, + milliseconds=1, + minutes=1, + hours=1, + weeks=1), + timedelta(days=-1, seconds=-1, microseconds=-1, milliseconds=-1, minutes=-1, hours=-1, weeks=-1)] encoder = MessagePackEncoder(timedelta) for timedelta_input in timedelta_inputs: timedelta_msgpack_bytes = encoder.encode(timedelta_input) - lv = Literal(scalar=Scalar(binary=Binary(value=timedelta_msgpack_bytes, tag="msgpack"))) + lv = Literal( + scalar=Scalar( + binary=Binary( + value=timedelta_msgpack_bytes, + tag="msgpack"))) timedelta_output = TypeEngine.to_python_value(ctx, lv, timedelta) assert timedelta_input == timedelta_output @@ -117,7 +167,8 @@ def test_untyped_dict(): }, { "list_in_dict": [ - {"inner_dict_1": [1, -2.5, "a"], "inner_dict_2": [True, False, 3.14]}, + {"inner_dict_1": [1, -2.5, "a"], + "inner_dict_2": [True, False, 3.14]}, [1, -2, 3, {"nested_list_dict": [False, "test"]}], ] }, @@ -150,7 +201,10 @@ def test_untyped_dict(): # dict_msgpack_bytes = msgpack.dumps(dict_input) dict_msgpack_bytes = MessagePackEncoder(dict).encode(dict_input) lv = Literal( - scalar=Scalar(binary=Binary(value=dict_msgpack_bytes, tag="msgpack")) + scalar=Scalar( + binary=Binary( + value=dict_msgpack_bytes, + tag="msgpack")) ) dict_output = TypeEngine.to_python_value(ctx, lv, dict) assert dict_input == dict_output @@ -163,7 +217,10 @@ def test_list_transformer(): encoder = MessagePackEncoder(List[int]) list_int_msgpack_bytes = encoder.encode(list_int_input) lv = Literal( - scalar=Scalar(binary=Binary(value=list_int_msgpack_bytes, tag="msgpack")) + scalar=Scalar( + binary=Binary( + value=list_int_msgpack_bytes, + tag="msgpack")) ) list_int_output = TypeEngine.to_python_value(ctx, lv, List[int]) assert list_int_input == list_int_output @@ -172,7 +229,10 @@ def test_list_transformer(): encoder = MessagePackEncoder(List[float]) list_float_msgpack_bytes = encoder.encode(list_float_input) lv = Literal( - scalar=Scalar(binary=Binary(value=list_float_msgpack_bytes, tag="msgpack")) + scalar=Scalar( + binary=Binary( + value=list_float_msgpack_bytes, + tag="msgpack")) ) list_float_output = TypeEngine.to_python_value(ctx, lv, List[float]) assert list_float_input == list_float_output @@ -181,7 +241,10 @@ def test_list_transformer(): encoder = MessagePackEncoder(List[str]) list_str_msgpack_bytes = encoder.encode(list_str_input) lv = Literal( - scalar=Scalar(binary=Binary(value=list_str_msgpack_bytes, tag="msgpack")) + scalar=Scalar( + binary=Binary( + value=list_str_msgpack_bytes, + tag="msgpack")) ) list_str_output = TypeEngine.to_python_value(ctx, lv, List[str]) assert list_str_input == list_str_output @@ -190,7 +253,10 @@ def test_list_transformer(): encoder = MessagePackEncoder(List[bool]) list_bool_msgpack_bytes = encoder.encode(list_bool_input) lv = Literal( - scalar=Scalar(binary=Binary(value=list_bool_msgpack_bytes, tag="msgpack")) + scalar=Scalar( + binary=Binary( + value=list_bool_msgpack_bytes, + tag="msgpack")) ) list_bool_output = TypeEngine.to_python_value(ctx, lv, List[bool]) assert list_bool_input == list_bool_output @@ -199,7 +265,10 @@ def test_list_transformer(): encoder = MessagePackEncoder(List[List[int]]) list_list_int_msgpack_bytes = encoder.encode(list_list_int_input) lv = Literal( - scalar=Scalar(binary=Binary(value=list_list_int_msgpack_bytes, tag="msgpack")) + scalar=Scalar( + binary=Binary( + value=list_list_int_msgpack_bytes, + tag="msgpack")) ) list_list_int_output = TypeEngine.to_python_value(ctx, lv, List[List[int]]) assert list_list_int_input == list_list_int_output @@ -208,16 +277,23 @@ def test_list_transformer(): encoder = MessagePackEncoder(List[List[float]]) list_list_float_msgpack_bytes = encoder.encode(list_list_float_input) lv = Literal( - scalar=Scalar(binary=Binary(value=list_list_float_msgpack_bytes, tag="msgpack")) + scalar=Scalar( + binary=Binary( + value=list_list_float_msgpack_bytes, + tag="msgpack")) ) - list_list_float_output = TypeEngine.to_python_value(ctx, lv, List[List[float]]) + list_list_float_output = TypeEngine.to_python_value( + ctx, lv, List[List[float]]) assert list_list_float_input == list_list_float_output list_list_str_input = [["a", "b"], ["c", "d"]] encoder = MessagePackEncoder(List[List[str]]) list_list_str_msgpack_bytes = encoder.encode(list_list_str_input) lv = Literal( - scalar=Scalar(binary=Binary(value=list_list_str_msgpack_bytes, tag="msgpack")) + scalar=Scalar( + binary=Binary( + value=list_list_str_msgpack_bytes, + tag="msgpack")) ) list_list_str_output = TypeEngine.to_python_value(ctx, lv, List[List[str]]) assert list_list_str_input == list_list_str_output @@ -226,9 +302,13 @@ def test_list_transformer(): encoder = MessagePackEncoder(List[List[bool]]) list_list_bool_msgpack_bytes = encoder.encode(list_list_bool_input) lv = Literal( - scalar=Scalar(binary=Binary(value=list_list_bool_msgpack_bytes, tag="msgpack")) + scalar=Scalar( + binary=Binary( + value=list_list_bool_msgpack_bytes, + tag="msgpack")) ) - list_list_bool_output = TypeEngine.to_python_value(ctx, lv, List[List[bool]]) + list_list_bool_output = TypeEngine.to_python_value( + ctx, lv, List[List[bool]]) assert list_list_bool_input == list_list_bool_output list_dict_str_int_input = [{"key1": -1, "key2": 2}] @@ -239,15 +319,19 @@ def test_list_transformer(): binary=Binary(value=list_dict_str_int_msgpack_bytes, tag="msgpack") ) ) - list_dict_str_int_output = TypeEngine.to_python_value(ctx, lv, List[Dict[str, int]]) + list_dict_str_int_output = TypeEngine.to_python_value( + ctx, lv, List[Dict[str, int]]) assert list_dict_str_int_input == list_dict_str_int_output list_dict_str_float_input = [{"key1": 1.0, "key2": -2.0}] encoder = MessagePackEncoder(List[Dict[str, float]]) - list_dict_str_float_msgpack_bytes = encoder.encode(list_dict_str_float_input) + list_dict_str_float_msgpack_bytes = encoder.encode( + list_dict_str_float_input) lv = Literal( scalar=Scalar( - binary=Binary(value=list_dict_str_float_msgpack_bytes, tag="msgpack") + binary=Binary( + value=list_dict_str_float_msgpack_bytes, + tag="msgpack") ) ) list_dict_str_float_output = TypeEngine.to_python_value( @@ -263,7 +347,8 @@ def test_list_transformer(): binary=Binary(value=list_dict_str_str_msgpack_bytes, tag="msgpack") ) ) - list_dict_str_str_output = TypeEngine.to_python_value(ctx, lv, List[Dict[str, str]]) + list_dict_str_str_output = TypeEngine.to_python_value( + ctx, lv, List[Dict[str, str]]) assert list_dict_str_str_input == list_dict_str_str_output list_dict_str_bool_input = [{"key1": True, "key2": False}] @@ -271,7 +356,9 @@ def test_list_transformer(): list_dict_str_bool_msgpack_bytes = encoder.encode(list_dict_str_bool_input) lv = Literal( scalar=Scalar( - binary=Binary(value=list_dict_str_bool_msgpack_bytes, tag="msgpack") + binary=Binary( + value=list_dict_str_bool_msgpack_bytes, + tag="msgpack") ) ) list_dict_str_bool_output = TypeEngine.to_python_value( @@ -293,8 +380,10 @@ class InnerDC: h: Dict[int, bool] = field( default_factory=lambda: {0: False, 1: True, -1: False} ) - i: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]}) - j: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}}) + i: Dict[int, List[int]] = field( + default_factory=lambda: {0: [0, 1, -1]}) + j: Dict[int, Dict[int, int]] = field( + default_factory=lambda: {1: {-1: 0}}) k: dict = field(default_factory=lambda: {"key": "value"}) enum_status: Status = field(default=Status.PENDING) @@ -312,18 +401,24 @@ class DC: h: Dict[int, bool] = field( default_factory=lambda: {0: False, 1: True, -1: False} ) - i: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]}) - j: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}}) + i: Dict[int, List[int]] = field( + default_factory=lambda: {0: [0, 1, -1]}) + j: Dict[int, Dict[int, int]] = field( + default_factory=lambda: {1: {-1: 0}}) k: dict = field(default_factory=lambda: {"key": "value"}) inner_dc: InnerDC = field(default_factory=lambda: InnerDC()) enum_status: Status = field(default=Status.PENDING) - list_dict_int_inner_dc_input = [{1: InnerDC(), -1: InnerDC(), 0: InnerDC()}] + list_dict_int_inner_dc_input = [ + {1: InnerDC(), -1: InnerDC(), 0: InnerDC()}] encoder = MessagePackEncoder(List[Dict[int, InnerDC]]) - list_dict_int_inner_dc_msgpack_bytes = encoder.encode(list_dict_int_inner_dc_input) + list_dict_int_inner_dc_msgpack_bytes = encoder.encode( + list_dict_int_inner_dc_input) lv = Literal( scalar=Scalar( - binary=Binary(value=list_dict_int_inner_dc_msgpack_bytes, tag="msgpack") + binary=Binary( + value=list_dict_int_inner_dc_msgpack_bytes, + tag="msgpack") ) ) list_dict_int_inner_dc_output = TypeEngine.to_python_value( @@ -339,7 +434,8 @@ class DC: binary=Binary(value=list_dict_int_dc_msgpack_bytes, tag="msgpack") ) ) - list_dict_int_dc_output = TypeEngine.to_python_value(ctx, lv, List[Dict[int, DC]]) + list_dict_int_dc_output = TypeEngine.to_python_value( + ctx, lv, List[Dict[int, DC]]) assert list_dict_int_dc_input == list_dict_int_dc_output list_list_inner_dc_input = [[InnerDC(), InnerDC(), InnerDC()]] @@ -347,17 +443,23 @@ class DC: list_list_inner_dc_msgpack_bytes = encoder.encode(list_list_inner_dc_input) lv = Literal( scalar=Scalar( - binary=Binary(value=list_list_inner_dc_msgpack_bytes, tag="msgpack") + binary=Binary( + value=list_list_inner_dc_msgpack_bytes, + tag="msgpack") ) ) - list_list_inner_dc_output = TypeEngine.to_python_value(ctx, lv, List[List[InnerDC]]) + list_list_inner_dc_output = TypeEngine.to_python_value( + ctx, lv, List[List[InnerDC]]) assert list_list_inner_dc_input == list_list_inner_dc_output list_list_dc_input = [[DC(), DC(), DC()]] encoder = MessagePackEncoder(List[List[DC]]) list_list_dc_msgpack_bytes = encoder.encode(list_list_dc_input) lv = Literal( - scalar=Scalar(binary=Binary(value=list_list_dc_msgpack_bytes, tag="msgpack")) + scalar=Scalar( + binary=Binary( + value=list_list_dc_msgpack_bytes, + tag="msgpack")) ) list_list_dc_output = TypeEngine.to_python_value(ctx, lv, List[List[DC]]) assert list_list_dc_input == list_list_dc_output @@ -369,72 +471,124 @@ def test_dict_transformer(local_dummy_file, local_dummy_directory): dict_str_int_input = {"key1": 1, "key2": -2} encoder = MessagePackEncoder(Dict[str, int]) dict_str_int_msgpack_bytes = encoder.encode(dict_str_int_input) - lv = Literal(scalar=Scalar(binary=Binary(value=dict_str_int_msgpack_bytes, tag="msgpack"))) + lv = Literal( + scalar=Scalar( + binary=Binary( + value=dict_str_int_msgpack_bytes, + tag="msgpack"))) dict_str_int_output = TypeEngine.to_python_value(ctx, lv, Dict[str, int]) assert dict_str_int_input == dict_str_int_output dict_str_float_input = {"key1": 1.0, "key2": -2.0} encoder = MessagePackEncoder(Dict[str, float]) dict_str_float_msgpack_bytes = encoder.encode(dict_str_float_input) - lv = Literal(scalar=Scalar(binary=Binary(value=dict_str_float_msgpack_bytes, tag="msgpack"))) - dict_str_float_output = TypeEngine.to_python_value(ctx, lv, Dict[str, float]) + lv = Literal( + scalar=Scalar( + binary=Binary( + value=dict_str_float_msgpack_bytes, + tag="msgpack"))) + dict_str_float_output = TypeEngine.to_python_value( + ctx, lv, Dict[str, float]) assert dict_str_float_input == dict_str_float_output dict_str_str_input = {"key1": "a", "key2": "b"} encoder = MessagePackEncoder(Dict[str, str]) dict_str_str_msgpack_bytes = encoder.encode(dict_str_str_input) - lv = Literal(scalar=Scalar(binary=Binary(value=dict_str_str_msgpack_bytes, tag="msgpack"))) + lv = Literal( + scalar=Scalar( + binary=Binary( + value=dict_str_str_msgpack_bytes, + tag="msgpack"))) dict_str_str_output = TypeEngine.to_python_value(ctx, lv, Dict[str, str]) assert dict_str_str_input == dict_str_str_output dict_str_bool_input = {"key1": True, "key2": False} encoder = MessagePackEncoder(Dict[str, bool]) dict_str_bool_msgpack_bytes = encoder.encode(dict_str_bool_input) - lv = Literal(scalar=Scalar(binary=Binary(value=dict_str_bool_msgpack_bytes, tag="msgpack"))) + lv = Literal( + scalar=Scalar( + binary=Binary( + value=dict_str_bool_msgpack_bytes, + tag="msgpack"))) dict_str_bool_output = TypeEngine.to_python_value(ctx, lv, Dict[str, bool]) assert dict_str_bool_input == dict_str_bool_output dict_str_list_int_input = {"key1": [1, -2, 3]} encoder = MessagePackEncoder(Dict[str, List[int]]) dict_str_list_int_msgpack_bytes = encoder.encode(dict_str_list_int_input) - lv = Literal(scalar=Scalar(binary=Binary(value=dict_str_list_int_msgpack_bytes, tag="msgpack"))) - dict_str_list_int_output = TypeEngine.to_python_value(ctx, lv, Dict[str, List[int]]) + lv = Literal( + scalar=Scalar( + binary=Binary( + value=dict_str_list_int_msgpack_bytes, + tag="msgpack"))) + dict_str_list_int_output = TypeEngine.to_python_value( + ctx, lv, Dict[str, List[int]]) assert dict_str_list_int_input == dict_str_list_int_output dict_str_dict_str_int_input = {"key1": {"subkey1": 1, "subkey2": -2}} encoder = MessagePackEncoder(Dict[str, Dict[str, int]]) - dict_str_dict_str_int_msgpack_bytes = encoder.encode(dict_str_dict_str_int_input) - lv = Literal(scalar=Scalar(binary=Binary(value=dict_str_dict_str_int_msgpack_bytes, tag="msgpack"))) - dict_str_dict_str_int_output = TypeEngine.to_python_value(ctx, lv, Dict[str, Dict[str, int]]) + dict_str_dict_str_int_msgpack_bytes = encoder.encode( + dict_str_dict_str_int_input) + lv = Literal( + scalar=Scalar( + binary=Binary( + value=dict_str_dict_str_int_msgpack_bytes, + tag="msgpack"))) + dict_str_dict_str_int_output = TypeEngine.to_python_value( + ctx, lv, Dict[str, Dict[str, int]]) assert dict_str_dict_str_int_input == dict_str_dict_str_int_output - dict_str_dict_str_list_int_input = {"key1": {"subkey1": [1, -2], "subkey2": [-3, 4]}} + dict_str_dict_str_list_int_input = { + "key1": {"subkey1": [1, -2], "subkey2": [-3, 4]}} encoder = MessagePackEncoder(Dict[str, Dict[str, List[int]]]) - dict_str_dict_str_list_int_msgpack_bytes = encoder.encode(dict_str_dict_str_list_int_input) - lv = Literal(scalar=Scalar(binary=Binary(value=dict_str_dict_str_list_int_msgpack_bytes, tag="msgpack"))) - dict_str_dict_str_list_int_output = TypeEngine.to_python_value(ctx, lv, Dict[str, Dict[str, List[int]]]) + dict_str_dict_str_list_int_msgpack_bytes = encoder.encode( + dict_str_dict_str_list_int_input) + lv = Literal( + scalar=Scalar( + binary=Binary( + value=dict_str_dict_str_list_int_msgpack_bytes, + tag="msgpack"))) + dict_str_dict_str_list_int_output = TypeEngine.to_python_value( + ctx, lv, Dict[str, Dict[str, List[int]]]) assert dict_str_dict_str_list_int_input == dict_str_dict_str_list_int_output - dict_str_list_dict_str_int_input = {"key1": [{"subkey1": -1}, {"subkey2": 2}]} + dict_str_list_dict_str_int_input = { + "key1": [{"subkey1": -1}, {"subkey2": 2}]} encoder = MessagePackEncoder(Dict[str, List[Dict[str, int]]]) - dict_str_list_dict_str_int_msgpack_bytes = encoder.encode(dict_str_list_dict_str_int_input) - lv = Literal(scalar=Scalar(binary=Binary(value=dict_str_list_dict_str_int_msgpack_bytes, tag="msgpack"))) - dict_str_list_dict_str_int_output = TypeEngine.to_python_value(ctx, lv, Dict[str, List[Dict[str, int]]]) + dict_str_list_dict_str_int_msgpack_bytes = encoder.encode( + dict_str_list_dict_str_int_input) + lv = Literal( + scalar=Scalar( + binary=Binary( + value=dict_str_list_dict_str_int_msgpack_bytes, + tag="msgpack"))) + dict_str_list_dict_str_int_output = TypeEngine.to_python_value( + ctx, lv, Dict[str, List[Dict[str, int]]]) assert dict_str_list_dict_str_int_input == dict_str_list_dict_str_int_output # non-strict types dict_int_str_input = {1: "a", -2: "b"} encoder = MessagePackEncoder(dict) dict_int_str_msgpack_bytes = encoder.encode(dict_int_str_input) - lv = Literal(scalar=Scalar(binary=Binary(value=dict_int_str_msgpack_bytes, tag="msgpack"))) + lv = Literal( + scalar=Scalar( + binary=Binary( + value=dict_int_str_msgpack_bytes, + tag="msgpack"))) dict_int_str_output = TypeEngine.to_python_value(ctx, lv, dict) assert dict_int_str_input == dict_int_str_output dict_int_dict_int_list_int_input = {1: {-2: [1, -2]}, -3: {4: [-3, 4]}} encoder = MessagePackEncoder(Dict[int, Dict[int, List[int]]]) - dict_int_dict_int_list_int_msgpack_bytes = encoder.encode(dict_int_dict_int_list_int_input) - lv = Literal(scalar=Scalar(binary=Binary(value=dict_int_dict_int_list_int_msgpack_bytes, tag="msgpack"))) - dict_int_dict_int_list_int_output = TypeEngine.to_python_value(ctx, lv, Dict[int, Dict[int, List[int]]]) + dict_int_dict_int_list_int_msgpack_bytes = encoder.encode( + dict_int_dict_int_list_int_input) + lv = Literal( + scalar=Scalar( + binary=Binary( + value=dict_int_dict_int_list_int_msgpack_bytes, + tag="msgpack"))) + dict_int_dict_int_list_int_output = TypeEngine.to_python_value( + ctx, lv, Dict[int, Dict[int, List[int]]]) assert dict_int_dict_int_list_int_input == dict_int_dict_int_list_int_output @dataclass @@ -451,8 +605,10 @@ class InnerDC: h: Dict[int, bool] = field( default_factory=lambda: {0: False, 1: True, -1: False} ) - i: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]}) - j: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}}) + i: Dict[int, List[int]] = field( + default_factory=lambda: {0: [0, 1, -1]}) + j: Dict[int, Dict[int, int]] = field( + default_factory=lambda: {1: {-1: 0}}) k: dict = field(default_factory=lambda: {"key": "value"}) enum_status: Status = field(default=Status.PENDING) @@ -470,8 +626,10 @@ class DC: h: Dict[int, bool] = field( default_factory=lambda: {0: False, 1: True, -1: False} ) - i: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]}) - j: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}}) + i: Dict[int, List[int]] = field( + default_factory=lambda: {0: [0, 1, -1]}) + j: Dict[int, Dict[int, int]] = field( + default_factory=lambda: {1: {-1: 0}}) k: dict = field(default_factory=lambda: {"key": "value"}) inner_dc: InnerDC = field(default_factory=lambda: InnerDC()) enum_status: Status = field(default=Status.PENDING) @@ -484,14 +642,18 @@ class DC: binary=Binary(value=dict_int_inner_dc_msgpack_bytes, tag="msgpack") ) ) - dict_int_inner_dc_output = TypeEngine.to_python_value(ctx, lv, Dict[int, InnerDC]) + dict_int_inner_dc_output = TypeEngine.to_python_value( + ctx, lv, Dict[int, InnerDC]) assert dict_int_inner_dc_input == dict_int_inner_dc_output dict_int_dc = {1: DC(), -2: DC(), 0: DC()} encoder = MessagePackEncoder(Dict[int, DC]) dict_int_dc_msgpack_bytes = encoder.encode(dict_int_dc) lv = Literal( - scalar=Scalar(binary=Binary(value=dict_int_dc_msgpack_bytes, tag="msgpack")) + scalar=Scalar( + binary=Binary( + value=dict_int_dc_msgpack_bytes, + tag="msgpack")) ) dict_int_dc_output = TypeEngine.to_python_value(ctx, lv, Dict[int, DC]) assert dict_int_dc == dict_int_dc_output @@ -522,12 +684,17 @@ def local_dummy_directory(): def test_flytetypes_in_dataclass_wf(local_dummy_file, local_dummy_directory): @dataclass class InnerDC: - flytefile: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) - flytedir: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(local_dummy_directory)) + flytefile: FlyteFile = field( + default_factory=lambda: FlyteFile(local_dummy_file)) + flytedir: FlyteDirectory = field( + default_factory=lambda: FlyteDirectory(local_dummy_directory)) + @dataclass class DC: - flytefile: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) - flytedir: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(local_dummy_directory)) + flytefile: FlyteFile = field( + default_factory=lambda: FlyteFile(local_dummy_file)) + flytedir: FlyteDirectory = field( + default_factory=lambda: FlyteDirectory(local_dummy_directory)) inner_dc: InnerDC = field(default_factory=lambda: InnerDC()) @task @@ -559,6 +726,7 @@ def wf(dc: DC) -> (FlyteFile, FlyteFile, FlyteDirectory, FlyteDirectory): with open(os.path.join(o4, "file"), "r") as fh: assert fh.read() == "Hello FlyteDirectory" + def test_all_types_in_dataclass_wf(local_dummy_file, local_dummy_directory): @dataclass class InnerDC: @@ -567,18 +735,26 @@ class InnerDC: c: str = "Hello, Flyte" d: bool = False e: List[int] = field(default_factory=lambda: [0, 1, 2, -1, -2]) - f: List[FlyteFile] = field(default_factory=lambda: [FlyteFile(local_dummy_file)]) + f: List[FlyteFile] = field( + default_factory=lambda: [ + FlyteFile(local_dummy_file)]) g: List[List[int]] = field(default_factory=lambda: [[0], [1], [-1]]) - h: List[Dict[int, bool]] = field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}]) - i: Dict[int, bool] = field(default_factory=lambda: {0: False, 1: True, -1: False}) + h: List[Dict[int, bool]] = field(default_factory=lambda: [ + {0: False}, {1: True}, {-1: True}]) + i: Dict[int, bool] = field(default_factory=lambda: { + 0: False, 1: True, -1: False}) j: Dict[int, FlyteFile] = field(default_factory=lambda: {0: FlyteFile(local_dummy_file), 1: FlyteFile(local_dummy_file), -1: FlyteFile(local_dummy_file)}) - k: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]}) - l: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}}) + k: Dict[int, List[int]] = field( + default_factory=lambda: {0: [0, 1, -1]}) + l: Dict[int, Dict[int, int]] = field( + default_factory=lambda: {1: {-1: 0}}) m: dict = field(default_factory=lambda: {"key": "value"}) - n: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) - o: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(local_dummy_directory)) + n: FlyteFile = field( + default_factory=lambda: FlyteFile(local_dummy_file)) + o: FlyteDirectory = field( + default_factory=lambda: FlyteDirectory(local_dummy_directory)) enum_status: Status = field(default=Status.PENDING) @dataclass @@ -588,18 +764,26 @@ class DC: c: str = "Hello, Flyte" d: bool = False e: List[int] = field(default_factory=lambda: [0, 1, 2, -1, -2]) - f: List[FlyteFile] = field(default_factory=lambda: [FlyteFile(local_dummy_file), ]) + f: List[FlyteFile] = field( + default_factory=lambda: [ + FlyteFile(local_dummy_file), ]) g: List[List[int]] = field(default_factory=lambda: [[0], [1], [-1]]) - h: List[Dict[int, bool]] = field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}]) - i: Dict[int, bool] = field(default_factory=lambda: {0: False, 1: True, -1: False}) + h: List[Dict[int, bool]] = field(default_factory=lambda: [ + {0: False}, {1: True}, {-1: True}]) + i: Dict[int, bool] = field(default_factory=lambda: { + 0: False, 1: True, -1: False}) j: Dict[int, FlyteFile] = field(default_factory=lambda: {0: FlyteFile(local_dummy_file), 1: FlyteFile(local_dummy_file), -1: FlyteFile(local_dummy_file)}) - k: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]}) - l: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}}) + k: Dict[int, List[int]] = field( + default_factory=lambda: {0: [0, 1, -1]}) + l: Dict[int, Dict[int, int]] = field( + default_factory=lambda: {1: {-1: 0}}) m: dict = field(default_factory=lambda: {"key": "value"}) - n: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) - o: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(local_dummy_directory)) + n: FlyteFile = field( + default_factory=lambda: FlyteFile(local_dummy_file)) + o: FlyteDirectory = field( + default_factory=lambda: FlyteDirectory(local_dummy_directory)) inner_dc: InnerDC = field(default_factory=lambda: InnerDC()) enum_status: Status = field(default=Status.PENDING) @@ -627,11 +811,10 @@ def t_inner(inner_dc: InnerDC): with open(os.path.join(inner_dc.o, "file"), "r") as fh: assert fh.read() == "Hello FlyteDirectory" assert inner_dc.o.downloaded - print("Test InnerDC Successfully Passed") + # enum: Status assert inner_dc.enum_status == Status.PENDING - @task def t_test_all_attributes(a: int, b: float, c: str, d: bool, e: List[int], f: List[FlyteFile], g: List[List[int]], h: List[Dict[int, bool]], i: Dict[int, bool], j: Dict[int, FlyteFile], @@ -645,10 +828,12 @@ def t_test_all_attributes(a: int, b: float, c: str, d: bool, e: List[int], f: Li assert isinstance(d, bool), f"d is not bool, it's {type(d)}" # Strict type checks for List[int] - assert isinstance(e, list) and all(isinstance(i, int) for i in e), "e is not List[int]" + assert isinstance(e, list) and all(isinstance(i, int) + for i in e), "e is not List[int]" # Strict type checks for List[FlyteFile] - assert isinstance(f, list) and all(isinstance(i, FlyteFile) for i in f), "f is not List[FlyteFile]" + assert isinstance(f, list) and all(isinstance(i, FlyteFile) + for i in f), "f is not List[FlyteFile]" # Strict type checks for List[List[int]] assert isinstance(g, list) and all( @@ -690,8 +875,6 @@ def t_test_all_attributes(a: int, b: float, c: str, d: bool, e: List[int], f: Li # Strict type check for Enum assert isinstance(enum_status, Status), "enum_status is not Status" - print("All attributes passed strict type checks.") - @workflow def wf(dc: DC): t_inner(dc.inner_dc) @@ -709,7 +892,12 @@ def wf(dc: DC): wf(dc=DC()) -def test_backward_compatible_with_dataclass_in_protobuf_struct(local_dummy_file, local_dummy_directory): + +def test_backward_compatible_with_dataclass_in_protobuf_struct( + local_dummy_file, local_dummy_directory): + # Flyte Console will send the input data as protobuf Struct + # This test also test how Flyte Console with attribute access on the + # Struct object @dataclass class InnerDC: @@ -718,18 +906,26 @@ class InnerDC: c: str = "Hello, Flyte" d: bool = False e: List[int] = field(default_factory=lambda: [0, 1, 2, -1, -2]) - f: List[FlyteFile] = field(default_factory=lambda: [FlyteFile(local_dummy_file)]) + f: List[FlyteFile] = field( + default_factory=lambda: [ + FlyteFile(local_dummy_file)]) g: List[List[int]] = field(default_factory=lambda: [[0], [1], [-1]]) - h: List[Dict[int, bool]] = field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}]) - i: Dict[int, bool] = field(default_factory=lambda: {0: False, 1: True, -1: False}) + h: List[Dict[int, bool]] = field(default_factory=lambda: [ + {0: False}, {1: True}, {-1: True}]) + i: Dict[int, bool] = field(default_factory=lambda: { + 0: False, 1: True, -1: False}) j: Dict[int, FlyteFile] = field(default_factory=lambda: {0: FlyteFile(local_dummy_file), 1: FlyteFile(local_dummy_file), -1: FlyteFile(local_dummy_file)}) - k: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]}) - l: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}}) + k: Dict[int, List[int]] = field( + default_factory=lambda: {0: [0, 1, -1]}) + l: Dict[int, Dict[int, int]] = field( + default_factory=lambda: {1: {-1: 0}}) m: dict = field(default_factory=lambda: {"key": "value"}) - n: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) - o: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(local_dummy_directory)) + n: FlyteFile = field( + default_factory=lambda: FlyteFile(local_dummy_file)) + o: FlyteDirectory = field( + default_factory=lambda: FlyteDirectory(local_dummy_directory)) enum_status: Status = field(default=Status.PENDING) @dataclass @@ -739,18 +935,26 @@ class DC: c: str = "Hello, Flyte" d: bool = False e: List[int] = field(default_factory=lambda: [0, 1, 2, -1, -2]) - f: List[FlyteFile] = field(default_factory=lambda: [FlyteFile(local_dummy_file), ]) + f: List[FlyteFile] = field( + default_factory=lambda: [ + FlyteFile(local_dummy_file), ]) g: List[List[int]] = field(default_factory=lambda: [[0], [1], [-1]]) - h: List[Dict[int, bool]] = field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}]) - i: Dict[int, bool] = field(default_factory=lambda: {0: False, 1: True, -1: False}) + h: List[Dict[int, bool]] = field(default_factory=lambda: [ + {0: False}, {1: True}, {-1: True}]) + i: Dict[int, bool] = field(default_factory=lambda: { + 0: False, 1: True, -1: False}) j: Dict[int, FlyteFile] = field(default_factory=lambda: {0: FlyteFile(local_dummy_file), 1: FlyteFile(local_dummy_file), -1: FlyteFile(local_dummy_file)}) - k: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]}) - l: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}}) + k: Dict[int, List[int]] = field( + default_factory=lambda: {0: [0, 1, -1]}) + l: Dict[int, Dict[int, int]] = field( + default_factory=lambda: {1: {-1: 0}}) m: dict = field(default_factory=lambda: {"key": "value"}) - n: FlyteFile = field(default_factory=lambda: FlyteFile(local_dummy_file)) - o: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(local_dummy_directory)) + n: FlyteFile = field( + default_factory=lambda: FlyteFile(local_dummy_file)) + o: FlyteDirectory = field( + default_factory=lambda: FlyteDirectory(local_dummy_directory)) inner_dc: InnerDC = field(default_factory=lambda: InnerDC()) enum_status: Status = field(default=Status.PENDING) @@ -777,7 +981,7 @@ def t_inner(inner_dc: InnerDC): with open(os.path.join(inner_dc.o, "file"), "r") as fh: assert fh.read() == "Hello FlyteDirectory" assert inner_dc.o.downloaded - print("Test InnerDC Successfully Passed") + # enum: Status assert inner_dc.enum_status == Status.PENDING @@ -793,10 +997,12 @@ def t_test_all_attributes(a: int, b: float, c: str, d: bool, e: List[int], f: Li assert isinstance(d, bool), f"d is not bool, it's {type(d)}" # Strict type checks for List[int] - assert isinstance(e, list) and all(isinstance(i, int) for i in e), "e is not List[int]" + assert isinstance(e, list) and all(isinstance(i, int) + for i in e), "e is not List[int]" # Strict type checks for List[FlyteFile] - assert isinstance(f, list) and all(isinstance(i, FlyteFile) for i in f), "f is not List[FlyteFile]" + assert isinstance(f, list) and all(isinstance(i, FlyteFile) + for i in f), "f is not List[FlyteFile]" # Strict type checks for List[List[int]] assert isinstance(g, list) and all( @@ -838,16 +1044,19 @@ def t_test_all_attributes(a: int, b: float, c: str, d: bool, e: List[int], f: Li # Strict type check for Enum assert isinstance(enum_status, Status), "enum_status is not Status" - print("All attributes passed strict type checks.") - # This is the old dataclass serialization behavior. # https://github.com/flyteorg/flytekit/blob/94786cfd4a5c2c3b23ac29dcd6f04d0553fa1beb/flytekit/core/type_engine.py#L702-L728 dc = DC() DataclassTransformer()._make_dataclass_serializable(python_val=dc, python_type=DC) json_str = JSONEncoder(DC).encode(dc) - upstream_output = Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) + upstream_output = Literal( + scalar=Scalar( + generic=_json_format.Parse( + json_str, + _struct.Struct()))) - downstream_input = TypeEngine.to_python_value(FlyteContextManager.current_context(), upstream_output, DC) + downstream_input = TypeEngine.to_python_value( + FlyteContextManager.current_context(), upstream_output, DC) t_inner(downstream_input.inner_dc) t_test_all_attributes(a=downstream_input.a, b=downstream_input.b, c=downstream_input.c, d=downstream_input.d, e=downstream_input.e, f=downstream_input.f, @@ -862,10 +1071,11 @@ def t_test_all_attributes(a: int, b: float, c: str, d: bool, e: List[int], f: Li m=downstream_input.inner_dc.m, n=downstream_input.inner_dc.n, o=downstream_input.inner_dc.o, enum_status=downstream_input.inner_dc.enum_status) + def test_backward_compatible_with_untyped_dict_in_protobuf_struct(): # This is the old dataclass serialization behavior. # https://github.com/flyteorg/flytekit/blob/94786cfd4a5c2c3b23ac29dcd6f04d0553fa1beb/flytekit/core/type_engine.py#L1699-L1720 - dict_input = {"a" : 1.0, "b": "str", + dict_input = {"a": 1.0, "b": "str", "c": False, "d": True, "e": [1.0, 2.0, -1.0, 0.0], "f": {"a": {"b": [1.0, -1.0]}}} @@ -873,5 +1083,413 @@ def test_backward_compatible_with_untyped_dict_in_protobuf_struct(): upstream_output = Literal(scalar=Scalar(generic=_json_format.Parse(json.dumps(dict_input), _struct.Struct())), metadata={"format": "json"}) - downstream_input = TypeEngine.to_python_value(FlyteContextManager.current_context(), upstream_output, dict) + downstream_input = TypeEngine.to_python_value( + FlyteContextManager.current_context(), upstream_output, dict) assert dict_input == downstream_input + + +def test_flyte_console_input_with_typed_dict_with_flyte_types_in_dataclass_in_protobuf_struct( + local_dummy_file, local_dummy_directory): + # TODO: We can add more nested cases for non-flyte types. + """ + Handles the case where Flyte Console provides input as a protobuf struct. + When resolving an attribute like 'dc.dict_int_ff', FlytePropeller retrieves a dictionary. + Mashumaro's decoder can convert this dictionary to the expected Python object if the correct type is provided. + Since Flyte Types handle their own deserialization, the dictionary is automatically converted to the expected Python object. + + Example Code: + @dataclass + class DC: + dict_int_ff: Dict[int, FlyteFile] + + @workflow + def wf(dc: DC): + t_ff(dc.dict_int_ff) + + Life Cycle: + json str -> protobuf struct -> resolved protobuf struct -> dictionary -> expected Python object + (console user input) (console output) (propeller) (flytekit dict transformer) (mashumaro decoder) + + Related PR: + - Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro + - Link: https://github.com/flyteorg/flytekit/pull/2554 + - Title: Binary IDL With MessagePack + - Link: https://github.com/flyteorg/flytekit/pull/2760 + """ + + dict_int_flyte_file = {"1": {"path": local_dummy_file}} + json_str = json.dumps(dict_int_flyte_file) + upstream_output = Literal( + scalar=Scalar( + generic=_json_format.Parse( + json_str, + _struct.Struct())), + metadata={ + "format": "json"}) + downstream_input = TypeEngine.to_python_value( + FlyteContextManager.current_context(), upstream_output, Dict[int, FlyteFile]) + assert downstream_input == {1: FlyteFile(local_dummy_file)} + + # FlyteConsole trims trailing ".0" when converting float-like strings + dict_float_flyte_file = {"1": {"path": local_dummy_file}} + json_str = json.dumps(dict_float_flyte_file) + upstream_output = Literal( + scalar=Scalar( + generic=_json_format.Parse( + json_str, + _struct.Struct())), + metadata={ + "format": "json"}) + downstream_input = TypeEngine.to_python_value( + FlyteContextManager.current_context(), upstream_output, Dict[float, FlyteFile]) + assert downstream_input == {1.0: FlyteFile(local_dummy_file)} + + dict_float_flyte_file = {"1.0": {"path": local_dummy_file}} + json_str = json.dumps(dict_float_flyte_file) + upstream_output = Literal( + scalar=Scalar( + generic=_json_format.Parse( + json_str, + _struct.Struct())), + metadata={ + "format": "json"}) + downstream_input = TypeEngine.to_python_value( + FlyteContextManager.current_context(), upstream_output, Dict[float, FlyteFile]) + assert downstream_input == {1.0: FlyteFile(local_dummy_file)} + + dict_str_flyte_file = {"1": {"path": local_dummy_file}} + json_str = json.dumps(dict_str_flyte_file) + upstream_output = Literal( + scalar=Scalar( + generic=_json_format.Parse( + json_str, + _struct.Struct())), + metadata={ + "format": "json"}) + downstream_input = TypeEngine.to_python_value( + FlyteContextManager.current_context(), upstream_output, Dict[str, FlyteFile]) + assert downstream_input == {"1": FlyteFile(local_dummy_file)} + + dict_int_flyte_directory = {"1": {"path": local_dummy_directory}} + json_str = json.dumps(dict_int_flyte_directory) + upstream_output = Literal( + scalar=Scalar( + generic=_json_format.Parse( + json_str, + _struct.Struct())), + metadata={ + "format": "json"}) + downstream_input = TypeEngine.to_python_value( + FlyteContextManager.current_context(), upstream_output, Dict[int, FlyteDirectory]) + assert downstream_input == {1: FlyteDirectory(local_dummy_directory)} + + # FlyteConsole trims trailing ".0" when converting float-like strings + dict_float_flyte_directory = {"1": {"path": local_dummy_directory}} + json_str = json.dumps(dict_float_flyte_directory) + upstream_output = Literal( + scalar=Scalar( + generic=_json_format.Parse( + json_str, + _struct.Struct())), + metadata={ + "format": "json"}) + downstream_input = TypeEngine.to_python_value( + FlyteContextManager.current_context(), upstream_output, Dict[float, FlyteDirectory]) + assert downstream_input == {1.0: FlyteDirectory(local_dummy_directory)} + + dict_float_flyte_directory = {"1.0": {"path": local_dummy_directory}} + json_str = json.dumps(dict_float_flyte_directory) + upstream_output = Literal( + scalar=Scalar( + generic=_json_format.Parse( + json_str, + _struct.Struct())), + metadata={ + "format": "json"}) + downstream_input = TypeEngine.to_python_value( + FlyteContextManager.current_context(), upstream_output, Dict[float, FlyteDirectory]) + assert downstream_input == {1.0: FlyteDirectory(local_dummy_directory)} + + dict_str_flyte_file = {"1": {"path": local_dummy_file}} + json_str = json.dumps(dict_str_flyte_file) + upstream_output = Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())), + metadata={"format": "json"}) + downstream_input = TypeEngine.to_python_value(FlyteContextManager.current_context(), upstream_output, + Dict[str, FlyteFile]) + assert downstream_input == {"1": FlyteFile(local_dummy_file)} + + +def test_all_types_with_optional_in_dataclass_basemodel_wf( + local_dummy_file, local_dummy_directory): + @dataclass + class InnerDC: + a: Optional[int] = -1 + b: Optional[float] = 2.1 + c: Optional[str] = "Hello, Flyte" + d: Optional[bool] = False + e: Optional[List[int]] = field( + default_factory=lambda: [0, 1, 2, -1, -2]) + f: Optional[List[FlyteFile]] = field( + default_factory=lambda: [FlyteFile(local_dummy_file)]) + g: Optional[List[List[int]]] = field( + default_factory=lambda: [[0], [1], [-1]]) + h: Optional[List[Dict[int, bool]]] = field( + default_factory=lambda: [{0: False}, {1: True}, {-1: True}]) + i: Optional[Dict[int, bool]] = field( + default_factory=lambda: {0: False, 1: True, -1: False}) + j: Optional[Dict[int, FlyteFile]] = field(default_factory=lambda: {0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file)}) + k: Optional[Dict[int, List[int]]] = field( + default_factory=lambda: {0: [0, 1, -1]}) + l: Optional[Dict[int, Dict[int, int]]] = field( + default_factory=lambda: {1: {-1: 0}}) + m: Optional[dict] = field(default_factory=lambda: {"key": "value"}) + n: Optional[FlyteFile] = field( + default_factory=lambda: FlyteFile(local_dummy_file)) + o: Optional[FlyteDirectory] = field( + default_factory=lambda: FlyteDirectory(local_dummy_directory)) + enum_status: Optional[Status] = field(default=Status.PENDING) + + @dataclass + class DC: + a: Optional[int] = -1 + b: Optional[float] = 2.1 + c: Optional[str] = "Hello, Flyte" + d: Optional[bool] = False + e: Optional[List[int]] = field( + default_factory=lambda: [0, 1, 2, -1, -2]) + f: Optional[List[FlyteFile]] = field( + default_factory=lambda: [FlyteFile(local_dummy_file)]) + g: Optional[List[List[int]]] = field( + default_factory=lambda: [[0], [1], [-1]]) + h: Optional[List[Dict[int, bool]]] = field( + default_factory=lambda: [{0: False}, {1: True}, {-1: True}]) + i: Optional[Dict[int, bool]] = field( + default_factory=lambda: {0: False, 1: True, -1: False}) + j: Optional[Dict[int, FlyteFile]] = field(default_factory=lambda: {0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file)}) + k: Optional[Dict[int, List[int]]] = field( + default_factory=lambda: {0: [0, 1, -1]}) + l: Optional[Dict[int, Dict[int, int]]] = field( + default_factory=lambda: {1: {-1: 0}}) + m: Optional[dict] = field(default_factory=lambda: {"key": "value"}) + n: Optional[FlyteFile] = field( + default_factory=lambda: FlyteFile(local_dummy_file)) + o: Optional[FlyteDirectory] = field( + default_factory=lambda: FlyteDirectory(local_dummy_directory)) + inner_dc: Optional[InnerDC] = field(default_factory=lambda: InnerDC()) + enum_status: Optional[Status] = field(default=Status.PENDING) + + @task + def t_inner(inner_dc: InnerDC): + assert type(inner_dc) is InnerDC + + # f: List[FlyteFile] + for ff in inner_dc.f: # type: ignore + assert type(ff) is FlyteFile + with open(ff, "r") as f: + assert f.read() == "Hello FlyteFile" + # j: Dict[int, FlyteFile] + for _, ff in inner_dc.j.items(): # type: ignore + assert type(ff) is FlyteFile + with open(ff, "r") as f: + assert f.read() == "Hello FlyteFile" + # n: FlyteFile + assert type(inner_dc.n) is FlyteFile + with open(inner_dc.n, "r") as f: + assert f.read() == "Hello FlyteFile" + # o: FlyteDirectory + assert type(inner_dc.o) is FlyteDirectory + assert not inner_dc.o.downloaded + with open(os.path.join(inner_dc.o, "file"), "r") as fh: + assert fh.read() == "Hello FlyteDirectory" + assert inner_dc.o.downloaded + + # enum: Status + assert inner_dc.enum_status == Status.PENDING + + @task + def t_test_all_attributes(a: Optional[int], b: Optional[float], c: Optional[str], d: Optional[bool], + e: Optional[List[int]], f: Optional[List[FlyteFile]], + g: Optional[List[List[int]]], + h: Optional[List[Dict[int, bool]]], i: Optional[Dict[int, bool]], + j: Optional[Dict[int, FlyteFile]], + k: Optional[Dict[int, List[int]]], l: Optional[Dict[int, Dict[int, int]]], + m: Optional[dict], + n: Optional[FlyteFile], o: Optional[FlyteDirectory], + enum_status: Optional[Status]): + # Strict type checks for simple types + assert isinstance(a, int), f"a is not int, it's {type(a)}" + assert a == -1 + assert isinstance(b, float), f"b is not float, it's {type(b)}" + assert isinstance(c, str), f"c is not str, it's {type(c)}" + assert isinstance(d, bool), f"d is not bool, it's {type(d)}" + + # Strict type checks for List[int] + assert isinstance(e, list) and all(isinstance(i, int) + for i in e), "e is not List[int]" + + # Strict type checks for List[FlyteFile] + assert isinstance(f, list) and all(isinstance(i, FlyteFile) + for i in f), "f is not List[FlyteFile]" + + # Strict type checks for List[List[int]] + assert isinstance(g, list) and all( + isinstance(i, list) and all(isinstance(j, int) for j in i) for i in g), "g is not List[List[int]]" + + # Strict type checks for List[Dict[int, bool]] + assert isinstance(h, list) and all( + isinstance(i, dict) and all(isinstance(k, int) and isinstance(v, bool) for k, v in i.items()) for i in h + ), "h is not List[Dict[int, bool]]" + + # Strict type checks for Dict[int, bool] + assert isinstance(i, dict) and all( + isinstance(k, int) and isinstance(v, bool) for k, v in i.items()), "i is not Dict[int, bool]" + + # Strict type checks for Dict[int, FlyteFile] + assert isinstance(j, dict) and all( + isinstance(k, int) and isinstance(v, FlyteFile) for k, v in j.items()), "j is not Dict[int, FlyteFile]" + + # Strict type checks for Dict[int, List[int]] + assert isinstance(k, dict) and all( + isinstance(k, int) and isinstance(v, list) and all(isinstance(i, int) for i in v) for k, v in + k.items()), "k is not Dict[int, List[int]]" + + # Strict type checks for Dict[int, Dict[int, int]] + assert isinstance(l, dict) and all( + isinstance(k, int) and isinstance(v, dict) and all( + isinstance(sub_k, int) and isinstance(sub_v, int) for sub_k, sub_v in v.items()) + for k, v in l.items()), "l is not Dict[int, Dict[int, int]]" + + # Strict type check for a generic dict + assert isinstance(m, dict), "m is not dict" + + # Strict type check for FlyteFile + assert isinstance(n, FlyteFile), "n is not FlyteFile" + + # Strict type check for FlyteDirectory + assert isinstance(o, FlyteDirectory), "o is not FlyteDirectory" + + # Strict type check for Enum + assert isinstance(enum_status, Status), "enum_status is not Status" + + @workflow + def wf(dc: DC): + t_inner(dc.inner_dc) + t_test_all_attributes(a=dc.a, b=dc.b, c=dc.c, + d=dc.d, e=dc.e, f=dc.f, + g=dc.g, h=dc.h, i=dc.i, + j=dc.j, k=dc.k, l=dc.l, + m=dc.m, n=dc.n, o=dc.o, + enum_status=dc.enum_status) + + wf(dc=DC()) + + +def test_all_types_with_optional_and_none_in_dataclass_wf(): + @dataclass + class InnerDC: + a: Optional[int] = None + b: Optional[float] = None + c: Optional[str] = None + d: Optional[bool] = None + e: Optional[List[int]] = None + f: Optional[List[FlyteFile]] = None + g: Optional[List[List[int]]] = None + h: Optional[List[Dict[int, bool]]] = None + i: Optional[Dict[int, bool]] = None + j: Optional[Dict[int, FlyteFile]] = None + k: Optional[Dict[int, List[int]]] = None + l: Optional[Dict[int, Dict[int, int]]] = None + m: Optional[dict] = None + n: Optional[FlyteFile] = None + o: Optional[FlyteDirectory] = None + enum_status: Optional[Status] = None + + @dataclass + class DC: + a: Optional[int] = None + b: Optional[float] = None + c: Optional[str] = None + d: Optional[bool] = None + e: Optional[List[int]] = None + f: Optional[List[FlyteFile]] = None + g: Optional[List[List[int]]] = None + h: Optional[List[Dict[int, bool]]] = None + i: Optional[Dict[int, bool]] = None + j: Optional[Dict[int, FlyteFile]] = None + k: Optional[Dict[int, List[int]]] = None + l: Optional[Dict[int, Dict[int, int]]] = None + m: Optional[dict] = None + n: Optional[FlyteFile] = None + o: Optional[FlyteDirectory] = None + inner_dc: Optional[InnerDC] = None + enum_status: Optional[Status] = None + + @task + def t_inner(inner_dc: Optional[InnerDC]): + return inner_dc + + @task + def t_test_all_attributes(a: Optional[int], b: Optional[float], c: Optional[str], d: Optional[bool], + e: Optional[List[int]], f: Optional[List[FlyteFile]], + g: Optional[List[List[int]]], + h: Optional[List[Dict[int, bool]]], i: Optional[Dict[int, bool]], + j: Optional[Dict[int, FlyteFile]], + k: Optional[Dict[int, List[int]]], l: Optional[Dict[int, Dict[int, int]]], + m: Optional[dict], + n: Optional[FlyteFile], o: Optional[FlyteDirectory], + enum_status: Optional[Status]): + return + + @workflow + def wf(dc: DC): + t_inner(dc.inner_dc) + t_test_all_attributes(a=dc.a, b=dc.b, c=dc.c, + d=dc.d, e=dc.e, f=dc.f, + g=dc.g, h=dc.h, i=dc.i, + j=dc.j, k=dc.k, l=dc.l, + m=dc.m, n=dc.n, o=dc.o, + enum_status=dc.enum_status) + + wf(dc=DC()) + + +def test_union_in_dataclass_wf(): + @dataclass + class DC: + a: Union[int, bool, str, float] + b: Union[int, bool, str, float] + + @task + def add(a: Union[int, bool, str, float], b: Union[int, + bool, str, float]) -> Union[int, bool, str, float]: + return a + b # type: ignore + + @workflow + def wf(dc: DC) -> Union[int, bool, str, float]: + return add(dc.a, dc.b) + + assert wf(dc=DC(a=1, b=2)) == 3 + assert wf(dc=DC(a=True, b=False)) == True + assert wf(dc=DC(a=False, b=False)) == False + assert wf(dc=DC(a="hello", b="world")) == "helloworld" + assert wf(dc=DC(a=1.0, b=2.0)) == 3.0 + + @task + def add(dc1: DC, dc2: DC) -> Union[int, bool, str, float]: + return dc1.a + dc2.b # type: ignore + + @workflow + def wf(dc: DC) -> Union[int, bool, str, float]: + return add(dc, dc) + + assert wf(dc=DC(a=1, b=2)) == 3 + + @workflow + def wf(dc: DC) -> DC: + return dc + + assert wf(dc=DC(a=1, b=2)) == DC(a=1, b=2) diff --git a/tests/flytekit/unit/extras/pydantic_transformer/test_dataclass_in_pydantic_basemodel.py b/tests/flytekit/unit/extras/pydantic_transformer/test_dataclass_in_pydantic_basemodel.py new file mode 100644 index 0000000000..d77f7b17b0 --- /dev/null +++ b/tests/flytekit/unit/extras/pydantic_transformer/test_dataclass_in_pydantic_basemodel.py @@ -0,0 +1,105 @@ +from pydantic import BaseModel, Field + +from flytekit import task, workflow + + +def test_dataclasss_in_pydantic_basemodel(): + from dataclasses import dataclass + + @dataclass + class InnerBM: + a: int = -1 + b: float = 3.14 + c: str = "Hello, Flyte" + d: bool = False + + class BM(BaseModel): + a: int = -1 + b: float = 3.14 + c: str = "Hello, Flyte" + d: bool = False + inner_bm: InnerBM = Field(default_factory=lambda: InnerBM()) + + @task + def t_bm(bm: BM): + assert isinstance(bm, BM) + assert isinstance(bm.inner_bm, InnerBM) + + @task + def t_inner(inner_bm: InnerBM): + assert isinstance(inner_bm, InnerBM) + + @task + def t_test_primitive_attributes(a: int, b: float, c: str, d: bool): + assert isinstance(a, int), f"a is not int, it's {type(a)}" + assert a == -1 + assert isinstance(b, float), f"b is not float, it's {type(b)}" + assert b == 3.14 + assert isinstance(c, str), f"c is not str, it's {type(c)}" + assert c == "Hello, Flyte" + assert isinstance(d, bool), f"d is not bool, it's {type(d)}" + assert d is False + print("All primitive attributes passed strict type checks.") + + @workflow + def wf(bm: BM): + t_bm(bm=bm) + t_inner(inner_bm=bm.inner_bm) + t_test_primitive_attributes(a=bm.a, b=bm.b, c=bm.c, d=bm.d) + t_test_primitive_attributes( + a=bm.inner_bm.a, b=bm.inner_bm.b, c=bm.inner_bm.c, d=bm.inner_bm.d + ) + + bm = BM() + wf(bm=bm) + + +def test_pydantic_dataclasss_in_pydantic_basemodel(): + from pydantic.dataclasses import dataclass + + @dataclass + class InnerBM: + a: int = -1 + b: float = 3.14 + c: str = "Hello, Flyte" + d: bool = False + + class BM(BaseModel): + a: int = -1 + b: float = 3.14 + c: str = "Hello, Flyte" + d: bool = False + inner_bm: InnerBM = Field(default_factory=lambda: InnerBM()) + + @task + def t_bm(bm: BM): + assert isinstance(bm, BM) + assert isinstance(bm.inner_bm, InnerBM) + + @task + def t_inner(inner_bm: InnerBM): + assert isinstance(inner_bm, InnerBM) + + @task + def t_test_primitive_attributes(a: int, b: float, c: str, d: bool): + assert isinstance(a, int), f"a is not int, it's {type(a)}" + assert a == -1 + assert isinstance(b, float), f"b is not float, it's {type(b)}" + assert b == 3.14 + assert isinstance(c, str), f"c is not str, it's {type(c)}" + assert c == "Hello, Flyte" + assert isinstance(d, bool), f"d is not bool, it's {type(d)}" + assert d is False + print("All primitive attributes passed strict type checks.") + + @workflow + def wf(bm: BM): + t_bm(bm=bm) + t_inner(inner_bm=bm.inner_bm) + t_test_primitive_attributes(a=bm.a, b=bm.b, c=bm.c, d=bm.d) + t_test_primitive_attributes( + a=bm.inner_bm.a, b=bm.inner_bm.b, c=bm.inner_bm.c, d=bm.inner_bm.d + ) + + bm = BM() + wf(bm=bm) diff --git a/tests/flytekit/unit/extras/pydantic_transformer/test_pydantic_basemodel_in_dataclass.py b/tests/flytekit/unit/extras/pydantic_transformer/test_pydantic_basemodel_in_dataclass.py new file mode 100644 index 0000000000..9de8dfa41f --- /dev/null +++ b/tests/flytekit/unit/extras/pydantic_transformer/test_pydantic_basemodel_in_dataclass.py @@ -0,0 +1,138 @@ +import pytest +from pydantic import BaseModel + +from flytekit import task, workflow + +""" +This should be supported in the future. +Issue Link: https://github.com/flyteorg/flyte/issues/5925 +""" + + +def test_pydantic_basemodel_in_dataclass(): + from dataclasses import dataclass, field + + # Define InnerBM using Pydantic BaseModel + class InnerBM(BaseModel): + a: int = -1 + b: float = 3.14 + c: str = "Hello, Flyte" + d: bool = False + + # Define the dataclass DC + @dataclass + class DC: + a: int = -1 + b: float = 3.14 + c: str = "Hello, Flyte" + d: bool = False + inner_bm: InnerBM = field(default_factory=lambda: InnerBM()) + + # Task to check DC instance + @task + def t_dc(dc: DC): + assert isinstance(dc, DC) + assert isinstance(dc.inner_bm, InnerBM) + + # Task to check InnerBM instance + @task + def t_inner(inner_bm: InnerBM): + assert isinstance(inner_bm, InnerBM) + + # Task to check primitive attributes + @task + def t_test_primitive_attributes(a: int, b: float, c: str, d: bool): + assert isinstance(a, int), f"a is not int, it's {type(a)}" + assert a == -1 + assert isinstance(b, float), f"b is not float, it's {type(b)}" + assert b == 3.14 + assert isinstance(c, str), f"c is not str, it's {type(c)}" + assert c == "Hello, Flyte" + assert isinstance(d, bool), f"d is not bool, it's {type(d)}" + assert d is False + print("All primitive attributes passed strict type checks.") + + # Define the workflow + @workflow + def wf(dc: DC): + t_dc(dc=dc) + t_inner(inner_bm=dc.inner_bm) + t_test_primitive_attributes(a=dc.a, b=dc.b, c=dc.c, d=dc.d) + t_test_primitive_attributes( + a=dc.inner_bm.a, b=dc.inner_bm.b, c=dc.inner_bm.c, d=dc.inner_bm.d + ) + + # Create an instance of DC and run the workflow + dc_instance = DC() + with pytest.raises(Exception) as excinfo: + wf(dc=dc_instance) + + # Assert that the error message contains "UnserializableField" + assert "is not serializable" in str( + excinfo.value + ), f"Unexpected error: {excinfo.value}" + + +def test_pydantic_basemodel_in_pydantic_dataclass(): + from pydantic import Field + from pydantic.dataclasses import dataclass + + # Define InnerBM using Pydantic BaseModel + class InnerBM(BaseModel): + a: int = -1 + b: float = 3.14 + c: str = "Hello, Flyte" + d: bool = False + + # Define the Pydantic dataclass DC + @dataclass + class DC: + a: int = -1 + b: float = 3.14 + c: str = "Hello, Flyte" + d: bool = False + inner_bm: InnerBM = Field(default_factory=lambda: InnerBM()) + + # Task to check DC instance + @task + def t_dc(dc: DC): + assert isinstance(dc, DC) + assert isinstance(dc.inner_bm, InnerBM) + + # Task to check InnerBM instance + @task + def t_inner(inner_bm: InnerBM): + assert isinstance(inner_bm, InnerBM) + + # Task to check primitive attributes + @task + def t_test_primitive_attributes(a: int, b: float, c: str, d: bool): + assert isinstance(a, int), f"a is not int, it's {type(a)}" + assert a == -1 + assert isinstance(b, float), f"b is not float, it's {type(b)}" + assert b == 3.14 + assert isinstance(c, str), f"c is not str, it's {type(c)}" + assert c == "Hello, Flyte" + assert isinstance(d, bool), f"d is not bool, it's {type(d)}" + assert d is False + print("All primitive attributes passed strict type checks.") + + # Define the workflow + @workflow + def wf(dc: DC): + t_dc(dc=dc) + t_inner(inner_bm=dc.inner_bm) + t_test_primitive_attributes(a=dc.a, b=dc.b, c=dc.c, d=dc.d) + t_test_primitive_attributes( + a=dc.inner_bm.a, b=dc.inner_bm.b, c=dc.inner_bm.c, d=dc.inner_bm.d + ) + + # Create an instance of DC and run the workflow + dc_instance = DC() + with pytest.raises(Exception) as excinfo: + wf(dc=dc_instance) + + # Assert that the error message contains "UnserializableField" + assert "is not serializable" in str( + excinfo.value + ), f"Unexpected error: {excinfo.value}" diff --git a/tests/flytekit/unit/extras/pydantic_transformer/test_pydantic_basemodel_transformer.py b/tests/flytekit/unit/extras/pydantic_transformer/test_pydantic_basemodel_transformer.py new file mode 100644 index 0000000000..05ba54903f --- /dev/null +++ b/tests/flytekit/unit/extras/pydantic_transformer/test_pydantic_basemodel_transformer.py @@ -0,0 +1,982 @@ +import os +import tempfile +from enum import Enum +from typing import Dict, List, Optional, Union +from unittest.mock import patch + +import pytest +from google.protobuf import json_format as _json_format +from google.protobuf import struct_pb2 as _struct +from pydantic import BaseModel, Field + +from flytekit import task, workflow +from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.type_engine import TypeEngine +from flytekit.models.literals import Literal, Scalar +from flytekit.types.directory import FlyteDirectory +from flytekit.types.file import FlyteFile +from flytekit.types.schema import FlyteSchema +from flytekit.types.structured import StructuredDataset + + +class Status(Enum): + PENDING = "pending" + APPROVED = "approved" + REJECTED = "rejected" + + +@pytest.fixture +def local_dummy_file(): + fd, path = tempfile.mkstemp() + try: + with os.fdopen(fd, "w") as tmp: + tmp.write("Hello FlyteFile") + yield path + finally: + os.remove(path) + + +@pytest.fixture +def local_dummy_directory(): + temp_dir = tempfile.TemporaryDirectory() + try: + with open(os.path.join(temp_dir.name, "file"), "w") as tmp: + tmp.write("Hello FlyteDirectory") + yield temp_dir.name + finally: + temp_dir.cleanup() + + +def test_flytetypes_in_pydantic_basemodel_wf(local_dummy_file, local_dummy_directory): + class InnerBM(BaseModel): + flytefile: FlyteFile = Field( + default_factory=lambda: FlyteFile(local_dummy_file) + ) + flytedir: FlyteDirectory = Field( + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) + + class BM(BaseModel): + flytefile: FlyteFile = Field( + default_factory=lambda: FlyteFile(local_dummy_file) + ) + flytedir: FlyteDirectory = Field( + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) + inner_bm: InnerBM = Field(default_factory=lambda: InnerBM()) + + @task + def t1(path: FlyteFile) -> FlyteFile: + return path + + @task + def t2(path: FlyteDirectory) -> FlyteDirectory: + return path + + @workflow + def wf(bm: BM) -> (FlyteFile, FlyteFile, FlyteDirectory, FlyteDirectory): + f1 = t1(path=bm.flytefile) + f2 = t1(path=bm.inner_bm.flytefile) + d1 = t2(path=bm.flytedir) + d2 = t2(path=bm.inner_bm.flytedir) + return f1, f2, d1, d2 + + o1, o2, o3, o4 = wf(bm=BM()) + with open(o1, "r") as fh: + assert fh.read() == "Hello FlyteFile" + + with open(o2, "r") as fh: + assert fh.read() == "Hello FlyteFile" + + with open(os.path.join(o3, "file"), "r") as fh: + assert fh.read() == "Hello FlyteDirectory" + + with open(os.path.join(o4, "file"), "r") as fh: + assert fh.read() == "Hello FlyteDirectory" + + +def test_all_types_in_pydantic_basemodel_wf(local_dummy_file, local_dummy_directory): + class InnerBM(BaseModel): + a: int = -1 + b: float = 2.1 + c: str = "Hello, Flyte" + d: bool = False + e: List[int] = Field(default_factory=lambda: [0, 1, 2, -1, -2]) + f: List[FlyteFile] = Field( + default_factory=lambda: [FlyteFile(local_dummy_file)] + ) + g: List[List[int]] = Field(default_factory=lambda: [[0], [1], [-1]]) + h: List[Dict[int, bool]] = Field( + default_factory=lambda: [{0: False}, {1: True}, {-1: True}] + ) + i: Dict[int, bool] = Field( + default_factory=lambda: {0: False, 1: True, -1: False} + ) + j: Dict[int, FlyteFile] = Field( + default_factory=lambda: { + 0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file), + } + ) + k: Dict[int, List[int]] = Field(default_factory=lambda: {0: [0, 1, -1]}) + l: Dict[int, Dict[int, int]] = Field(default_factory=lambda: {1: {-1: 0}}) + m: dict = Field(default_factory=lambda: {"key": "value"}) + n: FlyteFile = Field(default_factory=lambda: FlyteFile(local_dummy_file)) + o: FlyteDirectory = Field( + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) + enum_status: Status = Field(default=Status.PENDING) + + class BM(BaseModel): + a: int = -1 + b: float = 2.1 + c: str = "Hello, Flyte" + d: bool = False + e: List[int] = Field(default_factory=lambda: [0, 1, 2, -1, -2]) + f: List[FlyteFile] = Field( + default_factory=lambda: [ + FlyteFile(local_dummy_file), + ] + ) + g: List[List[int]] = Field(default_factory=lambda: [[0], [1], [-1]]) + h: List[Dict[int, bool]] = Field( + default_factory=lambda: [{0: False}, {1: True}, {-1: True}] + ) + i: Dict[int, bool] = Field( + default_factory=lambda: {0: False, 1: True, -1: False} + ) + j: Dict[int, FlyteFile] = Field( + default_factory=lambda: { + 0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file), + } + ) + k: Dict[int, List[int]] = Field(default_factory=lambda: {0: [0, 1, -1]}) + l: Dict[int, Dict[int, int]] = Field(default_factory=lambda: {1: {-1: 0}}) + m: dict = Field(default_factory=lambda: {"key": "value"}) + n: FlyteFile = Field(default_factory=lambda: FlyteFile(local_dummy_file)) + o: FlyteDirectory = Field( + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) + inner_bm: InnerBM = Field(default_factory=lambda: InnerBM()) + enum_status: Status = Field(default=Status.PENDING) + + @task + def t_inner(inner_bm: InnerBM): + assert type(inner_bm) is InnerBM + + # f: List[FlyteFile] + for ff in inner_bm.f: + assert type(ff) is FlyteFile + with open(ff, "r") as f: + assert f.read() == "Hello FlyteFile" + # j: Dict[int, FlyteFile] + for _, ff in inner_bm.j.items(): + assert type(ff) is FlyteFile + with open(ff, "r") as f: + assert f.read() == "Hello FlyteFile" + # n: FlyteFile + assert type(inner_bm.n) is FlyteFile + with open(inner_bm.n, "r") as f: + assert f.read() == "Hello FlyteFile" + # o: FlyteDirectory + assert type(inner_bm.o) is FlyteDirectory + assert not inner_bm.o.downloaded + with open(os.path.join(inner_bm.o, "file"), "r") as fh: + assert fh.read() == "Hello FlyteDirectory" + assert inner_bm.o.downloaded + + # enum: Status + assert inner_bm.enum_status == Status.PENDING + + @task + def t_test_all_attributes( + a: int, + b: float, + c: str, + d: bool, + e: List[int], + f: List[FlyteFile], + g: List[List[int]], + h: List[Dict[int, bool]], + i: Dict[int, bool], + j: Dict[int, FlyteFile], + k: Dict[int, List[int]], + l: Dict[int, Dict[int, int]], + m: dict, + n: FlyteFile, + o: FlyteDirectory, + enum_status: Status, + ): + # Strict type checks for simple types + assert isinstance(a, int), f"a is not int, it's {type(a)}" + assert a == -1 + assert isinstance(b, float), f"b is not float, it's {type(b)}" + assert isinstance(c, str), f"c is not str, it's {type(c)}" + assert isinstance(d, bool), f"d is not bool, it's {type(d)}" + + # Strict type checks for List[int] + assert isinstance(e, list) and all( + isinstance(i, int) for i in e + ), "e is not List[int]" + + # Strict type checks for List[FlyteFile] + assert isinstance(f, list) and all( + isinstance(i, FlyteFile) for i in f + ), "f is not List[FlyteFile]" + + # Strict type checks for List[List[int]] + assert isinstance(g, list) and all( + isinstance(i, list) and all(isinstance(j, int) for j in i) for i in g + ), "g is not List[List[int]]" + + # Strict type checks for List[Dict[int, bool]] + assert isinstance(h, list) and all( + isinstance(i, dict) + and all(isinstance(k, int) and isinstance(v, bool) for k, v in i.items()) + for i in h + ), "h is not List[Dict[int, bool]]" + + # Strict type checks for Dict[int, bool] + assert isinstance(i, dict) and all( + isinstance(k, int) and isinstance(v, bool) for k, v in i.items() + ), "i is not Dict[int, bool]" + + # Strict type checks for Dict[int, FlyteFile] + assert isinstance(j, dict) and all( + isinstance(k, int) and isinstance(v, FlyteFile) for k, v in j.items() + ), "j is not Dict[int, FlyteFile]" + + # Strict type checks for Dict[int, List[int]] + assert isinstance(k, dict) and all( + isinstance(k, int) + and isinstance(v, list) + and all(isinstance(i, int) for i in v) + for k, v in k.items() + ), "k is not Dict[int, List[int]]" + + # Strict type checks for Dict[int, Dict[int, int]] + assert isinstance(l, dict) and all( + isinstance(k, int) + and isinstance(v, dict) + and all( + isinstance(sub_k, int) and isinstance(sub_v, int) + for sub_k, sub_v in v.items() + ) + for k, v in l.items() + ), "l is not Dict[int, Dict[int, int]]" + + # Strict type check for a generic dict + assert isinstance(m, dict), "m is not dict" + + # Strict type check for FlyteFile + assert isinstance(n, FlyteFile), "n is not FlyteFile" + + # Strict type check for FlyteDirectory + assert isinstance(o, FlyteDirectory), "o is not FlyteDirectory" + + # Strict type check for Enum + assert isinstance(enum_status, Status), "enum_status is not Status" + + print("All attributes passed strict type checks.") + + @workflow + def wf(bm: BM): + t_inner(bm.inner_bm) + t_test_all_attributes( + a=bm.a, + b=bm.b, + c=bm.c, + d=bm.d, + e=bm.e, + f=bm.f, + g=bm.g, + h=bm.h, + i=bm.i, + j=bm.j, + k=bm.k, + l=bm.l, + m=bm.m, + n=bm.n, + o=bm.o, + enum_status=bm.enum_status, + ) + + t_test_all_attributes( + a=bm.inner_bm.a, + b=bm.inner_bm.b, + c=bm.inner_bm.c, + d=bm.inner_bm.d, + e=bm.inner_bm.e, + f=bm.inner_bm.f, + g=bm.inner_bm.g, + h=bm.inner_bm.h, + i=bm.inner_bm.i, + j=bm.inner_bm.j, + k=bm.inner_bm.k, + l=bm.inner_bm.l, + m=bm.inner_bm.m, + n=bm.inner_bm.n, + o=bm.inner_bm.o, + enum_status=bm.inner_bm.enum_status, + ) + + wf(bm=BM()) + + +def test_all_types_with_optional_in_pydantic_basemodel_wf( + local_dummy_file, local_dummy_directory +): + class InnerBM(BaseModel): + a: Optional[int] = -1 + b: Optional[float] = 2.1 + c: Optional[str] = "Hello, Flyte" + d: Optional[bool] = False + e: Optional[List[int]] = Field(default_factory=lambda: [0, 1, 2, -1, -2]) + f: Optional[List[FlyteFile]] = Field( + default_factory=lambda: [FlyteFile(local_dummy_file)] + ) + g: Optional[List[List[int]]] = Field(default_factory=lambda: [[0], [1], [-1]]) + h: Optional[List[Dict[int, bool]]] = Field( + default_factory=lambda: [{0: False}, {1: True}, {-1: True}] + ) + i: Optional[Dict[int, bool]] = Field( + default_factory=lambda: {0: False, 1: True, -1: False} + ) + j: Optional[Dict[int, FlyteFile]] = Field( + default_factory=lambda: { + 0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file), + } + ) + k: Optional[Dict[int, List[int]]] = Field( + default_factory=lambda: {0: [0, 1, -1]} + ) + l: Optional[Dict[int, Dict[int, int]]] = Field( + default_factory=lambda: {1: {-1: 0}} + ) + m: Optional[dict] = Field(default_factory=lambda: {"key": "value"}) + n: Optional[FlyteFile] = Field( + default_factory=lambda: FlyteFile(local_dummy_file) + ) + o: Optional[FlyteDirectory] = Field( + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) + enum_status: Optional[Status] = Field(default=Status.PENDING) + + class BM(BaseModel): + a: Optional[int] = -1 + b: Optional[float] = 2.1 + c: Optional[str] = "Hello, Flyte" + d: Optional[bool] = False + e: Optional[List[int]] = Field(default_factory=lambda: [0, 1, 2, -1, -2]) + f: Optional[List[FlyteFile]] = Field( + default_factory=lambda: [FlyteFile(local_dummy_file)] + ) + g: Optional[List[List[int]]] = Field(default_factory=lambda: [[0], [1], [-1]]) + h: Optional[List[Dict[int, bool]]] = Field( + default_factory=lambda: [{0: False}, {1: True}, {-1: True}] + ) + i: Optional[Dict[int, bool]] = Field( + default_factory=lambda: {0: False, 1: True, -1: False} + ) + j: Optional[Dict[int, FlyteFile]] = Field( + default_factory=lambda: { + 0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file), + } + ) + k: Optional[Dict[int, List[int]]] = Field( + default_factory=lambda: {0: [0, 1, -1]} + ) + l: Optional[Dict[int, Dict[int, int]]] = Field( + default_factory=lambda: {1: {-1: 0}} + ) + m: Optional[dict] = Field(default_factory=lambda: {"key": "value"}) + n: Optional[FlyteFile] = Field( + default_factory=lambda: FlyteFile(local_dummy_file) + ) + o: Optional[FlyteDirectory] = Field( + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) + inner_bm: Optional[InnerBM] = Field(default_factory=lambda: InnerBM()) + enum_status: Optional[Status] = Field(default=Status.PENDING) + + @task + def t_inner(inner_bm: InnerBM): + assert type(inner_bm) is InnerBM + + # f: List[FlyteFile] + for ff in inner_bm.f: + assert type(ff) is FlyteFile + with open(ff, "r") as f: + assert f.read() == "Hello FlyteFile" + # j: Dict[int, FlyteFile] + for _, ff in inner_bm.j.items(): + assert type(ff) is FlyteFile + with open(ff, "r") as f: + assert f.read() == "Hello FlyteFile" + # n: FlyteFile + assert type(inner_bm.n) is FlyteFile + with open(inner_bm.n, "r") as f: + assert f.read() == "Hello FlyteFile" + # o: FlyteDirectory + assert type(inner_bm.o) is FlyteDirectory + assert not inner_bm.o.downloaded + with open(os.path.join(inner_bm.o, "file"), "r") as fh: + assert fh.read() == "Hello FlyteDirectory" + assert inner_bm.o.downloaded + + # enum: Status + assert inner_bm.enum_status == Status.PENDING + + @task + def t_test_all_attributes( + a: Optional[int], + b: Optional[float], + c: Optional[str], + d: Optional[bool], + e: Optional[List[int]], + f: Optional[List[FlyteFile]], + g: Optional[List[List[int]]], + h: Optional[List[Dict[int, bool]]], + i: Optional[Dict[int, bool]], + j: Optional[Dict[int, FlyteFile]], + k: Optional[Dict[int, List[int]]], + l: Optional[Dict[int, Dict[int, int]]], + m: Optional[dict], + n: Optional[FlyteFile], + o: Optional[FlyteDirectory], + enum_status: Optional[Status], + ): + # Strict type checks for simple types + assert isinstance(a, int), f"a is not int, it's {type(a)}" + assert a == -1 + assert isinstance(b, float), f"b is not float, it's {type(b)}" + assert isinstance(c, str), f"c is not str, it's {type(c)}" + assert isinstance(d, bool), f"d is not bool, it's {type(d)}" + + # Strict type checks for List[int] + assert isinstance(e, list) and all( + isinstance(i, int) for i in e + ), "e is not List[int]" + + # Strict type checks for List[FlyteFile] + assert isinstance(f, list) and all( + isinstance(i, FlyteFile) for i in f + ), "f is not List[FlyteFile]" + + # Strict type checks for List[List[int]] + assert isinstance(g, list) and all( + isinstance(i, list) and all(isinstance(j, int) for j in i) for i in g + ), "g is not List[List[int]]" + + # Strict type checks for List[Dict[int, bool]] + assert isinstance(h, list) and all( + isinstance(i, dict) + and all(isinstance(k, int) and isinstance(v, bool) for k, v in i.items()) + for i in h + ), "h is not List[Dict[int, bool]]" + + # Strict type checks for Dict[int, bool] + assert isinstance(i, dict) and all( + isinstance(k, int) and isinstance(v, bool) for k, v in i.items() + ), "i is not Dict[int, bool]" + + # Strict type checks for Dict[int, FlyteFile] + assert isinstance(j, dict) and all( + isinstance(k, int) and isinstance(v, FlyteFile) for k, v in j.items() + ), "j is not Dict[int, FlyteFile]" + + # Strict type checks for Dict[int, List[int]] + assert isinstance(k, dict) and all( + isinstance(k, int) + and isinstance(v, list) + and all(isinstance(i, int) for i in v) + for k, v in k.items() + ), "k is not Dict[int, List[int]]" + + # Strict type checks for Dict[int, Dict[int, int]] + assert isinstance(l, dict) and all( + isinstance(k, int) + and isinstance(v, dict) + and all( + isinstance(sub_k, int) and isinstance(sub_v, int) + for sub_k, sub_v in v.items() + ) + for k, v in l.items() + ), "l is not Dict[int, Dict[int, int]]" + + # Strict type check for a generic dict + assert isinstance(m, dict), "m is not dict" + + # Strict type check for FlyteFile + assert isinstance(n, FlyteFile), "n is not FlyteFile" + + # Strict type check for FlyteDirectory + assert isinstance(o, FlyteDirectory), "o is not FlyteDirectory" + + # Strict type check for Enum + assert isinstance(enum_status, Status), "enum_status is not Status" + + @workflow + def wf(bm: BM): + t_inner(bm.inner_bm) + t_test_all_attributes( + a=bm.a, + b=bm.b, + c=bm.c, + d=bm.d, + e=bm.e, + f=bm.f, + g=bm.g, + h=bm.h, + i=bm.i, + j=bm.j, + k=bm.k, + l=bm.l, + m=bm.m, + n=bm.n, + o=bm.o, + enum_status=bm.enum_status, + ) + + wf(bm=BM()) + + +def test_all_types_with_optional_and_none_in_pydantic_basemodel_wf( + local_dummy_file, local_dummy_directory +): + class InnerBM(BaseModel): + a: Optional[int] = None + b: Optional[float] = None + c: Optional[str] = None + d: Optional[bool] = None + e: Optional[List[int]] = None + f: Optional[List[FlyteFile]] = None + g: Optional[List[List[int]]] = None + h: Optional[List[Dict[int, bool]]] = None + i: Optional[Dict[int, bool]] = None + j: Optional[Dict[int, FlyteFile]] = None + k: Optional[Dict[int, List[int]]] = None + l: Optional[Dict[int, Dict[int, int]]] = None + m: Optional[dict] = None + n: Optional[FlyteFile] = None + o: Optional[FlyteDirectory] = None + enum_status: Optional[Status] = None + + class BM(BaseModel): + a: Optional[int] = None + b: Optional[float] = None + c: Optional[str] = None + d: Optional[bool] = None + e: Optional[List[int]] = None + f: Optional[List[FlyteFile]] = None + g: Optional[List[List[int]]] = None + h: Optional[List[Dict[int, bool]]] = None + i: Optional[Dict[int, bool]] = None + j: Optional[Dict[int, FlyteFile]] = None + k: Optional[Dict[int, List[int]]] = None + l: Optional[Dict[int, Dict[int, int]]] = None + m: Optional[dict] = None + n: Optional[FlyteFile] = None + o: Optional[FlyteDirectory] = None + inner_bm: Optional[InnerBM] = None + enum_status: Optional[Status] = None + + @task + def t_inner(inner_bm: Optional[InnerBM]): + return inner_bm + + @task + def t_test_all_attributes( + a: Optional[int], + b: Optional[float], + c: Optional[str], + d: Optional[bool], + e: Optional[List[int]], + f: Optional[List[FlyteFile]], + g: Optional[List[List[int]]], + h: Optional[List[Dict[int, bool]]], + i: Optional[Dict[int, bool]], + j: Optional[Dict[int, FlyteFile]], + k: Optional[Dict[int, List[int]]], + l: Optional[Dict[int, Dict[int, int]]], + m: Optional[dict], + n: Optional[FlyteFile], + o: Optional[FlyteDirectory], + enum_status: Optional[Status], + ): + return + + @workflow + def wf(bm: BM): + t_inner(bm.inner_bm) + t_test_all_attributes( + a=bm.a, + b=bm.b, + c=bm.c, + d=bm.d, + e=bm.e, + f=bm.f, + g=bm.g, + h=bm.h, + i=bm.i, + j=bm.j, + k=bm.k, + l=bm.l, + m=bm.m, + n=bm.n, + o=bm.o, + enum_status=bm.enum_status, + ) + + wf(bm=BM()) + + +def test_input_from_flyte_console_pydantic_basemodel( + local_dummy_file, local_dummy_directory +): + # Flyte Console will send the input data as protobuf Struct + + class InnerBM(BaseModel): + a: int = -1 + b: float = 2.1 + c: str = "Hello, Flyte" + d: bool = False + e: List[int] = Field(default_factory=lambda: [0, 1, 2, -1, -2]) + f: List[FlyteFile] = Field( + default_factory=lambda: [FlyteFile(local_dummy_file)] + ) + g: List[List[int]] = Field(default_factory=lambda: [[0], [1], [-1]]) + h: List[Dict[int, bool]] = Field( + default_factory=lambda: [{0: False}, {1: True}, {-1: True}] + ) + i: Dict[int, bool] = Field( + default_factory=lambda: {0: False, 1: True, -1: False} + ) + j: Dict[int, FlyteFile] = Field( + default_factory=lambda: { + 0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file), + } + ) + k: Dict[int, List[int]] = Field(default_factory=lambda: {0: [0, 1, -1]}) + l: Dict[int, Dict[int, int]] = Field(default_factory=lambda: {1: {-1: 0}}) + m: dict = Field(default_factory=lambda: {"key": "value"}) + n: FlyteFile = Field(default_factory=lambda: FlyteFile(local_dummy_file)) + o: FlyteDirectory = Field( + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) + enum_status: Status = Field(default=Status.PENDING) + + class BM(BaseModel): + a: int = -1 + b: float = 2.1 + c: str = "Hello, Flyte" + d: bool = False + e: List[int] = Field(default_factory=lambda: [0, 1, 2, -1, -2]) + f: List[FlyteFile] = Field( + default_factory=lambda: [ + FlyteFile(local_dummy_file), + ] + ) + g: List[List[int]] = Field(default_factory=lambda: [[0], [1], [-1]]) + h: List[Dict[int, bool]] = Field( + default_factory=lambda: [{0: False}, {1: True}, {-1: True}] + ) + i: Dict[int, bool] = Field( + default_factory=lambda: {0: False, 1: True, -1: False} + ) + j: Dict[int, FlyteFile] = Field( + default_factory=lambda: { + 0: FlyteFile(local_dummy_file), + 1: FlyteFile(local_dummy_file), + -1: FlyteFile(local_dummy_file), + } + ) + k: Dict[int, List[int]] = Field(default_factory=lambda: {0: [0, 1, -1]}) + l: Dict[int, Dict[int, int]] = Field(default_factory=lambda: {1: {-1: 0}}) + m: dict = Field(default_factory=lambda: {"key": "value"}) + n: FlyteFile = Field(default_factory=lambda: FlyteFile(local_dummy_file)) + o: FlyteDirectory = Field( + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) + inner_bm: InnerBM = Field(default_factory=lambda: InnerBM()) + enum_status: Status = Field(default=Status.PENDING) + + @task + def t_inner(inner_bm: InnerBM): + assert type(inner_bm) is InnerBM + + # f: List[FlyteFile] + for ff in inner_bm.f: + assert type(ff) is FlyteFile + with open(ff, "r") as f: + assert f.read() == "Hello FlyteFile" + # j: Dict[int, FlyteFile] + for _, ff in inner_bm.j.items(): + assert type(ff) is FlyteFile + with open(ff, "r") as f: + assert f.read() == "Hello FlyteFile" + # n: FlyteFile + assert type(inner_bm.n) is FlyteFile + with open(inner_bm.n, "r") as f: + assert f.read() == "Hello FlyteFile" + # o: FlyteDirectory + assert type(inner_bm.o) is FlyteDirectory + assert not inner_bm.o.downloaded + with open(os.path.join(inner_bm.o, "file"), "r") as fh: + assert fh.read() == "Hello FlyteDirectory" + assert inner_bm.o.downloaded + + # enum: Status + assert inner_bm.enum_status == Status.PENDING + + def t_test_all_attributes( + a: int, + b: float, + c: str, + d: bool, + e: List[int], + f: List[FlyteFile], + g: List[List[int]], + h: List[Dict[int, bool]], + i: Dict[int, bool], + j: Dict[int, FlyteFile], + k: Dict[int, List[int]], + l: Dict[int, Dict[int, int]], + m: dict, + n: FlyteFile, + o: FlyteDirectory, + enum_status: Status, + ): + # Strict type checks for simple types + assert isinstance(a, int), f"a is not int, it's {type(a)}" + assert a == -1 + assert isinstance(b, float), f"b is not float, it's {type(b)}" + assert isinstance(c, str), f"c is not str, it's {type(c)}" + assert isinstance(d, bool), f"d is not bool, it's {type(d)}" + + # Strict type checks for List[int] + assert isinstance(e, list) and all( + isinstance(i, int) for i in e + ), "e is not List[int]" + + # Strict type checks for List[FlyteFile] + assert isinstance(f, list) and all( + isinstance(i, FlyteFile) for i in f + ), "f is not List[FlyteFile]" + + # Strict type checks for List[List[int]] + assert isinstance(g, list) and all( + isinstance(i, list) and all(isinstance(j, int) for j in i) for i in g + ), "g is not List[List[int]]" + + # Strict type checks for List[Dict[int, bool]] + assert isinstance(h, list) and all( + isinstance(i, dict) + and all(isinstance(k, int) and isinstance(v, bool) for k, v in i.items()) + for i in h + ), "h is not List[Dict[int, bool]]" + + # Strict type checks for Dict[int, bool] + assert isinstance(i, dict) and all( + isinstance(k, int) and isinstance(v, bool) for k, v in i.items() + ), "i is not Dict[int, bool]" + + # Strict type checks for Dict[int, FlyteFile] + assert isinstance(j, dict) and all( + isinstance(k, int) and isinstance(v, FlyteFile) for k, v in j.items() + ), "j is not Dict[int, FlyteFile]" + + # Strict type checks for Dict[int, List[int]] + assert isinstance(k, dict) and all( + isinstance(k, int) + and isinstance(v, list) + and all(isinstance(i, int) for i in v) + for k, v in k.items() + ), "k is not Dict[int, List[int]]" + + # Strict type checks for Dict[int, Dict[int, int]] + assert isinstance(l, dict) and all( + isinstance(k, int) + and isinstance(v, dict) + and all( + isinstance(sub_k, int) and isinstance(sub_v, int) + for sub_k, sub_v in v.items() + ) + for k, v in l.items() + ), "l is not Dict[int, Dict[int, int]]" + + # Strict type check for a generic dict + assert isinstance(m, dict), "m is not dict" + + # Strict type check for FlyteFile + assert isinstance(n, FlyteFile), "n is not FlyteFile" + + # Strict type check for FlyteDirectory + assert isinstance(o, FlyteDirectory), "o is not FlyteDirectory" + + # Strict type check for Enum + assert isinstance(enum_status, Status), "enum_status is not Status" + + print("All attributes passed strict type checks.") + + # This is the old dataclass serialization behavior. + # https://github.com/flyteorg/flytekit/blob/94786cfd4a5c2c3b23ac29bmd6f04d0553fa1beb/flytekit/core/type_engine.py#L702-L728 + bm = BM() + json_str = bm.model_dump_json() + upstream_output = Literal( + scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())) + ) + + downstream_input = TypeEngine.to_python_value( + FlyteContextManager.current_context(), upstream_output, BM + ) + t_inner(downstream_input.inner_bm) + t_test_all_attributes( + a=downstream_input.a, + b=downstream_input.b, + c=downstream_input.c, + d=downstream_input.d, + e=downstream_input.e, + f=downstream_input.f, + g=downstream_input.g, + h=downstream_input.h, + i=downstream_input.i, + j=downstream_input.j, + k=downstream_input.k, + l=downstream_input.l, + m=downstream_input.m, + n=downstream_input.n, + o=downstream_input.o, + enum_status=downstream_input.enum_status, + ) + t_test_all_attributes( + a=downstream_input.inner_bm.a, + b=downstream_input.inner_bm.b, + c=downstream_input.inner_bm.c, + d=downstream_input.inner_bm.d, + e=downstream_input.inner_bm.e, + f=downstream_input.inner_bm.f, + g=downstream_input.inner_bm.g, + h=downstream_input.inner_bm.h, + i=downstream_input.inner_bm.i, + j=downstream_input.inner_bm.j, + k=downstream_input.inner_bm.k, + l=downstream_input.inner_bm.l, + m=downstream_input.inner_bm.m, + n=downstream_input.inner_bm.n, + o=downstream_input.inner_bm.o, + enum_status=downstream_input.inner_bm.enum_status, + ) + + +def test_flyte_types_deserialization_not_called_when_using_constructor( + local_dummy_file, local_dummy_directory +): + # Mocking both FlyteFilePathTransformer and FlyteDirectoryPathTransformer + with patch( + "flytekit.types.file.FlyteFilePathTransformer.to_python_value" + ) as mock_file_to_python_value, patch( + "flytekit.types.directory.FlyteDirToMultipartBlobTransformer.to_python_value" + ) as mock_directory_to_python_value, patch( + "flytekit.types.structured.StructuredDatasetTransformerEngine.to_python_value" + ) as mock_structured_dataset_to_python_value, patch( + "flytekit.types.schema.FlyteSchemaTransformer.to_python_value" + ) as mock_schema_to_python_value: + + # Define your Pydantic model + class BM(BaseModel): + ff: FlyteFile = Field(default_factory=lambda: FlyteFile(local_dummy_file)) + fd: FlyteDirectory = Field( + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) + sd: StructuredDataset = Field(default_factory=lambda: StructuredDataset()) + fsc: FlyteSchema = Field(default_factory=lambda: FlyteSchema()) + + # Create an instance of BM (should not call the deserialization) + BM() + + mock_file_to_python_value.assert_not_called() + mock_directory_to_python_value.assert_not_called() + mock_structured_dataset_to_python_value.assert_not_called() + mock_schema_to_python_value.assert_not_called() + + +def test_flyte_types_deserialization_called_once_when_using_model_validate_json( + local_dummy_file, local_dummy_directory +): + """ + It's hard to mock flyte schema and structured dataset in tests, so we will only test FlyteFile and FlyteDirectory + """ + with patch( + "flytekit.types.file.FlyteFilePathTransformer.to_python_value" + ) as mock_file_to_python_value, patch( + "flytekit.types.directory.FlyteDirToMultipartBlobTransformer.to_python_value" + ) as mock_directory_to_python_value: + # Define your Pydantic model + class BM(BaseModel): + ff: FlyteFile = Field(default_factory=lambda: FlyteFile(local_dummy_file)) + fd: FlyteDirectory = Field( + default_factory=lambda: FlyteDirectory(local_dummy_directory) + ) + + # Create instances of FlyteFile and FlyteDirectory + bm = BM( + ff=FlyteFile(local_dummy_file), fd=FlyteDirectory(local_dummy_directory) + ) + + # Serialize and Deserialize with model_validate_json + json_str = bm.model_dump_json() + bm.model_validate_json( + json_data=json_str, strict=False, context={"deserialize": True} + ) + + # Assert that the to_python_value method was called once + mock_file_to_python_value.assert_called_once() + mock_directory_to_python_value.assert_called_once() + + +def test_union_in_basemodel_wf(): + class bm(BaseModel): + a: Union[int, bool, str, float] + b: Union[int, bool, str, float] + + @task + def add( + a: Union[int, bool, str, float], b: Union[int, bool, str, float] + ) -> Union[int, bool, str, float]: + return a + b # type: ignore + + @workflow + def wf(bm: bm) -> Union[int, bool, str, float]: + return add(bm.a, bm.b) + + assert wf(bm=bm(a=1, b=2)) == 3 + assert wf(bm=bm(a=True, b=False)) == True + assert wf(bm=bm(a=False, b=False)) == False + assert wf(bm=bm(a="hello", b="world")) == "helloworld" + assert wf(bm=bm(a=1.0, b=2.0)) == 3.0 + + @task + def add_bm(bm1: bm, bm2: bm) -> Union[int, bool, str, float]: + return bm1.a + bm2.b # type: ignore + + @workflow + def wf_add_bm(bm: bm) -> Union[int, bool, str, float]: + return add_bm(bm, bm) + + assert wf_add_bm(bm=bm(a=1, b=2)) == 3 + + @workflow + def wf_return_bm(bm: bm) -> bm: + return bm + + assert wf_return_bm(bm=bm(a=1, b=2)) == bm(a=1, b=2) diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py index f107384b96..3cc19f295c 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py @@ -16,7 +16,6 @@ from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow -from flytekit.exceptions.user import FlyteAssertion from flytekit.lazy_import.lazy_module import is_imported from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata @@ -35,7 +34,8 @@ pd = pytest.importorskip("pandas") pa = pytest.importorskip("pyarrow") -my_cols = kwtypes(w=typing.Dict[str, typing.Dict[str, int]], x=typing.List[typing.List[int]], y=int, z=str) +my_cols = kwtypes(w=typing.Dict[str, typing.Dict[str, int]], + x=typing.List[typing.List[int]], y=int, z=str) fields = [("some_int", pa.int32()), ("some_string", pa.string())] arrow_schema = pa.schema(fields) @@ -72,7 +72,8 @@ def t1(a: pd.DataFrame) -> pd.DataFrame: ctx = FlyteContextManager.current_context() with FlyteContextManager.with_context( ctx.with_execution_state( - ctx.new_execution_state().with_params(mode=ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION) + ctx.new_execution_state().with_params( + mode=ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION) ) ): result = t1(a=generate_pandas()) @@ -143,7 +144,8 @@ def test_types_annotated(): pt = Annotated[pd.DataFrame, PARQUET, arrow_schema] lt = TypeEngine.to_literal_type(pt) assert lt.structured_dataset_type.external_schema_type == "arrow" - assert "some_string" in str(lt.structured_dataset_type.external_schema_bytes) + assert "some_string" in str( + lt.structured_dataset_type.external_schema_bytes) pt = Annotated[pd.DataFrame, kwtypes(a=None)] with pytest.raises(AssertionError, match="type None is currently not supported by StructuredDataset"): @@ -171,8 +173,10 @@ def test_types_sd(): def test_retrieving(): - assert StructuredDatasetTransformerEngine.get_encoder(pd.DataFrame, "file", PARQUET) is not None - # Asking for a generic means you're okay with any one registered for that type assuming there's just one. + assert StructuredDatasetTransformerEngine.get_encoder( + pd.DataFrame, "file", PARQUET) is not None + # Asking for a generic means you're okay with any one registered for that + # type assuming there's just one. assert StructuredDatasetTransformerEngine.get_encoder( pd.DataFrame, "file", "" ) is StructuredDatasetTransformerEngine.get_encoder(pd.DataFrame, "file", PARQUET) @@ -184,9 +188,11 @@ def __init__(self, protocol): def encode(self): ... - StructuredDatasetTransformerEngine.register(TempEncoder("gs"), default_for_type=False) + StructuredDatasetTransformerEngine.register( + TempEncoder("gs"), default_for_type=False) with pytest.raises(ValueError): - StructuredDatasetTransformerEngine.register(TempEncoder("gs://"), default_for_type=False) + StructuredDatasetTransformerEngine.register( + TempEncoder("gs://"), default_for_type=False) with pytest.raises(ValueError, match="Use None instead"): e = TempEncoder("") @@ -197,7 +203,8 @@ class TempEncoder: pass with pytest.raises(TypeError, match="We don't support this type of handler"): - StructuredDatasetTransformerEngine.register(TempEncoder, default_for_type=False) + StructuredDatasetTransformerEngine.register( + TempEncoder, default_for_type=False) def test_to_literal(): @@ -215,16 +222,29 @@ def test_to_literal(): sd_with_literal_and_df._literal_sd = lit with pytest.raises(ValueError, match="Shouldn't have specified both literal"): - fdt.to_literal(ctx, sd_with_literal_and_df, python_type=StructuredDataset, expected=lt) + fdt.to_literal( + ctx, + sd_with_literal_and_df, + python_type=StructuredDataset, + expected=lt) sd_with_nothing = StructuredDataset() with pytest.raises(ValueError, match="If dataframe is not specified"): - fdt.to_literal(ctx, sd_with_nothing, python_type=StructuredDataset, expected=lt) + fdt.to_literal( + ctx, + sd_with_nothing, + python_type=StructuredDataset, + expected=lt) sd_with_uri = StructuredDataset(uri="s3://some/extant/df.parquet") - lt = TypeEngine.to_literal_type(Annotated[StructuredDataset, {}, "new-df-format"]) - lit = fdt.to_literal(ctx, sd_with_uri, python_type=StructuredDataset, expected=lt) + lt = TypeEngine.to_literal_type( + Annotated[StructuredDataset, {}, "new-df-format"]) + lit = fdt.to_literal( + ctx, + sd_with_uri, + python_type=StructuredDataset, + expected=lt) assert lit.scalar.structured_dataset.uri == "s3://some/extant/df.parquet" assert lit.scalar.structured_dataset.metadata.structured_dataset_type.format == "new-df-format" @@ -247,7 +267,8 @@ def encode( return literals.StructuredDataset(uri="") default_encoder = TempEncoder("myavro") - StructuredDatasetTransformerEngine.register(default_encoder, default_for_type=True) + StructuredDatasetTransformerEngine.register( + default_encoder, default_for_type=True) lt = TypeEngine.to_literal_type(MyDF) assert lt.structured_dataset_type.format == "" @@ -255,14 +276,18 @@ def encode( fdt = StructuredDatasetTransformerEngine() sd = StructuredDataset(dataframe=MyDF()) l = fdt.to_literal(ctx, sd, MyDF, lt) - # Test that the literal type is filled in even though the encode function above doesn't do it. + # Test that the literal type is filled in even though the encode function + # above doesn't do it. assert l.scalar.structured_dataset.metadata.structured_dataset_type.format == "myavro" - # Test that looking up encoders/decoders falls back to the "" encoder/decoder + # Test that looking up encoders/decoders falls back to the "" + # encoder/decoder empty_format_temp_encoder = TempEncoder("") - StructuredDatasetTransformerEngine.register(empty_format_temp_encoder, default_for_type=False) + StructuredDatasetTransformerEngine.register( + empty_format_temp_encoder, default_for_type=False) - res = StructuredDatasetTransformerEngine.get_encoder(MyDF, "tmpfs", "rando") + res = StructuredDatasetTransformerEngine.get_encoder( + MyDF, "tmpfs", "rando") assert res is empty_format_temp_encoder @@ -283,7 +308,8 @@ def encode( StructuredDatasetTransformerEngine.register(TempEncoder("/")) res = StructuredDatasetTransformerEngine.get_encoder(MyDF, "file", "/") # Test that the one we got was registered under fsspec - assert res is StructuredDatasetTransformerEngine.ENCODERS[MyDF].get("fsspec")["/"] + assert res is StructuredDatasetTransformerEngine.ENCODERS[MyDF].get("fsspec")[ + "/"] assert res is not None @@ -342,12 +368,18 @@ def decode( def test_convert_schema_type_to_structured_dataset_type(): schema_ct = SchemaType.SchemaColumn.SchemaColumnType - assert convert_schema_type_to_structured_dataset_type(schema_ct.INTEGER) == SimpleType.INTEGER - assert convert_schema_type_to_structured_dataset_type(schema_ct.FLOAT) == SimpleType.FLOAT - assert convert_schema_type_to_structured_dataset_type(schema_ct.STRING) == SimpleType.STRING - assert convert_schema_type_to_structured_dataset_type(schema_ct.DATETIME) == SimpleType.DATETIME - assert convert_schema_type_to_structured_dataset_type(schema_ct.DURATION) == SimpleType.DURATION - assert convert_schema_type_to_structured_dataset_type(schema_ct.BOOLEAN) == SimpleType.BOOLEAN + assert convert_schema_type_to_structured_dataset_type( + schema_ct.INTEGER) == SimpleType.INTEGER + assert convert_schema_type_to_structured_dataset_type( + schema_ct.FLOAT) == SimpleType.FLOAT + assert convert_schema_type_to_structured_dataset_type( + schema_ct.STRING) == SimpleType.STRING + assert convert_schema_type_to_structured_dataset_type( + schema_ct.DATETIME) == SimpleType.DATETIME + assert convert_schema_type_to_structured_dataset_type( + schema_ct.DURATION) == SimpleType.DURATION + assert convert_schema_type_to_structured_dataset_type( + schema_ct.BOOLEAN) == SimpleType.BOOLEAN with pytest.raises(AssertionError, match="Unrecognized SchemaColumnType"): convert_schema_type_to_structured_dataset_type(int) @@ -363,7 +395,8 @@ def test_to_python_value_with_incoming_columns(): df = generate_pandas() fdt = StructuredDatasetTransformerEngine() lit = fdt.to_literal(ctx, df, python_type=original_type, expected=lt) - assert len(lit.scalar.structured_dataset.metadata.structured_dataset_type.columns) == 2 + assert len( + lit.scalar.structured_dataset.metadata.structured_dataset_type.columns) == 2 # declare a new type that only has one column # get the dataframe, make sure it has the column that was asked for. @@ -373,7 +406,8 @@ def test_to_python_value_with_incoming_columns(): sub_df = sd.open(pd.DataFrame).all() assert sub_df.shape[1] == 1 - # check when columns are not specified, should pull both and add column information. + # check when columns are not specified, should pull both and add column + # information. sd = fdt.to_python_value(ctx, lit, StructuredDataset) assert len(sd.metadata.structured_dataset_type.columns) == 2 @@ -390,7 +424,8 @@ def test_to_python_value_without_incoming_columns(): df = generate_pandas() fdt = StructuredDatasetTransformerEngine() lit = fdt.to_literal(ctx, df, python_type=pd.DataFrame, expected=lt) - assert len(lit.scalar.structured_dataset.metadata.structured_dataset_type.columns) == 0 + assert len( + lit.scalar.structured_dataset.metadata.structured_dataset_type.columns) == 0 # declare a new type that only has one column # get the dataframe, make sure it has the column that was asked for. @@ -402,7 +437,8 @@ def test_to_python_value_without_incoming_columns(): # check when columns are not specified, should pull both and add column information. # todo: see the todos in the open_as, and iter_as functions in StructuredDatasetTransformerEngine - # we have to recreate the literal because the test case above filled in the metadata + # we have to recreate the literal because the test case above filled in + # the metadata lit = fdt.to_literal(ctx, df, python_type=pd.DataFrame, expected=lt) sd = fdt.to_python_value(ctx, lit, StructuredDataset) assert sd.metadata.structured_dataset_type.columns == [] @@ -434,7 +470,8 @@ def encode( ctx = FlyteContextManager.current_context() df = pd.DataFrame({"name": ["Tom", "Joseph"], "age": [20, 22]}) - annotated_sd_type = Annotated[StructuredDataset, "avro", kwtypes(name=str, age=int)] + annotated_sd_type = Annotated[StructuredDataset, + "avro", kwtypes(name=str, age=int)] df_literal_type = TypeEngine.to_literal_type(annotated_sd_type) assert df_literal_type.structured_dataset_type is not None assert len(df_literal_type.structured_dataset_type.columns) == 2 @@ -446,11 +483,20 @@ def encode( sd = annotated_sd_type(df) with pytest.raises(ValueError, match="Failed to find a handler"): - TypeEngine.to_literal(ctx, sd, python_type=annotated_sd_type, expected=df_literal_type) + TypeEngine.to_literal( + ctx, + sd, + python_type=annotated_sd_type, + expected=df_literal_type) - StructuredDatasetTransformerEngine.register(TempEncoder(), default_for_type=False) + StructuredDatasetTransformerEngine.register( + TempEncoder(), default_for_type=False) sd2 = annotated_sd_type(df) - sd_literal = TypeEngine.to_literal(ctx, sd2, python_type=annotated_sd_type, expected=df_literal_type) + sd_literal = TypeEngine.to_literal( + ctx, + sd2, + python_type=annotated_sd_type, + expected=df_literal_type) assert sd_literal.scalar.structured_dataset.metadata.structured_dataset_type.format == "avro" @task @@ -469,12 +515,15 @@ def test_protocol_detection(): assert protocol == "file" with tempfile.TemporaryDirectory() as tmp_dir: - fs = FileAccessProvider(local_sandbox_dir=tmp_dir, raw_output_prefix="s3://fdsa") + fs = FileAccessProvider( + local_sandbox_dir=tmp_dir, + raw_output_prefix="s3://fdsa") ctx2 = ctx.with_file_access(fs).build() protocol = e._protocol_from_type_or_prefix(ctx2, pd.DataFrame) assert protocol == "s3" - protocol = e._protocol_from_type_or_prefix(ctx2, pd.DataFrame, "bq://foo") + protocol = e._protocol_from_type_or_prefix( + ctx2, pd.DataFrame, "bq://foo") assert protocol == "bq" @@ -490,7 +539,8 @@ def to_html(self, input: str) -> str: assert pa.Table in renderers with pytest.raises(NotImplementedError, match="Could not find a renderer for in"): - StructuredDatasetTransformerEngine().to_html(FlyteContextManager.current_context(), 3, int) + StructuredDatasetTransformerEngine().to_html( + FlyteContextManager.current_context(), 3, int) def test_list_of_annotated(): @@ -524,7 +574,8 @@ def encode( def test_reregister_encoder(): # Test that lazy import can run after a user has already registered a custom handler. - # The default handlers don't have override=True (and should not) but the call should not fail. + # The default handlers don't have override=True (and should not) but the + # call should not fail. dir(google.cloud.bigquery) assert is_imported("google.cloud.bigquery") @@ -533,13 +584,18 @@ def test_reregister_encoder(): ) TypeEngine.lazy_import_transformers() - sd = StructuredDataset(dataframe=pd.DataFrame({"a": [1, 2], "b": [3, 4]}), uri="bq://blah", file_format="bq") + sd = StructuredDataset(dataframe=pd.DataFrame( + {"a": [1, 2], "b": [3, 4]}), uri="bq://blah", file_format="bq") ctx = FlyteContextManager.current_context() df_literal_type = TypeEngine.to_literal_type(pd.DataFrame) - TypeEngine.to_literal(ctx, sd, python_type=pd.DataFrame, expected=df_literal_type) + TypeEngine.to_literal( + ctx, + sd, + python_type=pd.DataFrame, + expected=df_literal_type) def test_default_args_task(): @@ -558,8 +614,10 @@ def wf_no_input() -> pd.DataFrame: def wf_with_input() -> pd.DataFrame: return t1(a=input_val) - wf_no_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_no_input) - wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input) + wf_no_input_spec = get_serializable( + OrderedDict(), serialization_settings, wf_no_input) + wf_with_input_spec = get_serializable( + OrderedDict(), serialization_settings, wf_with_input) assert wf_no_input_spec.template.nodes[0].inputs[ 0