From 3e080b5621b120c5f37e0effd6f6786cc5006a22 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 10 Sep 2022 02:58:19 +0800 Subject: [PATCH] Add literal type to union literal (#1144) * Add literal type to union literal Signed-off-by: Kevin Su * fix test Signed-off-by: Kevin Su * Add tests Signed-off-by: Kevin Su * more tests Signed-off-by: Kevin Su Signed-off-by: Kevin Su --- flytekit/clis/sdk_in_container/run.py | 5 +- tests/flytekit/unit/cli/pyflyte/test_run.py | 64 ++++++++++++++++++++- 2 files changed, 65 insertions(+), 4 deletions(-) diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 935cfc1ad3..d0b890ba7b 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -33,7 +33,7 @@ from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase from flytekit.models import literals from flytekit.models.interface import Variable -from flytekit.models.literals import Blob, BlobMetadata, Primitive +from flytekit.models.literals import Blob, BlobMetadata, Primitive, Union from flytekit.models.types import LiteralType, SimpleType from flytekit.remote.executions import FlyteWorkflowExecution from flytekit.tools import module_loader, script_mode @@ -270,8 +270,7 @@ def convert_to_union( # and then use flyte converter to convert it to literal. python_val = converter._click_type.convert(value, param, ctx) literal = converter.convert_to_literal(ctx, param, python_val) - self._python_type = python_type - return literal + return Literal(scalar=Scalar(union=Union(literal, variant))) except (Exception or AttributeError) as e: logging.debug(f"Failed to convert python type {python_type} to literal type {variant}", e) raise ValueError(f"Failed to convert python type {self._python_type} to literal type {lt}") diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 5bc94592b9..b7ac80cce4 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -1,10 +1,15 @@ +import functools import os import pathlib +import typing +from enum import Enum +import click import mock import pytest from click.testing import CliRunner +from flytekit import FlyteContextManager from flytekit.clis.sdk_in_container import pyflyte from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE from flytekit.clis.sdk_in_container.helpers import FLYTE_REMOTE_INSTANCE_KEY @@ -12,11 +17,15 @@ REMOTE_FLAG_KEY, RUN_LEVEL_PARAMS_KEY, FileParamType, + FlyteLiteralConverter, get_entities_in_file, run_command, ) -from flytekit.configuration import Image, ImageConfig +from flytekit.configuration import Config, Image, ImageConfig from flytekit.core.task import task +from flytekit.core.type_engine import TypeEngine +from flytekit.models.types import SimpleType +from flytekit.remote import FlyteRemote WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py") IMPERATIVE_WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "imperative_wf.py") @@ -266,3 +275,56 @@ def test_file_param(): assert l.local r = FileParamType().convert("https://tmp/file", m, m) assert r.local is False + + +class Color(Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + + +@pytest.mark.parametrize( + "python_type, python_value", + [ + (typing.Union[typing.List[int], str, Color], "flyte"), + (typing.Union[typing.List[int], str, Color], "red"), + (typing.Union[typing.List[int], str, Color], [1, 2, 3]), + (typing.List[int], [1, 2, 3]), + (typing.Dict[str, int], {"flyte": 2}), + ], +) +def test_literal_converter(python_type, python_value): + get_upload_url_fn = functools.partial( + FlyteRemote(Config.auto()).client.get_upload_signed_url, project="p", domain="d" + ) + click_ctx = click.Context(click.Command("test_command"), obj={"remote": True}) + ctx = FlyteContextManager.current_context() + lt = TypeEngine.to_literal_type(python_type) + + lc = FlyteLiteralConverter( + click_ctx, ctx, literal_type=lt, python_type=python_type, get_upload_url_fn=get_upload_url_fn + ) + + assert lc.convert(click_ctx, ctx, python_value) == TypeEngine.to_literal(ctx, python_value, python_type, lt) + + +def test_enum_converter(): + pt = typing.Union[str, Color] + + get_upload_url_fn = functools.partial(FlyteRemote(Config.auto()).client.get_upload_signed_url) + click_ctx = click.Context(click.Command("test_command"), obj={"remote": True}) + ctx = FlyteContextManager.current_context() + lt = TypeEngine.to_literal_type(pt) + lc = FlyteLiteralConverter(click_ctx, ctx, literal_type=lt, python_type=pt, get_upload_url_fn=get_upload_url_fn) + union_lt = lc.convert(click_ctx, ctx, "red").scalar.union + + assert union_lt.stored_type.simple == SimpleType.STRING + assert union_lt.stored_type.enum_type is None + + pt = typing.Union[Color, str] + lt = TypeEngine.to_literal_type(typing.Union[Color, str]) + lc = FlyteLiteralConverter(click_ctx, ctx, literal_type=lt, python_type=pt, get_upload_url_fn=get_upload_url_fn) + union_lt = lc.convert(click_ctx, ctx, "red").scalar.union + + assert union_lt.stored_type.simple is None + assert union_lt.stored_type.enum_type.values == ["red", "green", "blue"]