Skip to content

Commit

Permalink
EnumType support in flytekit (#509)
Browse files Browse the repository at this point in the history
Example usage:

```python
from enum import Enum

class Color(Enum):
   RED = "red"
   BLUE = "blue"
   GREEN = "green"

@task
def foo(c: Color) -> str:
   return c.value
```

UI/UX: Drop down. RED will be the default value
flytectl: enforce the value to be limited to one of the enum values

Signed-off-by: Ketan Umare <[email protected]>
Signed-off-by: Haytham Abuelfutuh <[email protected]>
  • Loading branch information
kumare3 authored and EngHabu committed Jun 25, 2021
1 parent 6ebf69e commit 422f437
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 21 deletions.
50 changes: 50 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dataclasses
import datetime as _datetime
import enum
import json as _json
import mimetypes
import os
Expand Down Expand Up @@ -300,12 +301,39 @@ def register(cls, transformer: TypeTransformer):

@classmethod
def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:
"""
The TypeEngine hierarchy for flyteKit. This method looksup and selects the type transformer. The algorithm is
as follows
d = dictionary of registered transformers, where is a python `type`
v = lookup type
Step 1:
find a transformer that matches v exactly
Step 2:
find a transformer that matches the generic type of v. e.g List[int], Dict[str, int] etc
Step 3:
if v is of type data class, use the dataclass transformer
Step 4:
Walk the inheritance hierarchy of v and find a transformer that matches the first base class.
This is potentially non-deterministic - will depend on the registration pattern.
TODO lets make this deterministic by using an ordered dict
"""
# Step 1
if python_type in cls._REGISTRY:
return cls._REGISTRY[python_type]

# Step 2
if hasattr(python_type, "__origin__"):
if python_type.__origin__ in cls._REGISTRY:
return cls._REGISTRY[python_type.__origin__]
raise ValueError(f"Generic Type {python_type.__origin__} not supported currently in Flytekit.")

# Step 3
if dataclasses.is_dataclass(python_type):
return cls._DATACLASS_TRANSFORMER

Expand Down Expand Up @@ -622,6 +650,27 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
return local_destination_path


class EnumTransformer(TypeTransformer[enum.Enum]):
"""
Enables converting a python type enum.Enum to LiteralType.EnumType
"""

def __init__(self):
super().__init__(name="DefaultEnumTransformer", t=enum.Enum)

def get_literal_type(self, t: Type[T]) -> LiteralType:
values = [v.value for v in t]
if not isinstance(values[0], str):
raise AssertionError("Only EnumTypes with value of string are supported")
return LiteralType(enum_type=_core_types.EnumType(values=values))

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
return Literal(scalar=Scalar(primitive=Primitive(string_value=python_val.value)))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
return expected_python_type(lv.scalar.primitive.string_value)


def _check_and_covert_float(lv: Literal) -> float:
if lv.scalar.primitive.float_value is not None:
return lv.scalar.primitive.float_value
Expand Down Expand Up @@ -705,6 +754,7 @@ def _register_default_type_transformers():
TypeEngine.register(TextIOTransformer())
TypeEngine.register(PathLikeTransformer())
TypeEngine.register(BinaryIOTransformer())
TypeEngine.register(EnumTransformer())

# inner type is. Also unsupported are typing's Tuples. Even though you can look inside them, Flyte's type system
# doesn't support these currently.
Expand Down
24 changes: 24 additions & 0 deletions flytekit/models/core/types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,32 @@
import typing

from flyteidl.core import types_pb2 as _types_pb2

from flytekit.models import common as _common


class EnumType(_common.FlyteIdlEntity):
"""
Models _types_pb2.EnumType
"""

def __init__(self, values: typing.List[str]):
self._values = values

@property
def values(self) -> typing.List[str]:
return self._values

def to_flyte_idl(self) -> _types_pb2.EnumType:
return _types_pb2.EnumType(
values=self._values if self._values else [],
)

@classmethod
def from_flyte_idl(cls, proto: _types_pb2.EnumType):
return cls(values=proto.values)


class BlobType(_common.FlyteIdlEntity):
class BlobDimensionality(object):
SINGLE = _types_pb2.BlobType.SINGLE
Expand Down
36 changes: 16 additions & 20 deletions flytekit/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(
collection_type=None,
map_value_type=None,
blob=None,
enum_type=None,
metadata=None,
):
"""
Expand All @@ -116,54 +117,47 @@ def __init__(
:param LiteralType map_value_type: For map objects, this is the type of the value. The key must always be a
string.
:param flytekit.models.core.types.BlobType blob: For blob objects, this describes the type.
:param flytekit.models.core.types.EnumType enum_type: For enum objects, describes an enum
:param dict[Text, T] metadata: Additional data describing the type
"""
self._simple = simple
self._schema = schema
self._collection_type = collection_type
self._map_value_type = map_value_type
self._blob = blob
self._enum_type = enum_type
self._metadata = metadata

@property
def simple(self):
"""
Enum type from SimpleType
:rtype: int
"""
def simple(self) -> SimpleType:
return self._simple

@property
def schema(self):
"""
Type definition for a dataframe-like object.
:rtype: SchemaType
"""
def schema(self) -> SchemaType:
return self._schema

@property
def collection_type(self):
def collection_type(self) -> "LiteralType":
"""
Enum type from SimpleType or SchemaType
:rtype: LiteralType
The collection value type
"""
return self._collection_type

@property
def map_value_type(self):
def map_value_type(self) -> "LiteralType":
"""
Enum type from SimpleType
:rtype: LiteralType
The Value for a dictionary. Key is always string
"""
return self._map_value_type

@property
def blob(self):
"""
:rtype: flytekit.models.core.types.BlobType
"""
def blob(self) -> _core_types.BlobType:
return self._blob

@property
def enum_type(self) -> _core_types.EnumType:
return self._enum_type

@property
def metadata(self):
"""
Expand All @@ -185,6 +179,7 @@ def to_flyte_idl(self):
collection_type=self.collection_type.to_flyte_idl() if self.collection_type is not None else None,
map_value_type=self.map_value_type.to_flyte_idl() if self.map_value_type is not None else None,
blob=self.blob.to_flyte_idl() if self.blob is not None else None,
enum_type=self.enum_type.to_flyte_idl() if self.enum_type else None,
metadata=metadata,
)
return t
Expand All @@ -207,6 +202,7 @@ def from_flyte_idl(cls, proto):
collection_type=collection_type,
map_value_type=map_value_type,
blob=_core_types.BlobType.from_flyte_idl(proto.blob) if proto.HasField("blob") else None,
enum_type=_core_types.EnumType.from_flyte_idl(proto.enum_type) if proto.HasField("enum_type") else None,
metadata=_json_format.MessageToDict(proto.metadata) or None,
)

Expand Down
48 changes: 47 additions & 1 deletion tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import typing
from dataclasses import dataclass
from datetime import timedelta
from enum import Enum

import pytest
from dataclasses_json import dataclass_json
from flyteidl.core import errors_pb2

from flytekit.core.context_manager import FlyteContext
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import (
DataclassTransformer,
DictTransformer,
Expand Down Expand Up @@ -325,3 +326,48 @@ def test_dataclass_transformer():
assert t.simple is not None
assert t.simple == SimpleType.STRUCT
assert t.metadata is None


# Enums should have string values
class Color(Enum):
RED = "red"
GREEN = "green"
BLUE = "blue"


# Enums with integer values are not supported
class UnsupportedEnumValues(Enum):
RED = 1
GREEN = 2
BLUE = 3


def test_enum_type():
t = TypeEngine.to_literal_type(Color)
assert t is not None
assert t.enum_type is not None
assert t.enum_type.values
assert t.enum_type.values == [c.value for c in Color]

ctx = FlyteContextManager.current_context()
lv = TypeEngine.to_literal(ctx, Color.RED, Color, TypeEngine.to_literal_type(Color))
assert lv
assert lv.scalar
assert lv.scalar.primitive.string_value == "red"

v = TypeEngine.to_python_value(ctx, lv, Color)
assert v
assert v == Color.RED

v = TypeEngine.to_python_value(ctx, lv, str)
assert v
assert v == "red"

with pytest.raises(ValueError):
TypeEngine.to_python_value(ctx, Literal(scalar=Scalar(primitive=Primitive(string_value=str(Color.RED)))), Color)

with pytest.raises(ValueError):
TypeEngine.to_python_value(ctx, Literal(scalar=Scalar(primitive=Primitive(string_value="bad"))), Color)

with pytest.raises(AssertionError):
TypeEngine.to_literal_type(UnsupportedEnumValues)
17 changes: 17 additions & 0 deletions tests/flytekit/unit/models/core/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,20 @@ def test_blob_type():
assert o == o2
assert o2.format == "csv"
assert o2.dimensionality == _types.BlobType.BlobDimensionality.SINGLE


def test_enum_type():
o = _types.EnumType(values=["x", "y"])
assert o.values == ["x", "y"]
v = o.to_flyte_idl()
assert v
assert v.values == ["x", "y"]

o = _types.EnumType.from_flyte_idl(_types_pb2.EnumType(values=["a", "b"]))
assert o.values == ["a", "b"]

o = _types.EnumType(values=None)
assert not o.values
v = o.to_flyte_idl()
assert v
assert not v.values

0 comments on commit 422f437

Please sign in to comment.