Skip to content

Commit

Permalink
Add literal type to union literal (#1144)
Browse files Browse the repository at this point in the history
* Add literal type to union literal

Signed-off-by: Kevin Su <[email protected]>

* fix test

Signed-off-by: Kevin Su <[email protected]>

* Add tests

Signed-off-by: Kevin Su <[email protected]>

* more tests

Signed-off-by: Kevin Su <[email protected]>

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored and eapolinario committed Sep 15, 2022
1 parent 2aaaeaa commit 1d1fc85
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 4 deletions.
5 changes: 2 additions & 3 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase
from flytekit.models import literals
from flytekit.models.interface import Variable
from flytekit.models.literals import Blob, BlobMetadata, Primitive
from flytekit.models.literals import Blob, BlobMetadata, Primitive, Union
from flytekit.models.types import LiteralType, SimpleType
from flytekit.remote.executions import FlyteWorkflowExecution
from flytekit.tools import module_loader, script_mode
Expand Down Expand Up @@ -270,8 +270,7 @@ def convert_to_union(
# and then use flyte converter to convert it to literal.
python_val = converter._click_type.convert(value, param, ctx)
literal = converter.convert_to_literal(ctx, param, python_val)
self._python_type = python_type
return literal
return Literal(scalar=Scalar(union=Union(literal, variant)))
except (Exception or AttributeError) as e:
logging.debug(f"Failed to convert python type {python_type} to literal type {variant}", e)
raise ValueError(f"Failed to convert python type {self._python_type} to literal type {lt}")
Expand Down
64 changes: 63 additions & 1 deletion tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,31 @@
import functools
import os
import pathlib
import typing
from enum import Enum

import click
import mock
import pytest
from click.testing import CliRunner

from flytekit import FlyteContextManager
from flytekit.clis.sdk_in_container import pyflyte
from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE
from flytekit.clis.sdk_in_container.helpers import FLYTE_REMOTE_INSTANCE_KEY
from flytekit.clis.sdk_in_container.run import (
REMOTE_FLAG_KEY,
RUN_LEVEL_PARAMS_KEY,
FileParamType,
FlyteLiteralConverter,
get_entities_in_file,
run_command,
)
from flytekit.configuration import Image, ImageConfig
from flytekit.configuration import Config, Image, ImageConfig
from flytekit.core.task import task
from flytekit.core.type_engine import TypeEngine
from flytekit.models.types import SimpleType
from flytekit.remote import FlyteRemote

WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py")
IMPERATIVE_WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "imperative_wf.py")
Expand Down Expand Up @@ -267,3 +276,56 @@ def test_file_param():
assert l.local
r = FileParamType().convert("https://tmp/file", m, m)
assert r.local is False


class Color(Enum):
RED = "red"
GREEN = "green"
BLUE = "blue"


@pytest.mark.parametrize(
"python_type, python_value",
[
(typing.Union[typing.List[int], str, Color], "flyte"),
(typing.Union[typing.List[int], str, Color], "red"),
(typing.Union[typing.List[int], str, Color], [1, 2, 3]),
(typing.List[int], [1, 2, 3]),
(typing.Dict[str, int], {"flyte": 2}),
],
)
def test_literal_converter(python_type, python_value):
get_upload_url_fn = functools.partial(
FlyteRemote(Config.auto()).client.get_upload_signed_url, project="p", domain="d"
)
click_ctx = click.Context(click.Command("test_command"), obj={"remote": True})
ctx = FlyteContextManager.current_context()
lt = TypeEngine.to_literal_type(python_type)

lc = FlyteLiteralConverter(
click_ctx, ctx, literal_type=lt, python_type=python_type, get_upload_url_fn=get_upload_url_fn
)

assert lc.convert(click_ctx, ctx, python_value) == TypeEngine.to_literal(ctx, python_value, python_type, lt)


def test_enum_converter():
pt = typing.Union[str, Color]

get_upload_url_fn = functools.partial(FlyteRemote(Config.auto()).client.get_upload_signed_url)
click_ctx = click.Context(click.Command("test_command"), obj={"remote": True})
ctx = FlyteContextManager.current_context()
lt = TypeEngine.to_literal_type(pt)
lc = FlyteLiteralConverter(click_ctx, ctx, literal_type=lt, python_type=pt, get_upload_url_fn=get_upload_url_fn)
union_lt = lc.convert(click_ctx, ctx, "red").scalar.union

assert union_lt.stored_type.simple == SimpleType.STRING
assert union_lt.stored_type.enum_type is None

pt = typing.Union[Color, str]
lt = TypeEngine.to_literal_type(typing.Union[Color, str])
lc = FlyteLiteralConverter(click_ctx, ctx, literal_type=lt, python_type=pt, get_upload_url_fn=get_upload_url_fn)
union_lt = lc.convert(click_ctx, ctx, "red").scalar.union

assert union_lt.stored_type.simple is None
assert union_lt.stored_type.enum_type.values == ["red", "green", "blue"]

0 comments on commit 1d1fc85

Please sign in to comment.