diff --git a/Dockerfile.dev b/Dockerfile.dev index 760648d110..652867c529 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -38,7 +38,6 @@ RUN SETUPTOOLS_SCM_PRETEND_VERSION_FOR_FLYTEKIT=$PSEUDO_VERSION \ uv pip install --system --no-cache-dir -U \ "git+https://github.com/flyteorg/flyte.git@master#subdirectory=flyteidl" \ -e /flytekit \ - -e /flytekit/plugins/flytekit-k8s-pod \ -e /flytekit/plugins/flytekit-deck-standard \ -e /flytekit/plugins/flytekit-flyteinteractive \ scikit-learn \ @@ -50,5 +49,7 @@ RUN SETUPTOOLS_SCM_PRETEND_VERSION_FOR_FLYTEKIT=$PSEUDO_VERSION \ && chown flytekit: /home \ && : +ENV PYTHONPATH="/flytekit:" + # Switch to the 'flytekit' user for better security. USER flytekit diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index 0552197c0f..4e6286204c 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -63,8 +63,16 @@ def __init__( actual_task = python_function_task # TODO: add support for other Flyte entities - if not (isinstance(actual_task, PythonFunctionTask) or isinstance(actual_task, PythonInstanceTask)): - raise ValueError("Only PythonFunctionTask and PythonInstanceTask are supported in map tasks.") + if not ( + ( + isinstance(actual_task, PythonFunctionTask) + and actual_task.execution_mode == PythonFunctionTask.ExecutionBehavior.DEFAULT + ) + or isinstance(actual_task, PythonInstanceTask) + ): + raise ValueError( + "Only PythonFunctionTask with default execution mode (not @dynamic or @eager) and PythonInstanceTask are supported in map tasks." + ) for k, v in actual_task.python_interface.inputs.items(): if bound_inputs and k in bound_inputs: diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 1ce6a05488..6656c0c293 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -360,6 +360,7 @@ def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T): expected_type = get_underlying_type(expected_type) expected_fields_dict = {} + for f in dataclasses.fields(expected_type): expected_fields_dict[f.name] = f.type @@ -539,11 +540,13 @@ def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]: field.type = self._get_origin_type_in_annotation(field.type) return python_type - def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.Any) -> T: + def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.Any) -> T | None: # In python 3.7, 3.8, DataclassJson will deserialize Annotated[StructuredDataset, kwtypes(..)] to a dict, # so here we convert it back to the Structured Dataset. from flytekit.types.structured import StructuredDataset + if python_val is None: + return python_val if python_type == StructuredDataset and type(python_val) == dict: return StructuredDataset(**python_val) elif get_origin(python_type) is list: @@ -575,9 +578,13 @@ def _make_dataclass_serializable(self, python_val: T, python_type: Type[T]) -> t return self._make_dataclass_serializable(python_val, get_args(python_type)[0]) if hasattr(python_type, "__origin__") and get_origin(python_type) is list: + if python_val is None: + return None return [self._make_dataclass_serializable(v, get_args(python_type)[0]) for v in cast(list, python_val)] if hasattr(python_type, "__origin__") and get_origin(python_type) is dict: + if python_val is None: + return None return { k: self._make_dataclass_serializable(v, get_args(python_type)[1]) for k, v in cast(dict, python_val).items() diff --git a/flytekit/image_spec/default_builder.py b/flytekit/image_spec/default_builder.py index 32f20d6373..09b874693e 100644 --- a/flytekit/image_spec/default_builder.py +++ b/flytekit/image_spec/default_builder.py @@ -61,7 +61,6 @@ # Configure user space ENV PATH="/opt/micromamba/envs/runtime/bin:$$PATH" \ UV_LINK_MODE=copy \ - UV_PRERELEASE=allow \ FLYTE_SDK_RICH_TRACEBACKS=0 \ SSL_CERT_DIR=/etc/ssl/certs \ $ENV @@ -245,6 +244,7 @@ class DefaultImageBuilder(ImageSpecBuilder): "cudnn", "base_image", "pip_index", + "pip_extra_index_url", # "registry_config", "commands", } diff --git a/flytekit/image_spec/image_spec.py b/flytekit/image_spec/image_spec.py index e750cc211e..7e2c3acf32 100644 --- a/flytekit/image_spec/image_spec.py +++ b/flytekit/image_spec/image_spec.py @@ -3,6 +3,7 @@ import hashlib import os import pathlib +import re import typing from abc import abstractmethod from dataclasses import asdict, dataclass @@ -143,6 +144,10 @@ def exist(self) -> Optional[bool]: if e.response.status_code == 404: return False + if re.match(f"unknown: repository .*{self.name} not found", e.explanation): + click.secho(f"Received 500 error with explanation: {e.explanation}", fg="yellow") + return False + click.secho(f"Failed to check if the image exists with error:\n {e}", fg="red") return None except ImageNotFound: diff --git a/flytekit/interaction/click_types.py b/flytekit/interaction/click_types.py index 101ecea3d1..04a1848f84 100644 --- a/flytekit/interaction/click_types.py +++ b/flytekit/interaction/click_types.py @@ -1,3 +1,4 @@ +import dataclasses import datetime import enum import json @@ -5,7 +6,7 @@ import os import pathlib import typing -from typing import cast +from typing import cast, get_args import rich_click as click import yaml @@ -22,6 +23,7 @@ from flytekit.types.file import FlyteFile from flytekit.types.iterator.json_iterator import JSONIteratorTransformer from flytekit.types.pickle.pickle import FlytePickleTransformer +from flytekit.types.schema.types import FlyteSchema def is_pydantic_basemodel(python_type: typing.Type) -> bool: @@ -305,11 +307,50 @@ def convert( if value is None: raise click.BadParameter("None value cannot be converted to a Json type.") + FLYTE_TYPES = [FlyteFile, FlyteDirectory, StructuredDataset, FlyteSchema] + + def has_nested_dataclass(t: typing.Type) -> bool: + """ + Recursively checks whether the given type or its nested types contain any dataclass. + + This function is typically called with a dictionary or list type and will return True if + any of the nested types within the dictionary or list is a dataclass. + + Note: + - A single dataclass will return True. + - The function specifically excludes certain Flyte types like FlyteFile, FlyteDirectory, + StructuredDataset, and FlyteSchema from being considered as dataclasses. This is because + these types are handled separately by Flyte and do not need to be converted to dataclasses. + + Args: + t (typing.Type): The type to check for nested dataclasses. + + Returns: + bool: True if the type or its nested types contain a dataclass, False otherwise. + """ + + if dataclasses.is_dataclass(t): + # FlyteTypes is not supported now, we can support it in the future. + return t not in FLYTE_TYPES + + return any(has_nested_dataclass(arg) for arg in get_args(t)) + parsed_value = self._parse(value, param) # We compare the origin type because the json parsed value for list or dict is always a list or dict without # the covariant type information. if type(parsed_value) == typing.get_origin(self._python_type) or type(parsed_value) == self._python_type: + # Indexing the return value of get_args will raise an error for native dict and list types. + # We don't support native list/dict types with nested dataclasses. + if get_args(self._python_type) == (): + return parsed_value + elif isinstance(parsed_value, list) and has_nested_dataclass(get_args(self._python_type)[0]): + j = JsonParamType(get_args(self._python_type)[0]) + return [j.convert(v, param, ctx) for v in parsed_value] + elif isinstance(parsed_value, dict) and has_nested_dataclass(get_args(self._python_type)[1]): + j = JsonParamType(get_args(self._python_type)[1]) + return {k: j.convert(v, param, ctx) for k, v in parsed_value.items()} + return parsed_value if is_pydantic_basemodel(self._python_type): diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index dd0d50b8af..7cbaaa46ca 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -1144,8 +1144,9 @@ def _execute( """ if execution_name is not None and execution_name_prefix is not None: raise ValueError("Only one of execution_name and execution_name_prefix can be set, but got both set") - execution_name_prefix = execution_name_prefix + "-" if execution_name_prefix is not None else None - execution_name = execution_name or (execution_name_prefix or "f") + uuid.uuid4().hex[:19] + # todo: The prefix should be passed to the backend + if execution_name_prefix is not None: + execution_name = execution_name_prefix + "-" + uuid.uuid4().hex[:19] if not options: options = Options() if options.disable_notifications is not None: diff --git a/plugins/flytekit-k8s-pod/README.md b/plugins/flytekit-k8s-pod/README.md index 0c09d96c7d..8b25278124 100644 --- a/plugins/flytekit-k8s-pod/README.md +++ b/plugins/flytekit-k8s-pod/README.md @@ -1,5 +1,11 @@ # Flytekit Kubernetes Pod Plugin +> [!IMPORTANT] +> This plugin is no longer needed and is here only for backwards compatibility. No new versions will be published after v1.13.x +> Please use the `pod_template` and `pod_template_name` args to `@task` as described in https://docs.flyte.org/en/latest/deployment/configuration/general.html#configuring-task-pods-with-k8s-podtemplates +> instead. + + By default, Flyte tasks decorated with `@task` are essentially single functions that are loaded in one container. But often, there is a need to run a job with more than one container. In this case, a regular task is not enough. Hence, Flyte provides a Kubernetes pod abstraction to execute multiple containers, which can be accomplished using Pod's `task_config`. The `task_config` can be leveraged to fully customize the pod spec used to run the task. diff --git a/plugins/flytekit-k8s-pod/flytekitplugins/pod/__init__.py b/plugins/flytekit-k8s-pod/flytekitplugins/pod/__init__.py index 3e68602354..50dd9b5617 100644 --- a/plugins/flytekit-k8s-pod/flytekitplugins/pod/__init__.py +++ b/plugins/flytekit-k8s-pod/flytekitplugins/pod/__init__.py @@ -1,3 +1,7 @@ +import warnings + +from .task import Pod + """ .. currentmodule:: flytekitplugins.pod @@ -10,4 +14,8 @@ Pod """ -from .task import Pod +warnings.warn( + "This pod plugin is no longer necessary, please use the pod_template and pod_template_name options to @task as described " + "in https://docs.flyte.org/en/latest/deployment/configuration/general.html#configuring-task-pods-with-k8s-podtemplates", + FutureWarning, +) diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index f220517849..bbe3e842b3 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -26,8 +26,7 @@ class PolarsDataFrameRenderer: def to_html(self, df: pl.DataFrame) -> str: assert isinstance(df, pl.DataFrame) - describe_df = df.describe() - return pd.DataFrame(describe_df.transpose(), columns=describe_df.columns).to_html(index=False) + return df.describe().to_pandas().to_html(index=False) class PolarsDataFrameToParquetEncodingHandler(StructuredDatasetEncoder): diff --git a/plugins/flytekit-polars/setup.py b/plugins/flytekit-polars/setup.py index 483c3d18a4..d1a2372eff 100644 --- a/plugins/flytekit-polars/setup.py +++ b/plugins/flytekit-polars/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "polars>=0.8.27,<0.17.0", "pandas"] +plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "polars>=0.8.27", "pandas"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py index eecfeb8d78..1283438a93 100644 --- a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py +++ b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py @@ -4,6 +4,8 @@ import polars as pl from flytekitplugins.polars.sd_transformers import PolarsDataFrameRenderer from typing_extensions import Annotated +from packaging import version +from polars.testing import assert_frame_equal from flytekit import kwtypes, task, workflow from flytekit.types.structured.structured_dataset import PARQUET, StructuredDataset @@ -11,6 +13,8 @@ subset_schema = Annotated[StructuredDataset, kwtypes(col2=str), PARQUET] full_schema = Annotated[StructuredDataset, PARQUET] +polars_version = pl.__version__ + def test_polars_workflow_subset(): @task @@ -65,9 +69,9 @@ def wf() -> full_schema: def test_polars_renderer(): df = pl.DataFrame({"col1": [1, 3, 2], "col2": list("abc")}) - assert PolarsDataFrameRenderer().to_html(df) == pd.DataFrame( - df.describe().transpose(), columns=df.describe().columns - ).to_html(index=False) + assert PolarsDataFrameRenderer().to_html(df) == df.describe().to_pandas().to_html( + index=False + ) def test_parquet_to_polars(): @@ -80,7 +84,7 @@ def create_sd() -> StructuredDataset: sd = create_sd() polars_df = sd.open(pl.DataFrame).all() - assert pl.DataFrame(data).frame_equal(polars_df) + assert_frame_equal(polars_df, pl.DataFrame(data)) tmp = tempfile.mktemp() pl.DataFrame(data).write_parquet(tmp) @@ -90,11 +94,11 @@ def t1(sd: StructuredDataset) -> pl.DataFrame: return sd.open(pl.DataFrame).all() sd = StructuredDataset(uri=tmp) - assert t1(sd=sd).frame_equal(polars_df) + assert_frame_equal(t1(sd=sd), polars_df) @task def t2(sd: StructuredDataset) -> StructuredDataset: return StructuredDataset(dataframe=sd.open(pl.DataFrame).all()) sd = StructuredDataset(uri=tmp) - assert t2(sd=sd).open(pl.DataFrame).all().frame_equal(polars_df) + assert_frame_equal(t2(sd=sd).open(pl.DataFrame).all(), polars_df) diff --git a/tests/flytekit/unit/cli/pyflyte/my_wf_input.json b/tests/flytekit/unit/cli/pyflyte/my_wf_input.json index c20081f3b2..4c596e4d55 100644 --- a/tests/flytekit/unit/cli/pyflyte/my_wf_input.json +++ b/tests/flytekit/unit/cli/pyflyte/my_wf_input.json @@ -42,6 +42,9 @@ }, "p": "None", "q": "tests/flytekit/unit/cli/pyflyte/testdata", + "r": [{"i": 1, "a": ["h", "e"]}], + "s": {"x": {"i": 1, "a": ["h", "e"]}}, + "t": {"i": [{"i":1,"a":["h","e"]}]}, "remote": "tests/flytekit/unit/cli/pyflyte/testdata", "image": "tests/flytekit/unit/cli/pyflyte/testdata" } diff --git a/tests/flytekit/unit/cli/pyflyte/my_wf_input.yaml b/tests/flytekit/unit/cli/pyflyte/my_wf_input.yaml index 678f5331c8..5f15826b80 100644 --- a/tests/flytekit/unit/cli/pyflyte/my_wf_input.yaml +++ b/tests/flytekit/unit/cli/pyflyte/my_wf_input.yaml @@ -30,5 +30,22 @@ o: - tests/flytekit/unit/cli/pyflyte/testdata/df.parquet p: 'None' q: tests/flytekit/unit/cli/pyflyte/testdata +r: + - i: 1 + a: + - h + - e +s: + x: + i: 1 + a: + - h + - e +t: + i: + - i: 1 + a: + - h + - e remote: tests/flytekit/unit/cli/pyflyte/testdata image: tests/flytekit/unit/cli/pyflyte/testdata diff --git a/tests/flytekit/unit/cli/pyflyte/test_register.py b/tests/flytekit/unit/cli/pyflyte/test_register.py index 66967393fb..ec14aa8227 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_register.py +++ b/tests/flytekit/unit/cli/pyflyte/test_register.py @@ -48,8 +48,10 @@ def reset_flytectl_config_env_var() -> pytest.fixture(): return os.environ[FLYTECTL_CONFIG_ENV_VAR] +@mock.patch("flytekit.configuration.plugin.get_config_file") @mock.patch("flytekit.configuration.plugin.FlyteRemote") -def test_get_remote(mock_remote, reset_flytectl_config_env_var): +def test_get_remote(mock_remote, mock_config_file, reset_flytectl_config_env_var): + mock_config_file.return_value = None r = FlytekitPlugin.get_remote(None, "p", "d") assert r is not None mock_remote.assert_called_once_with( @@ -57,8 +59,10 @@ def test_get_remote(mock_remote, reset_flytectl_config_env_var): ) +@mock.patch("flytekit.configuration.plugin.get_config_file") @mock.patch("flytekit.configuration.plugin.FlyteRemote") -def test_saving_remote(mock_remote): +def test_saving_remote(mock_remote, mock_config_file): + mock_config_file.return_value = None mock_context = mock.MagicMock mock_context.obj = {} get_and_save_remote_with_click_context(mock_context, "p", "d") diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 475fb42ff1..58c4518f3d 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -201,6 +201,12 @@ def test_pyflyte_run_cli(workflow_file): "Any", "--q", DIR_NAME, + "--r", + json.dumps([{"i": 1, "a": ["h", "e"]}]), + "--s", + json.dumps({"x": {"i": 1, "a": ["h", "e"]}}), + "--t", + json.dumps({"i": [{"i":1,"a":["h","e"]}]}), ], catch_exceptions=False, ) diff --git a/tests/flytekit/unit/cli/pyflyte/workflow.py b/tests/flytekit/unit/cli/pyflyte/workflow.py index accebf82df..104538c338 100644 --- a/tests/flytekit/unit/cli/pyflyte/workflow.py +++ b/tests/flytekit/unit/cli/pyflyte/workflow.py @@ -35,6 +35,9 @@ class MyDataclass(DataClassJsonMixin): i: int a: typing.List[str] +@dataclass +class NestedDataclass(DataClassJsonMixin): + i: typing.List[MyDataclass] class Color(enum.Enum): RED = "RED" @@ -61,8 +64,11 @@ def print_all( o: typing.Dict[str, typing.List[FlyteFile]], p: typing.Any, q: FlyteDirectory, + r: typing.List[MyDataclass], + s: typing.Dict[str, MyDataclass], + t: NestedDataclass, ): - print(f"{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}, {j}, {k}, {l}, {m}, {n}, {o}, {p}, {q}") + print(f"{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}, {j}, {k}, {l}, {m}, {n}, {o}, {p}, {q}, {r}, {s}, {t}") @task @@ -93,6 +99,9 @@ def my_wf( o: typing.Dict[str, typing.List[FlyteFile]], p: typing.Any, q: FlyteDirectory, + r: typing.List[MyDataclass], + s: typing.Dict[str, MyDataclass], + t: NestedDataclass, remote: pd.DataFrame, image: StructuredDataset, m: dict = {"hello": "world"}, @@ -100,7 +109,7 @@ def my_wf( x = get_subset_df(df=remote) # noqa: shown for demonstration; users should use the same types between tasks show_sd(in_sd=x) show_sd(in_sd=image) - print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k, l=l, m=m, n=n, o=o, p=p, q=q) + print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k, l=l, m=m, n=n, o=o, p=p, q=q, r=r, s=s, t=t) return x diff --git a/tests/flytekit/unit/configuration/test_file.py b/tests/flytekit/unit/configuration/test_file.py index 3ce03f9c50..42f66d5ff5 100644 --- a/tests/flytekit/unit/configuration/test_file.py +++ b/tests/flytekit/unit/configuration/test_file.py @@ -4,10 +4,11 @@ import mock import pytest +from pathlib import Path from pytimeparse.timeparse import timeparse from flytekit.configuration import ConfigEntry, get_config_file, set_if_exists -from flytekit.configuration.file import LegacyConfigEntry, _exists +from flytekit.configuration.file import LegacyConfigEntry, _exists, FLYTECTL_CONFIG_ENV_VAR, FLYTECTL_CONFIG_ENV_VAR_OVERRIDE from flytekit.configuration.internal import Platform @@ -42,8 +43,23 @@ def test_exists(data, expected): def test_get_config_file(): + def all_path_not_exists(paths): + for path in paths: + if path.exists(): + return False + return True + + paths = [ + Path("flytekit.config"), + Path(Path.home(), ".flyte", "config"), + Path(Path.home(), ".flyte", "config.yaml") + ] + config_file = os.getenv(FLYTECTL_CONFIG_ENV_VAR_OVERRIDE, os.getenv(FLYTECTL_CONFIG_ENV_VAR)) + if config_file: + paths.append(Path(config_file)) + c = get_config_file(None) - assert c is None + assert (c is None) == all_path_not_exists(paths) c = get_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/good.config")) assert c is not None assert c.legacy_config is not None diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index 032c6e58f1..74f1868eb4 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -6,13 +6,14 @@ import pytest -from flytekit import map_task, task, workflow +from flytekit import dynamic, map_task, task, workflow from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings from flytekit.core import context_manager from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver from flytekit.core.task import TaskMetadata from flytekit.core.type_engine import TypeEngine from flytekit.extras.accelerators import GPUAccelerator +from flytekit.experimental.eager_function import eager from flytekit.tools.translator import get_serializable from flytekit.types.pickle import BatchSize @@ -403,3 +404,34 @@ def wf(x: typing.List[int]): task_spec = od[arraynode_maptask] assert task_spec.template.extended_resources.gpu_accelerator.device == "test_gpu" + + +def test_supported_node_type(): + @task + def test_task(): + ... + + map_task(test_task) + + +def test_unsupported_node_types(): + @dynamic + def test_dynamic(): + ... + + with pytest.raises(ValueError): + map_task(test_dynamic) + + @eager + def test_eager(): + ... + + with pytest.raises(ValueError): + map_task(test_eager) + + @workflow + def test_wf(): + ... + + with pytest.raises(ValueError): + map_task(test_wf) diff --git a/tests/flytekit/unit/interaction/test_click_types.py b/tests/flytekit/unit/interaction/test_click_types.py index a9ccfe61b3..11cfb374d8 100644 --- a/tests/flytekit/unit/interaction/test_click_types.py +++ b/tests/flytekit/unit/interaction/test_click_types.py @@ -1,3 +1,4 @@ +from dataclasses import field import json import tempfile import typing @@ -270,3 +271,230 @@ class Datum: assert v.y == "2" assert v.z == {1: "one", 2: "two"} assert v.w == [1, 2, 3] + + +def test_nested_dataclass_type(): + from dataclasses import dataclass + + @dataclass + class Datum: + w: int + x: str = "default" + y: typing.Dict[str, str] = field(default_factory=lambda: {"key": "value"}) + z: typing.List[int] = field(default_factory=lambda: [1, 2, 3]) + + @dataclass + class NestedDatum: + w: Datum + x: typing.List[Datum] + y: typing.Dict[str, Datum] = field(default_factory=lambda: {"key": Datum(1)}) + + + # typing.List[Datum] + value = '[{ "w": 1 }]' + t = JsonParamType(typing.List[Datum]) + v = t.convert(value=value, param=None, ctx=None) + + ctx = FlyteContextManager.current_context() + lt = TypeEngine.to_literal_type(typing.List[Datum]) + literal_converter = FlyteLiteralConverter( + ctx, literal_type=lt, python_type=typing.List[Datum], is_remote=False + ) + v = literal_converter.convert(ctx, None, v) + + assert v[0].w == 1 + assert v[0].x == "default" + assert v[0].y == {"key": "value"} + assert v[0].z == [1, 2, 3] + + # typing.Dict[str, Datum] + value = '{ "x": { "w": 1 } }' + t = JsonParamType(typing.Dict[str, Datum]) + v = t.convert(value=value, param=None, ctx=None) + ctx = FlyteContextManager.current_context() + lt = TypeEngine.to_literal_type(typing.Dict[str, Datum]) + literal_converter = FlyteLiteralConverter( + ctx, literal_type=lt, python_type=typing.Dict[str, Datum], is_remote=False + ) + v = literal_converter.convert(ctx, None, v) + + assert v["x"].w == 1 + assert v["x"].x == "default" + assert v["x"].y == {"key": "value"} + assert v["x"].z == [1, 2, 3] + + # typing.List[NestedDatum] + value = '[{"w":{ "w" : 1 },"x":[{ "w" : 1 }]}]' + t = JsonParamType(typing.List[NestedDatum]) + v = t.convert(value=value, param=None, ctx=None) + ctx = FlyteContextManager.current_context() + lt = TypeEngine.to_literal_type(typing.List[NestedDatum]) + literal_converter = FlyteLiteralConverter( + ctx, literal_type=lt, python_type=typing.List[NestedDatum], is_remote=False + ) + v = literal_converter.convert(ctx, None, v) + + assert v[0].w.w == 1 + assert v[0].w.x == "default" + assert v[0].w.y == {"key": "value"} + assert v[0].w.z == [1, 2, 3] + assert v[0].x[0].w == 1 + assert v[0].x[0].x == "default" + assert v[0].x[0].y == {"key": "value"} + assert v[0].x[0].z == [1, 2, 3] + + # typing.List[typing.List[Datum]] + value = '[[{ "w": 1 }]]' + t = JsonParamType(typing.List[typing.List[Datum]]) + v = t.convert(value=value, param=None, ctx=None) + ctx = FlyteContextManager.current_context() + lt = TypeEngine.to_literal_type(typing.List[typing.List[Datum]]) + literal_converter = FlyteLiteralConverter( + ctx, literal_type=lt, python_type=typing.List[typing.List[Datum]], is_remote=False + ) + v = literal_converter.convert(ctx, None, v) + + assert v[0][0].w == 1 + assert v[0][0].x == "default" + assert v[0][0].y == {"key": "value"} + assert v[0][0].z == [1, 2, 3] + +def test_dataclass_with_default_none(): + from dataclasses import dataclass + + @dataclass + class Datum: + x: int + y: str = None + z: typing.Dict[int, str] = None + w: typing.List[int] = None + + t = JsonParamType(Datum) + value = '{ "x": 1 }' + v = t.convert(value=value, param=None, ctx=None) + lt = TypeEngine.to_literal_type(Datum) + ctx = FlyteContextManager.current_context() + literal_converter = FlyteLiteralConverter( + ctx, literal_type=lt, python_type=Datum, is_remote=False + ) + v = literal_converter.convert(ctx=ctx, param=None, value=v) + + assert v.x == 1 + assert v.y is None + assert v.z is None + assert v.w is None + + +def test_dataclass_with_flyte_type_exception(): + from dataclasses import dataclass + from flytekit import StructuredDataset + from flytekit.types.directory import FlyteDirectory + from flytekit.types.file import FlyteFile + import os + + DIR_NAME = os.path.dirname(os.path.realpath(__file__)) + parquet_file = os.path.join(DIR_NAME, "testdata/df.parquet") + + @dataclass + class Datum: + x: FlyteFile + y: FlyteDirectory + z: StructuredDataset + + t = JsonParamType(Datum) + value = { "x": parquet_file, "y": DIR_NAME, "z": os.path.join(DIR_NAME, "testdata")} + + with pytest.raises(AttributeError): + t.convert(value=value, param=None, ctx=None) + +def test_dataclass_with_optional_fields(): + from dataclasses import dataclass + from typing import Optional + + @dataclass + class Datum: + x: int + y: Optional[str] = None + z: Optional[typing.Dict[int, str]] = None + w: Optional[typing.List[int]] = None + + t = JsonParamType(Datum) + value = '{ "x": 1 }' + v = t.convert(value=value, param=None, ctx=None) + lt = TypeEngine.to_literal_type(Datum) + ctx = FlyteContextManager.current_context() + literal_converter = FlyteLiteralConverter( + ctx, literal_type=lt, python_type=Datum, is_remote=False + ) + v = literal_converter.convert(ctx=ctx, param=None, value=v) + + # Assertions to check the Optional fields + assert v.x == 1 + assert v.y is None # Optional field with no value provided + assert v.z is None # Optional field with no value provided + assert v.w is None # Optional field with no value provided + + # Test with all fields provided + value = '{ "x": 2, "y": "test", "z": {"1": "value"}, "w": [1, 2, 3] }' + v = t.convert(value=value, param=None, ctx=None) + v = literal_converter.convert(ctx=ctx, param=None, value=v) + + assert v.x == 2 + assert v.y == "test" + assert v.z == {1: "value"} + assert v.w == [1, 2, 3] + +def test_nested_dataclass_with_optional_fields(): + from dataclasses import dataclass + from typing import Optional, List, Dict + + @dataclass + class InnerDatum: + a: int + b: Optional[str] = None + + @dataclass + class Datum: + x: int + y: Optional[InnerDatum] = None + z: Optional[Dict[str, InnerDatum]] = None + w: Optional[List[InnerDatum]] = None + + t = JsonParamType(Datum) + + # Case 1: Only required field provided + value = '{ "x": 1 }' + v = t.convert(value=value, param=None, ctx=None) + lt = TypeEngine.to_literal_type(Datum) + ctx = FlyteContextManager.current_context() + literal_converter = FlyteLiteralConverter( + ctx, literal_type=lt, python_type=Datum, is_remote=False + ) + v = literal_converter.convert(ctx=ctx, param=None, value=v) + + # Assertions to check the Optional fields + assert v.x == 1 + assert v.y is None # Optional field with no value provided + assert v.z is None # Optional field with no value provided + assert v.w is None # Optional field with no value provided + + # Case 2: All fields provided with nested structures + value = ''' + { + "x": 2, + "y": {"a": 10, "b": "inner"}, + "z": {"key": {"a": 20, "b": "nested"}}, + "w": [{"a": 30, "b": "list_item"}] + } + ''' + v = t.convert(value=value, param=None, ctx=None) + v = literal_converter.convert(ctx=ctx, param=None, value=v) + + # Assertions for nested structure + assert v.x == 2 + assert v.y.a == 10 + assert v.y.b == "inner" + assert v.z["key"].a == 20 + assert v.z["key"].b == "nested" + assert v.w[0].a == 30 + assert v.w[0].b == "list_item" diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 3852da9a31..81e70e0a21 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -598,7 +598,7 @@ def test_execution_name(mock_client, mock_uuid): [ mock.call(ANY, ANY, "execution-test", ANY, ANY), mock.call(ANY, ANY, "execution-test-" + test_uuid.hex[:19], ANY, ANY), - mock.call(ANY, ANY, "f" + test_uuid.hex[:19], ANY, ANY), + mock.call(ANY, ANY, None, ANY, ANY), ] ) with pytest.raises(