From d82f2cf5def2a9ec2316549bcb5199be3b3127bc Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Mon, 29 Jul 2024 15:03:48 +0800 Subject: [PATCH] [FlyteSchema] Fix numpy problems (#2619) Signed-off-by: Future-Outlier --- flytekit/core/type_engine.py | 5 ++- flytekit/interaction/click_types.py | 6 ++- flytekit/types/schema/types.py | 41 ++++++++++++------- .../flytekit-envd/tests/test_image_spec.py | 4 +- tests/flytekit/unit/core/test_dataclass.py | 31 ++++++++++++++ .../unit/interaction/test_click_types.py | 3 ++ 6 files changed, 69 insertions(+), 21 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 3165b4cdf5..5b0eb62c65 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1495,6 +1495,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp is_ambiguous = False res = None res_type = None + t = None for i in range(len(get_args(python_type))): try: t = get_args(python_type)[i] @@ -1504,8 +1505,8 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp if found_res: is_ambiguous = True found_res = True - except Exception: - logger.debug(f"Failed to convert from {python_val} to {t}", exc_info=True) + except Exception as e: + logger.debug(f"Failed to convert from {python_val} to {t} with error: {e}", exc_info=True) continue if is_ambiguous: diff --git a/flytekit/interaction/click_types.py b/flytekit/interaction/click_types.py index 491d2dba3f..101ecea3d1 100644 --- a/flytekit/interaction/click_types.py +++ b/flytekit/interaction/click_types.py @@ -175,7 +175,7 @@ def convert( if isinstance(value, ArtifactQuery): return value - if " " in value: + if isinstance(value, str) and " " in value: import re m = re.match(self._FLOATING_FORMAT_PATTERN, value) @@ -193,7 +193,9 @@ def convert( if parts[1] == "-": return dt - delta return dt + delta - raise click.BadParameter(f"Expected format {self.formats}, got {value}") + else: + value = datetime.datetime.fromisoformat(value) + return self._datetime_from_format(value, param, ctx) diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index 88adad2681..2cf0127d4c 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -9,7 +9,6 @@ from pathlib import Path from typing import Type -import numpy as _np from dataclasses_json import config from marshmallow import fields from mashumaro.mixins.json import DataClassJSONMixin @@ -349,27 +348,39 @@ def as_readonly(self) -> FlyteSchema: return s +def _get_numpy_type_mappings() -> typing.Dict[Type, SchemaType.SchemaColumn.SchemaColumnType]: + try: + import numpy as _np + + return { + _np.int32: SchemaType.SchemaColumn.SchemaColumnType.INTEGER, + _np.int64: SchemaType.SchemaColumn.SchemaColumnType.INTEGER, + _np.uint32: SchemaType.SchemaColumn.SchemaColumnType.INTEGER, + _np.uint64: SchemaType.SchemaColumn.SchemaColumnType.INTEGER, + _np.float32: SchemaType.SchemaColumn.SchemaColumnType.FLOAT, + _np.float64: SchemaType.SchemaColumn.SchemaColumnType.FLOAT, + _np.bool_: SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN, # type: ignore + _np.datetime64: SchemaType.SchemaColumn.SchemaColumnType.DATETIME, + _np.timedelta64: SchemaType.SchemaColumn.SchemaColumnType.DURATION, + _np.bytes_: SchemaType.SchemaColumn.SchemaColumnType.STRING, + _np.str_: SchemaType.SchemaColumn.SchemaColumnType.STRING, + _np.object_: SchemaType.SchemaColumn.SchemaColumnType.STRING, + } + except ImportError as e: + logger.warning("Numpy not found, skipping numpy type mappings, error: %s", e) + return {} + + class FlyteSchemaTransformer(TypeTransformer[FlyteSchema]): _SUPPORTED_TYPES: typing.Dict[Type, SchemaType.SchemaColumn.SchemaColumnType] = { - _np.int32: SchemaType.SchemaColumn.SchemaColumnType.INTEGER, - _np.int64: SchemaType.SchemaColumn.SchemaColumnType.INTEGER, - _np.uint32: SchemaType.SchemaColumn.SchemaColumnType.INTEGER, - _np.uint64: SchemaType.SchemaColumn.SchemaColumnType.INTEGER, - int: SchemaType.SchemaColumn.SchemaColumnType.INTEGER, - _np.float32: SchemaType.SchemaColumn.SchemaColumnType.FLOAT, - _np.float64: SchemaType.SchemaColumn.SchemaColumnType.FLOAT, float: SchemaType.SchemaColumn.SchemaColumnType.FLOAT, - _np.bool_: SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN, # type: ignore + int: SchemaType.SchemaColumn.SchemaColumnType.INTEGER, bool: SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN, - _np.datetime64: SchemaType.SchemaColumn.SchemaColumnType.DATETIME, + str: SchemaType.SchemaColumn.SchemaColumnType.STRING, datetime.datetime: SchemaType.SchemaColumn.SchemaColumnType.DATETIME, - _np.timedelta64: SchemaType.SchemaColumn.SchemaColumnType.DURATION, datetime.timedelta: SchemaType.SchemaColumn.SchemaColumnType.DURATION, - _np.bytes_: SchemaType.SchemaColumn.SchemaColumnType.STRING, - _np.str_: SchemaType.SchemaColumn.SchemaColumnType.STRING, - _np.object_: SchemaType.SchemaColumn.SchemaColumnType.STRING, - str: SchemaType.SchemaColumn.SchemaColumnType.STRING, } + _SUPPORTED_TYPES.update(_get_numpy_type_mappings()) def __init__(self): super().__init__("FlyteSchema Transformer", FlyteSchema) diff --git a/plugins/flytekit-envd/tests/test_image_spec.py b/plugins/flytekit-envd/tests/test_image_spec.py index 31cd92effe..cbd1eb761d 100644 --- a/plugins/flytekit-envd/tests/test_image_spec.py +++ b/plugins/flytekit-envd/tests/test_image_spec.py @@ -37,7 +37,7 @@ def test_image_spec(): apt_packages=["git"], python_version="3.8", base_image=base_image, - pip_index="https://private-pip-index/simple", + pip_index="https://pypi.python.org/simple", source_root=os.path.dirname(os.path.realpath(__file__)), ) @@ -58,7 +58,7 @@ def build(): install.python_packages(name=["pandas"]) install.apt_packages(name=["git"]) runtime.environ(env={{'PYTHONPATH': '/root:', '_F_IMG_ID': '{image_name}'}}, extra_path=['/root']) - config.pip_index(url="https://private-pip-index/simple") + config.pip_index(url="https://pypi.python.org/simple") install.python(version="3.8") io.copy(source="./", target="/root") """ diff --git a/tests/flytekit/unit/core/test_dataclass.py b/tests/flytekit/unit/core/test_dataclass.py index f07f51f7ae..654fca0a73 100644 --- a/tests/flytekit/unit/core/test_dataclass.py +++ b/tests/flytekit/unit/core/test_dataclass.py @@ -2,6 +2,7 @@ from dataclasses_json import DataClassJsonMixin from mashumaro.mixins.json import DataClassJSONMixin import os +import sys import tempfile from dataclasses import dataclass from typing import Annotated, List, Dict, Optional @@ -882,3 +883,33 @@ class NestedFlyteTypesWithDataClassJson(DataClassJsonMixin): transformer = DataclassTransformer() lt = transformer.get_literal_type(NestedFlyteTypesWithDataClassJson) assert lt.metadata is not None +@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10 or higher") +def test_numpy_import_issue_from_flyte_schema_in_dataclass(): + from dataclasses import dataclass + + from flytekit import task, workflow + from flytekit.types.directory import FlyteDirectory + from flytekit.types.file import FlyteFile + + @dataclass + class MyDataClass: + output_file: FlyteFile + output_directory: FlyteDirectory + + @task + def my_flyte_workflow(b: bool) -> list[MyDataClass | None]: + if b: + return [MyDataClass(__file__, ".")] + return [None] + + @task + def my_flyte_task(inputs: list[MyDataClass | None]) -> bool: + return inputs and (inputs[0] is not None) # type: ignore + + @workflow + def main_flyte_workflow(b: bool = False) -> bool: + inputs = my_flyte_workflow(b=b) + return my_flyte_task(inputs=inputs) + + assert main_flyte_workflow(b=True) == True + assert main_flyte_workflow(b=False) == False diff --git a/tests/flytekit/unit/interaction/test_click_types.py b/tests/flytekit/unit/interaction/test_click_types.py index d03891e75e..861f666952 100644 --- a/tests/flytekit/unit/interaction/test_click_types.py +++ b/tests/flytekit/unit/interaction/test_click_types.py @@ -181,6 +181,9 @@ def test_datetime_type(): with pytest.raises(click.BadParameter): t.convert("aaa + 1d", None, None) + fmt_v = "2024-07-29 13:47:07.643004+00:00" + d = t.convert(fmt_v, None, None) + _datetime_helper(t, fmt_v, d) def test_json_type():