diff --git a/plugins/flytekit-hydra/README.md b/plugins/flytekit-hydra/README.md new file mode 100644 index 00000000000..1caae55d17e --- /dev/null +++ b/plugins/flytekit-hydra/README.md @@ -0,0 +1,81 @@ +# Flytekit Hydra Plugin + +Flytekit python natively supports serialization of many data types for exchanging information between tasks. +The Flytekit Hydra Plugin extends these by the `DictConfig` type from the +[OmegaConf package](https://omegaconf.readthedocs.io/) as well as related types +that are being used by the [hydra package](https://hydra.cc/) for configuration management. + + + +## Task example +``` +from dataclasses import dataclass +import flytekitplugins.hydra # noqa F401 +from flytekit import task, workflow +from omegaconf import DictConfig + +@dataclass +class MySimpleConf: + _target_: str = "lightning_module.MyEncoderModule" + learning_rate: float = 0.0001 + +@task +def my_task(cfg: DictConfig) -> None: + print(f"Doing things with {cfg.learning_rate=}") + + +@workflow +def pipeline(cfg: DictConfig) -> None: + my_task(cfg=cfg) + + +if __name__ == "__main__": + import OmegaConf + + cfg = OmegaConf.structured(MySimpleConf) + pipeline(cfg=cfg) +``` + +## Transformer configuration + +The transformer can be set to one of three modes: + +`Dataclass` - This mode should be used with a StructuredConfig and will reconstruct the config from the matching dataclass +during deserialisation in order to make typing information from the dataclass and continued validation thereof available. +This requires the dataclass definition to be available via python import in the Flyte execution environment in which +objects are (de-)serialised. + +`DictConfig` - This mode will deserialize the config into a DictConfig object. In particular, dataclasses are translated +into DictConfig objects and only primitive types are being checked. The definition of underlying dataclasses for +structured configs is only required during the initial serialization for this mode. + +`Auto` - This mode will try to deserialize according to the Dataclass mode and fall back to the DictConfig mode if the +dataclass definition is not available. This is the default mode. + +To set the mode either initialise the transformer with the `mode` argument or set the mode of the config directly: + +```python +from flytekitplugins.hydra.config import SharedConfig, OmegaConfTransformerMode +from flytekitplugins.hydra import DictConfigTransformer + +# Set the mode directly on the transformer +transformer_slim = DictConfigTransformer(mode=OmegaConfTransformerMode.DictConfig) + +# Set the mode directly in the config +SharedConfig.set_mode(OmegaConfTransformerMode.DictConfig) +``` + +```note +Since the DictConfig is flattened and keys transformed into dot notation, the keys of the DictConfig must not contain +dots. +``` + +```note +Warning: This plugin overwrites the default serializer for Enum-objects to also allow for non-string-valued enum definitions. +Please check carefully if existing workflows are compatible with the new version. +``` + +```note +Warning: This plugin attempts serialisation of objects with different transformers. In the process exceptions during +serialisation are suppressed. +``` diff --git a/plugins/flytekit-hydra/flytekitplugins/hydra/__init__.py b/plugins/flytekit-hydra/flytekitplugins/hydra/__init__.py new file mode 100644 index 00000000000..02b4a40bc3a --- /dev/null +++ b/plugins/flytekit-hydra/flytekitplugins/hydra/__init__.py @@ -0,0 +1,3 @@ +from flytekitplugins.hydra.dictconfig_transformer import DictConfigTransformer # noqa: F401 +from flytekitplugins.hydra.extended_enum_transformer import GenericEnumTransformer # noqa: F401 +from flytekitplugins.hydra.listconfig_transformer import ListConfigTransformer # noqa: F401 diff --git a/plugins/flytekit-hydra/flytekitplugins/hydra/config.py b/plugins/flytekit-hydra/flytekitplugins/hydra/config.py new file mode 100644 index 00000000000..7b0a9b1056f --- /dev/null +++ b/plugins/flytekit-hydra/flytekitplugins/hydra/config.py @@ -0,0 +1,30 @@ +from enum import Enum + + +class OmegaConfTransformerMode(Enum): + """ + Operation Mode indicating whether a (potentially unannotated) DictConfig object or a structured config using the + underlying dataclass is returned. + + Note: We define a single shared configs across all transformers as recursive calls should refer to the same config + Note: The latter requires the use of structured configs. + """ + + DictConfig = "DictConfig" + DataClass = "DataClass" + Auto = "Auto" + + +class SharedConfig: + _mode: OmegaConfTransformerMode = OmegaConfTransformerMode.Auto + + @classmethod + def get_mode(cls) -> OmegaConfTransformerMode: + """Get the current mode for serialising omegaconf objects.""" + return cls._mode + + @classmethod + def set_mode(cls, new_mode: OmegaConfTransformerMode) -> None: + """Set the current mode for serialising omegaconf objects.""" + if isinstance(new_mode, OmegaConfTransformerMode): + cls._mode = new_mode diff --git a/plugins/flytekit-hydra/flytekitplugins/hydra/dictconfig_transformer.py b/plugins/flytekit-hydra/flytekitplugins/hydra/dictconfig_transformer.py new file mode 100644 index 00000000000..f27c6855205 --- /dev/null +++ b/plugins/flytekit-hydra/flytekitplugins/hydra/dictconfig_transformer.py @@ -0,0 +1,196 @@ +import importlib +import logging +import re +import traceback +import typing +from typing import Type, TypeVar + +import flatten_dict +import omegaconf +from flyteidl.core.literals_pb2 import Literal as PB_Literal +from flytekit import FlyteContext +from flytekit.core.type_engine import TypeTransformerFailedError +from flytekit.extend import TypeEngine, TypeTransformer +from flytekit.models.literals import Literal, Scalar +from flytekit.models.types import LiteralType, SimpleType +from google.protobuf.json_format import MessageToDict, ParseDict +from google.protobuf.struct_pb2 import Struct +from omegaconf import DictConfig, OmegaConf + +from flytekitplugins.hydra.config import OmegaConfTransformerMode, SharedConfig +from flytekitplugins.hydra.flytekit_patch import iterate_get_transformers +from flytekitplugins.hydra.type_information import extract_node_type + +logger = logging.getLogger(__name__) + +T = TypeVar("T") +NoneType = type(None) + + +def is_flatable(d: typing.Union[DictConfig]) -> bool: + """Check if a DictConfig can be properly flattened and unflattened, i.e. keys do not contain dots.""" + return all( + isinstance(k, str) # keys are strings ... + and "." not in k # ... and do not contain dots + and ( + OmegaConf.is_missing(d, k) # values are either MISSING ... + or not isinstance(d[k], DictConfig) # ... not nested Dictionaries ... + or is_flatable(d[k]) + ) # or flatable themselves + for k in d.keys() + ) + + +class DictConfigTransformer(TypeTransformer[DictConfig]): + def __init__(self, mode: typing.Optional[OmegaConfTransformerMode] = None): + """Construct DictConfigTransformer.""" + super().__init__(name="OmegaConf DictConfig", t=DictConfig) + self.mode = mode + + @property + def mode(self) -> OmegaConfTransformerMode: + """Serialisation mode for omegaconf objects defined in shared config.""" + return SharedConfig.get_mode() + + @mode.setter + def mode(self, new_mode: OmegaConfTransformerMode) -> None: + """Updates the central shared config with a new serialisation mode.""" + SharedConfig.set_mode(new_mode) + + def get_literal_type(self, t: Type[DictConfig]) -> LiteralType: + """ + Provide type hint for Flytekit type system. + + To support the multivariate typing of nodes in a DictConfig, we encode them as binaries (no introspection) + with multiple files. + """ + return LiteralType(simple=SimpleType.STRUCT) + + def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: + """ + Convert from given python type object ``DictConfig`` to the Literal representation. + + Since the DictConfig type does not offer additional type hints for its nodes, typing information is stored + within the literal itself rather than the Flyte LiteralType. + """ + # instead of raising TypeError, raising AssertError so that flytekit can catch it in + # https://github.com/flyteorg/flytekit/blob/60c982e4b065fdb3aba0b957e506f652a2674c00/flytekit/core/ + # type_engine.py#L1222 + assert isinstance(python_val, DictConfig), f"Invalid type {type(python_val)}, can only serialise DictConfigs" + assert is_flatable( + python_val + ), f"{python_val} cannot be flattened as it contains non-string keys or keys containing dots." + + base_config = OmegaConf.get_type(python_val) + type_map = {} + value_map = {} + for key in python_val.keys(): + if OmegaConf.is_missing(python_val, key): + type_map[key] = "MISSING" + else: + node_type, type_name = extract_node_type(python_val, key) + type_map[key] = type_name + + transformation_logs = "" + for transformer in iterate_get_transformers(node_type): + try: + literal_type = transformer.get_literal_type(node_type) + + value_map[key] = MessageToDict( + transformer.to_literal(ctx, python_val[key], node_type, literal_type).to_flyte_idl() + ) + break + except Exception: + transformation_logs += ( + f"Serialisation with transformer {type(transformer)} failed:\n" + f"{traceback.format_exc()}\n\n" + ) + + if key not in value_map.keys(): + raise ValueError( + f"Could not identify matching transformer for object {python_val[key]} of type " + f"{type(python_val[key])}. This may either indicate that no such transformer " + "exists or the appropriate transformer cannot serialise this object. Attempted the following " + f"transformers:\n{transformation_logs}" + ) + + wrapper = Struct() + wrapper.update( + flatten_dict.flatten( + { + "types": type_map, + "values": value_map, + "base_dataclass": f"{base_config.__module__}.{base_config.__name__}", + }, + reducer="dot", + keep_empty_types=(dict,), + ) + ) + return Literal(scalar=Scalar(generic=wrapper)) + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[DictConfig]) -> DictConfig: + """Re-hydrate the custom object from Flyte Literal value.""" + if lv and lv.scalar is not None: + nested_dict = flatten_dict.unflatten(MessageToDict(lv.scalar.generic), splitter="dot") + cfg_dict = {} + for key, type_desc in nested_dict["types"].items(): + if type_desc == "MISSING": + cfg_dict[key] = omegaconf.MISSING + else: + if re.match(r".+\[.*]", type_desc): + origin_module, origin_type = type_desc.split("[")[0].rsplit(".", 1) + origin = importlib.import_module(origin_module).__getattribute__(origin_type) + sub_types = type_desc.split("[")[1][:-1].split(", ") + for i, t in enumerate(sub_types): + if t != "NoneType": + module_name, class_name = t.rsplit(".", 1) + sub_types[i] = importlib.import_module(module_name).__getattribute__(class_name) + else: + sub_types[i] = type(None) + node_type = origin[tuple(sub_types)] + else: + module_name, class_name = type_desc.rsplit(".", 1) + node_type = importlib.import_module(module_name).__getattribute__(class_name) + + transformation_logs = "" + for transformer in iterate_get_transformers(node_type): + try: + value_literal = Literal.from_flyte_idl(ParseDict(nested_dict["values"][key], PB_Literal())) + cfg_dict[key] = transformer.to_python_value(ctx, value_literal, node_type) + break + except Exception: + logger.debug( + f"Serialisation with transformer {type(transformer)} failed:\n" + f"{traceback.format_exc()}\n\n" + ) + transformation_logs += ( + f"Deserialisation with transformer {type(transformer)} failed:\n" + f"{traceback.format_exc()}\n\n" + ) + + if key not in cfg_dict.keys(): + raise ValueError( + f"Could not identify matching transformer for object {nested_dict['values'][key]} of " + f"proposed type {node_type}. This may either indicate that no such transformer " + "exists or the appropriate transformer cannot deserialise this object. Attempted the " + f"following transformers:\n{transformation_logs}" + ) + if nested_dict["base_dataclass"] != "builtins.dict" and self.mode != OmegaConfTransformerMode.DictConfig: + # Explicitly instantiate dataclass and create DictConfig from there in order to have typing information + module_name, class_name = nested_dict["base_dataclass"].rsplit(".", 1) + try: + return OmegaConf.structured( + importlib.import_module(module_name).__getattribute__(class_name)(**cfg_dict) + ) + except (ModuleNotFoundError, AttributeError) as e: + logger.error( + f"Could not import module {module_name}. If you want to deserialise to DictConfig, " + f"set the mode to DictConfigTransformerMode.DictConfig." + ) + if self.mode == OmegaConfTransformerMode.DataClass: + raise e + return OmegaConf.create(cfg_dict) + raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + + +TypeEngine.register(DictConfigTransformer()) diff --git a/plugins/flytekit-hydra/flytekitplugins/hydra/extended_enum_transformer.py b/plugins/flytekit-hydra/flytekitplugins/hydra/extended_enum_transformer.py new file mode 100644 index 00000000000..b19fb879055 --- /dev/null +++ b/plugins/flytekit-hydra/flytekitplugins/hydra/extended_enum_transformer.py @@ -0,0 +1,66 @@ +import enum +import importlib +from typing import Type, TypeVar + +from flytekit import FlyteContext +from flytekit.core.type_engine import TypeTransformer, TypeTransformerFailedError +from flytekit.extend import TypeEngine +from flytekit.models.literals import Literal, LiteralMap, Primitive, Scalar +from flytekit.models.types import LiteralType, SimpleType + +T = TypeVar("T") + + +class GenericEnumTransformer(TypeTransformer[enum.Enum]): + def __init__(self): + """Transformer for arbitrary enum.Enum objects.""" + super().__init__(name="GenericEnumTransformer", t=enum.Enum) + + def get_literal_type(self, t: Type[T]) -> LiteralType: + """ + Provide type hint for Flytekit type system. + + To support the multivariate typing of generic enums, we encode them as binaries (no introspection) with + multiple files. + """ + return LiteralType(simple=SimpleType.STRUCT) + + def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: + """Transform python enum to Flyte literal for serialisation""" + if not isinstance(python_val, enum.Enum): + raise TypeTransformerFailedError("Expected an enum") + + enum_type = type(python_val) + node_type = type(python_val.value) + transformer = TypeEngine.get_transformer(node_type) + literal_type = transformer.get_literal_type(node_type) + + return Literal( + map=LiteralMap( + literals={ + "enum_type": Literal( + scalar=Scalar(primitive=Primitive(string_value=f"{enum_type.__module__}/{enum_type.__name__}")) + ), + "type": Literal( + scalar=Scalar(primitive=Primitive(string_value=f"{node_type.__module__}/{node_type.__name__}")) + ), + "value": transformer.to_literal(ctx, python_val.value, node_type, literal_type), + } + ) + ) + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: + """Transform Literal back to python object (Enum).""" + module_name, class_name = lv.map.literals["type"].scalar.primitive.string_value.split("/") + value_literal = lv.map.literals["value"] + node_type = importlib.import_module(module_name).__getattribute__(class_name) + transformer = TypeEngine.get_transformer(node_type) + base_value = transformer.to_python_value(ctx, value_literal, node_type) + + enum_module, enum_class = lv.map.literals["enum_type"].scalar.primitive.string_value.split("/") + enum_type = importlib.import_module(enum_module).__getattribute__(enum_class) + return enum_type(base_value) + + +TypeEngine._ENUM_TRANSFORMER = GenericEnumTransformer() +TypeEngine.register_additional_type(GenericEnumTransformer(), enum.Enum, override=True) diff --git a/plugins/flytekit-hydra/flytekitplugins/hydra/flytekit_patch.py b/plugins/flytekit-hydra/flytekitplugins/hydra/flytekit_patch.py new file mode 100644 index 00000000000..e6d94e18ea6 --- /dev/null +++ b/plugins/flytekit-hydra/flytekitplugins/hydra/flytekit_patch.py @@ -0,0 +1,64 @@ +import dataclasses +import inspect +import logging +from typing import get_args, Type + +from flytekit.core.type_engine import is_annotated, TypeEngine, TypeTransformer + +logger = logging.getLogger() + + +def iterate_get_transformers(python_type: Type) -> TypeTransformer: + """ + This is a copy of flytekit.core.type_engine.TypeEngine.get_transformer. Instead of simply returning the first + appropriate Transformer, it yields all admissible candidates and leaves identification of the correct one and error + handling to the caller of the transformer itself. + """ + TypeEngine.lazy_import_transformers() + # Step 1 + if is_annotated(python_type): + args = get_args(python_type) + for annotation in args: + if isinstance(annotation, TypeTransformer): + yield annotation + + python_type = args[0] + + if python_type in TypeEngine._REGISTRY: + yield TypeEngine._REGISTRY[python_type] + + # Step 2 + if hasattr(python_type, "__origin__"): + # Handling of annotated generics, eg: + # Annotated[typing.List[int], 'foo'] + if is_annotated(python_type): + yield TypeEngine.get_transformer(get_args(python_type)[0]) + + if python_type.__origin__ in TypeEngine._REGISTRY: + yield TypeEngine._REGISTRY[python_type.__origin__] + + raise ValueError( + f"Could not find suitable transformer for generic type {python_type.__origin__}." + f"Please check the (info-)logs for caught exceptions in serialisation attempts using existing transformers." + f"If no suitable transformer was chosen, this type may not currently be supported in Flytekit." + ) + + # Step 3 + # To facilitate cases where users may specify one transformer for multiple types that all inherit from one + # parent. + for base_type in TypeEngine._REGISTRY.keys(): + if base_type is None: + continue # None is actually one of the keys, but isinstance/issubclass doesn't work on it + try: + if isinstance(python_type, base_type) or ( + inspect.isclass(python_type) and issubclass(python_type, base_type) + ): + yield TypeEngine._REGISTRY[base_type] + except TypeError: + # As of python 3.9, calls to isinstance raise a TypeError if the base type is not a valid type, which + # is the case for one of the restricted types, namely NamedTuple. + logger.debug(f"Invalid base type {base_type} in call to isinstance", exc_info=True) + + # Step 4 + if dataclasses.is_dataclass(python_type): + yield TypeEngine._DATACLASS_TRANSFORMER diff --git a/plugins/flytekit-hydra/flytekitplugins/hydra/listconfig_transformer.py b/plugins/flytekit-hydra/flytekitplugins/hydra/listconfig_transformer.py new file mode 100644 index 00000000000..9f7f0867838 --- /dev/null +++ b/plugins/flytekit-hydra/flytekitplugins/hydra/listconfig_transformer.py @@ -0,0 +1,142 @@ +import importlib +import logging +import traceback +from typing import Optional, Type, TypeVar + +import omegaconf +from flyteidl.core.literals_pb2 import Literal as PB_Literal +from flytekit import FlyteContext +from flytekit.core.type_engine import TypeTransformerFailedError +from flytekit.extend import TypeEngine, TypeTransformer +from flytekit.models.literals import Literal, Primitive, Scalar +from flytekit.models.types import LiteralType, SimpleType +from google.protobuf.json_format import MessageToDict, ParseDict +from google.protobuf.struct_pb2 import Struct +from omegaconf import ListConfig, OmegaConf + +from flytekitplugins.hydra.config import OmegaConfTransformerMode, SharedConfig +from flytekitplugins.hydra.flytekit_patch import iterate_get_transformers +from flytekitplugins.hydra.type_information import extract_node_type + +T = TypeVar("T") +logger = logging.getLogger(__name__) + + +class ListConfigTransformer(TypeTransformer[ListConfig]): + def __init__(self, mode: Optional[OmegaConfTransformerMode] = None): + """Construct ListConfigTransformer.""" + super().__init__(name="OmegaConf ListConfig", t=ListConfig) + self.mode = mode + + @property + def mode(self) -> OmegaConfTransformerMode: + """Serialisation mode for omegaconf objects defined in shared config.""" + return SharedConfig.get_mode() + + @mode.setter + def mode(self, new_mode: OmegaConfTransformerMode) -> None: + """Updates the central shared config with a new serialisation mode.""" + SharedConfig.set_mode(new_mode) + + def get_literal_type(self, t: Type[ListConfig]) -> LiteralType: + """ + Provide type hint for Flytekit type system. + + To support the multivariate typing of nodes in a ListConfig, we encode them as binaries (no introspection) + with multiple files. + """ + return LiteralType(simple=SimpleType.STRUCT) + + def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: + """ + Convert from given python type object ``ListConfig`` to the Literal representation. + + Since the ListConfig type does not offer additional type hints for its nodes, typing information is stored + within the literal itself rather than the Flyte LiteralType. + """ + # instead of raising TypeError, raising AssertError so that flytekit can catch it in + # https://github.com/flyteorg/flytekit/blob/60c982e4b065fdb3aba0b957e506f652a2674c00/flytekit/core/ + # type_engine.py#L1222 + assert isinstance(python_val, ListConfig), f"Invalid type {type(python_val)}, can only serialise ListConfigs" + + type_list = [] + value_list = [] + for idx in range(len(python_val)): + if OmegaConf.is_missing(python_val, idx): + type_list.append("MISSING") + value_list.append( + MessageToDict(Literal(scalar=Scalar(primitive=Primitive(string_value="MISSING"))).to_flyte_idl()) + ) + else: + node_type, type_name = extract_node_type(python_val, idx) + type_list.append(type_name) + + transformation_logs = "" + for transformer in iterate_get_transformers(node_type): + try: + literal_type = transformer.get_literal_type(node_type) + value_list.append( + MessageToDict( + transformer.to_literal(ctx, python_val[idx], node_type, literal_type).to_flyte_idl() + ) + ) + break + except Exception: + transformation_logs += ( + f"Serialisation with transformer {type(transformer)} failed:\n" + f"{traceback.format_exc()}\n\n" + ) + + if len(type_list) != len(value_list): + raise ValueError( + f"Could not identify matching transformer for object {python_val[idx]} of type " + f"{type(python_val[idx])}. This may either indicate that no such transformer " + "exists or the appropriate transformer cannot serialise this object. Attempted the following " + f"transformers:\n{transformation_logs}" + ) + + wrapper = Struct() + wrapper.update({"types": type_list, "values": value_list}) + return Literal(scalar=Scalar(generic=wrapper)) + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[ListConfig]) -> ListConfig: + """Re-hydrate the custom object from Flyte Literal value.""" + if lv and lv.scalar is not None: + MessageToDict(lv.scalar.generic) + + type_list = MessageToDict(lv.scalar.generic)["types"] + value_list = MessageToDict(lv.scalar.generic)["values"] + cfg_literal = [] + for i, type_name in enumerate(type_list): + if type_name == "MISSING": + cfg_literal.append(omegaconf.MISSING) + else: + module_name, class_name = type_name.rsplit(".", 1) + node_type = importlib.import_module(module_name).__getattribute__(class_name) + + value_literal = Literal.from_flyte_idl(ParseDict(value_list[i], PB_Literal())) + + transformation_logs = "" + for transformer in iterate_get_transformers(node_type): + try: + cfg_literal.append(transformer.to_python_value(ctx, value_literal, node_type)) + break + except Exception: + transformation_logs += ( + f"Deserialisation with transformer {type(transformer)} failed:\n" + f"{traceback.format_exc()}\n\n" + ) + + if len(cfg_literal) != i + 1: + raise ValueError( + f"Could not identify matching transformer for object {value_literal[i]} of proposed type " + f"{node_type}. This may either indicate that no such transformer exists or the appropriate " + f"transformer cannot deserialise this object. Attempted the following " + f"transformers:\n{transformation_logs}" + ) + + return OmegaConf.create(cfg_literal) + raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + + +TypeEngine.register(ListConfigTransformer()) diff --git a/plugins/flytekit-hydra/flytekitplugins/hydra/type_information.py b/plugins/flytekit-hydra/flytekitplugins/hydra/type_information.py new file mode 100644 index 00000000000..9e57d327af1 --- /dev/null +++ b/plugins/flytekit-hydra/flytekitplugins/hydra/type_information.py @@ -0,0 +1,123 @@ +import dataclasses +import logging +import typing +from collections import ChainMap + +from dataclasses_json import DataClassJsonMixin +from omegaconf import DictConfig, ListConfig, OmegaConf + +logger = logging.getLogger(__name__) + +NoneType = type(None) + + +def substitute_types(t: typing.Type) -> typing.Type: + """ + Provides a substitute type hint to use when selecting transformers for serialisation. + + Args: + t: Original type + + Returns: + A corrected typehint + """ + + if hasattr(t, "__origin__"): + # Only encode generic type and let appropriate transformer handle the rest + if t.__origin__ in [dict, typing.Dict]: + t = DictConfig + elif t.__origin__ in [list, typing.List]: + t = ListConfig + else: + return t.__origin__ + return t + + +def all_annotations(cls: typing.Type) -> ChainMap: + """ + Returns a dictionary-like ChainMap that includes annotations for all + attributes defined in cls or inherited from superclasses. + """ + return ChainMap(*(c.__annotations__ for c in cls.__mro__ if "__annotations__" in c.__dict__)) + + +def extract_node_type( + python_val: typing.Union[DictConfig, ListConfig], key: typing.Union[str, int] +) -> typing.Tuple[type, str]: + """ + Provides typing information about DictConfig nodes + + Args: + python_val: A DictConfig + key: Key of the node to analyze + mode: The mode in which to serialize type information + + Returns: + Type - The extracted type + str - String representation for (de-)serialisation + + """ + assert isinstance(python_val, DictConfig) or isinstance( + python_val, ListConfig + ), "Can only extract type information from omegaconf objects" + + python_val_node_type = OmegaConf.get_type(python_val) + python_val_annotations = all_annotations(python_val_node_type) + + # Check if type annotations are available + if hasattr(python_val_node_type, "__annotations__"): + if key not in python_val_annotations: + raise ValueError( + f"Key '{key}' not found in type annotations {python_val_annotations}. " + "Check your hydra config for invalid subtrees not covered by your structured config." + ) + + if typing.get_origin(python_val_annotations[key]) is not None: + # Abstract types + origin = typing.get_origin(python_val_annotations[key]) + if getattr(origin, "__name__", None) is not None: + origin_name = f"{origin.__module__}.{origin.__name__}" + elif getattr(origin, "_name", None) is not None: + origin_name = f"{origin.__module__}.{origin._name}" + else: + raise ValueError(f"Could not extract name from origin type {origin}") + + # Replace list and dict with omegaconf types + if origin_name in ["builtins.list", "typing.List"]: + return ListConfig, "omegaconf.listconfig.ListConfig" + elif origin_name in ["builtins.dict", "typing.Dict"]: + return DictConfig, "omegaconf.dictconfig.DictConfig" + + sub_types = [] + sub_type_names = [] + for sub_type in typing.get_args(python_val_annotations[key]): + if sub_type == NoneType: # NoneType gets special treatment as no import exists + sub_types.append(NoneType) + sub_type_names.append("NoneType") + elif dataclasses.is_dataclass(sub_type) and not issubclass(sub_type, DataClassJsonMixin): + # Dataclasses have no matching transformers and get replaced by DictConfig + # alternatively, dataclasses can use dataclass_json decorator + sub_types.append(DictConfig) + sub_type_names.append("omegaconf.dictconfig.DictConfig") + else: + sub_type = substitute_types(sub_type) + sub_types.append(sub_type) + sub_type_names.append(f"{sub_type.__module__}.{sub_type.__name__}") + return origin[tuple(sub_types)], f"{origin_name}[{', '.join(sub_type_names)}]" + elif dataclasses.is_dataclass(python_val_annotations[key]): + # Dataclasses have no matching transformers and get replaced by DictConfig + # alternatively, dataclasses can use dataclass_json decorator + return DictConfig, "omegaconf.dictconfig.DictConfig" + elif python_val_annotations[key] != typing.Any: + # Use (cleaned) annotation if it is meaningful + node_type = substitute_types(python_val_annotations[key]) + type_name = f"{node_type.__module__}.{node_type.__name__}" + return node_type, type_name + + logger.debug( + f"Inferring type information directly from runtime object {python_val[key]} for serialisation purposes. " + "For more stable type resolution and serialisation provide explict type hints." + ) + node_type = type(python_val[key]) + type_name = f"{node_type.__module__}.{node_type.__name__}" + return node_type, type_name diff --git a/plugins/flytekit-hydra/setup.py b/plugins/flytekit-hydra/setup.py new file mode 100644 index 00000000000..b60982f94a0 --- /dev/null +++ b/plugins/flytekit-hydra/setup.py @@ -0,0 +1,39 @@ +from setuptools import setup + +PLUGIN_NAME = "hydra" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.10.0,<2.0.0", "flatten-dict", "hydra-core>=1.2.0"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="Hydra plugin for Flytekit", + url="https://github.com/flyteorg/flytekit/tree/master/plugins/flytekit-hydra", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.8", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, +) diff --git a/plugins/flytekit-hydra/tests/__init__.py b/plugins/flytekit-hydra/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/plugins/flytekit-hydra/tests/conftest.py b/plugins/flytekit-hydra/tests/conftest.py new file mode 100644 index 00000000000..71f91eca986 --- /dev/null +++ b/plugins/flytekit-hydra/tests/conftest.py @@ -0,0 +1,25 @@ +import typing as t +from dataclasses import dataclass, field + + +@dataclass +class ExampleNestedConfig: + nested_int_key: int = 2 + + +@dataclass +class ExampleConfig: + int_key: int = 1337 + union_key: t.Union[int, str] = 1337 + any_key: t.Any = "1337" + optional_key: t.Optional[int] = 1337 + dictconfig_key: ExampleNestedConfig = ExampleNestedConfig() + optional_dictconfig_key: t.Optional[ExampleNestedConfig] = None + listconfig_key: t.List[int] = field(default_factory=lambda: (1, 2, 3)) + tuple_int_key: t.Tuple[int, int] = field(default_factory=lambda: (1, 2)) + + +@dataclass +class ExampleConfigWithNonAnnotatedSubtree: + unnanotated_key = 1 + annotated_key: ExampleNestedConfig = ExampleNestedConfig() diff --git a/plugins/flytekit-hydra/tests/test_extract_node_type.py b/plugins/flytekit-hydra/tests/test_extract_node_type.py new file mode 100644 index 00000000000..70fae0af906 --- /dev/null +++ b/plugins/flytekit-hydra/tests/test_extract_node_type.py @@ -0,0 +1,77 @@ +import builtins +import typing as t + +import pytest +from omegaconf import DictConfig, ListConfig, OmegaConf + +from flytekitplugins.hydra.type_information import extract_node_type +from tests.conftest import ExampleConfig, ExampleConfigWithNonAnnotatedSubtree + + +class TestExtractNodeType: + def test_extract_type_and_string_representation(self) -> None: + """Tests type extraction and string representation.""" + + python_val = OmegaConf.structured(ExampleConfig(union_key="1337", optional_key=None)) + + # test int + node_type, type_name = extract_node_type(python_val, key="int_key") + assert node_type == int + assert type_name == "builtins.int" + + # test union + node_type, type_name = extract_node_type(python_val, key="union_key") + assert node_type == t.Union[int, str] + assert type_name == "typing.Union[builtins.int, builtins.str]" + + # test any + node_type, type_name = extract_node_type(python_val, key="any_key") + assert node_type == str + assert type_name == "builtins.str" + + # test optional + node_type, type_name = extract_node_type(python_val, key="optional_key") + assert node_type == t.Optional[int] + assert type_name == "typing.Union[builtins.int, NoneType]" + + # test dictconfig + node_type, type_name = extract_node_type(python_val, key="dictconfig_key") + assert node_type == DictConfig + assert type_name == "omegaconf.dictconfig.DictConfig" + + # test listconfig + node_type, type_name = extract_node_type(python_val, key="listconfig_key") + assert node_type == ListConfig + assert type_name == "omegaconf.listconfig.ListConfig" + + # test optional dictconfig + node_type, type_name = extract_node_type(python_val, key="optional_dictconfig_key") + assert node_type == t.Optional[DictConfig] + assert type_name == "typing.Union[omegaconf.dictconfig.DictConfig, NoneType]" + + # test int tuple + node_type, type_name = extract_node_type(python_val, key="tuple_int_key") + assert node_type == builtins.tuple[int, int] + assert type_name == "builtins.tuple[builtins.int, builtins.int]" + + def test_raises_nonannotated_subtree(self) -> None: + """Test that trying to infer type of a non-annotated subtree raises an error.""" + + python_val = OmegaConf.structured(ExampleConfigWithNonAnnotatedSubtree()) + node_type, type_name = extract_node_type(python_val, key="annotated_key") + assert node_type == DictConfig + + # When we try to infer unnanotated subtree combined with typed subtree, we should raise + with pytest.raises(ValueError): + extract_node_type(python_val, "unnanotated_key") + + def test_single_unnanotated_node(self) -> None: + """Test that inferring a fully unnanotated node works by infering types from runtime values.""" + + python_val = OmegaConf.create({"unannotated_dictconfig_key": {"unnanotated_int_key": 2}}) + node_type, type_name = extract_node_type(python_val, key="unannotated_dictconfig_key") + assert node_type == DictConfig + + python_val = python_val.unannotated_dictconfig_key + node_type, type_name = extract_node_type(python_val, key="unnanotated_int_key") + assert node_type == int diff --git a/plugins/flytekit-hydra/tests/test_objects.py b/plugins/flytekit-hydra/tests/test_objects.py new file mode 100644 index 00000000000..5c0f15eff16 --- /dev/null +++ b/plugins/flytekit-hydra/tests/test_objects.py @@ -0,0 +1,59 @@ +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Optional, Union + +from omegaconf import MISSING, OmegaConf + + +class MyEnum(Enum): + val1 = "str_val" + val2 = 123 + val3 = 123.3 + val4 = True + + +class MultiTypeEnum(str, Enum): + fifo = "fifo" # first in first out + filo = "filo" # first in last out + + +@dataclass +class MySubConf: + my_attr: Optional[Union[int, str]] = 1 + list_attr: List[int] = field(default_factory=list) + + +@dataclass +class MyConf: + my_attr: Optional[MySubConf] = None + + +class SpecialConf(MyConf): + key: int = 1 + + +TEST_CFG = OmegaConf.create( + { + "a": 1, + "b": 1.0, + "c": { + "d": 1, + "e": MISSING, + "f": [ + { + "g": 2, + "h": 1.2, + }, + {"j": 0.5, "k": "foo", "l": "bar"}, + ], + }, + "en": { + "a": MyEnum.val1, + "b": MyEnum.val2, + "c": [ + MyEnum.val3, + {"b": MyEnum.val4}, + ], + }, + } +) diff --git a/plugins/flytekit-hydra/tests/test_plugin.py b/plugins/flytekit-hydra/tests/test_plugin.py new file mode 100644 index 00000000000..5ee3062b48a --- /dev/null +++ b/plugins/flytekit-hydra/tests/test_plugin.py @@ -0,0 +1,193 @@ +from typing import Any + +import flytekitplugins.hydra # noqa +import pytest +from flyteidl.core.literals_pb2 import Literal, Scalar +from flytekitplugins.hydra import DictConfigTransformer +from flytekitplugins.hydra.config import OmegaConfTransformerMode, SharedConfig +from google.protobuf.struct_pb2 import Struct +from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf, ValidationError +from pytest import mark, param + +from flytekit import FlyteContext +from flytekit.core.type_engine import TypeEngine +from tests.test_objects import TEST_CFG, MultiTypeEnum, MyConf, MyEnum, MySubConf, SpecialConf + + +@mark.parametrize( + ("obj"), + [ + param( + DictConfig({}), + ), + param( + DictConfig({"a": "b"}), + ), + param( + DictConfig({"a": 1}), + ), + param( + DictConfig({"a": MISSING}), + ), + param( + DictConfig({"tuple": (1, 2, 3)}), + ), + param( + ListConfig(["a", "b"]), + ), + param( + ListConfig(["a", MISSING]), + ), + param( + MyEnum.val1, + ), + param( + MyEnum.val2, + ), + param( + TEST_CFG, + ), + param( + DictConfig({"foo": MultiTypeEnum.fifo}), + ), + param( + DictConfig({"foo": [MultiTypeEnum.fifo]}), + ), + param(DictConfig({"cfgs": [MySubConf(1), MySubConf("a"), "arg"]})), + param(OmegaConf.structured(SpecialConf)), + ], +) +def test_cfg_roundtrip(obj: Any) -> None: + """Test casting DictConfig object to flyte literal and back.""" + + ctx = FlyteContext.current_context() + expected = TypeEngine.to_literal_type(type(obj)) + transformer = TypeEngine.get_transformer(type(obj)) + + assert ( + isinstance(transformer, flytekitplugins.hydra.dictconfig_transformer.DictConfigTransformer) + or isinstance(transformer, flytekitplugins.hydra.listconfig_transformer.ListConfigTransformer) + or isinstance(transformer, flytekitplugins.hydra.extended_enum_transformer.GenericEnumTransformer) + ) + + literal = transformer.to_literal(ctx, obj, type(obj), expected) + reconstructed = transformer.to_python_value(ctx, literal, type(obj)) + assert obj == reconstructed + + +def test_optional_type() -> None: + """ + Test serialisation of DictConfigs with various optional entries, whose real types are provided by underlying + dataclasses. + """ + optional_obj: DictConfig = OmegaConf.structured(MySubConf()) + optional_obj1: DictConfig = OmegaConf.structured(MyConf(my_attr=MySubConf())) + optional_obj2: DictConfig = OmegaConf.structured(MyConf()) + + ctx = FlyteContext.current_context() + expected = TypeEngine.to_literal_type(DictConfig) + transformer = TypeEngine.get_transformer(DictConfig) + + literal = transformer.to_literal(ctx, optional_obj, DictConfig, expected) + recon = transformer.to_python_value(ctx, literal, DictConfig) + assert recon == optional_obj + + literal1 = transformer.to_literal(ctx, optional_obj1, DictConfig, expected) + recon1 = transformer.to_python_value(ctx, literal1, DictConfig) + assert recon1 == optional_obj1 + + literal2 = transformer.to_literal(ctx, optional_obj2, DictConfig, expected) + recon2 = transformer.to_python_value(ctx, literal2, DictConfig) + assert recon2 == optional_obj2 + + +def test_plugin_mode() -> None: + """Test serialisation with different plugin modes configured.""" + obj = OmegaConf.structured(MyConf(my_attr=MySubConf())) + + ctx = FlyteContext.current_context() + expected = TypeEngine.to_literal_type(DictConfig) + + transformer = DictConfigTransformer(mode=OmegaConfTransformerMode.DictConfig) + literal_slim = transformer.to_literal(ctx, obj, DictConfig, expected) + reconstructed_slim = transformer.to_python_value(ctx, literal_slim, DictConfig) + + SharedConfig.set_mode(OmegaConfTransformerMode.DataClass) + literal_full = transformer.to_literal(ctx, obj, DictConfig, expected) + reconstructed_full = transformer.to_python_value(ctx, literal_full, DictConfig) + + SharedConfig.set_mode(OmegaConfTransformerMode.Auto) + literal_semi = transformer.to_literal(ctx, obj, DictConfig, expected) + reconstructed_semi = transformer.to_python_value(ctx, literal_semi, DictConfig) + + assert literal_slim == literal_full == literal_semi + assert reconstructed_slim == reconstructed_full == reconstructed_semi # comparison by value should pass + + assert OmegaConf.get_type(reconstructed_slim, "my_attr") == dict + assert OmegaConf.get_type(reconstructed_semi, "my_attr") == MySubConf + assert OmegaConf.get_type(reconstructed_full, "my_attr") == MySubConf + + reconstructed_slim.my_attr.my_attr = (1,) # assign a tuple value to Union[int, str] field + with pytest.raises(ValidationError): + reconstructed_semi.my_attr.my_attr = (1,) + with pytest.raises(ValidationError): + reconstructed_full.my_attr.my_attr = (1,) + + +def test_fallback() -> None: + """Test if fallback mode recovers basic information if the specified type cannot be found.""" + obj = OmegaConf.structured(MyConf(my_attr=MySubConf())) + + struct = Struct() + struct.update( + { + "values.my_attr.scalar.union.value.scalar.generic.values.my_attr.scalar.union.value.scalar.primitive.integer": 1, # noqa: E501 + "values.my_attr.scalar.union.value.scalar.generic.values.my_attr.scalar.union.type.structure.tag": "int", + "values.my_attr.scalar.union.value.scalar.generic.values.my_attr.scalar.union.type.simple": "INTEGER", + "values.my_attr.scalar.union.value.scalar.generic.values.list_attr.scalar.generic.values": [], + "values.my_attr.scalar.union.value.scalar.generic.values.list_attr.scalar.generic.types": [], + "values.my_attr.scalar.union.value.scalar.generic.types.my_attr": "typing.Union[builtins.int, builtins.str, NoneType]", # noqa: E501 + "values.my_attr.scalar.union.value.scalar.generic.types.list_attr": "omegaconf.listconfig.ListConfig", + "values.my_attr.scalar.union.value.scalar.generic.base_dataclass": "tests.test_objects.MySubConf", + "values.my_attr.scalar.union.type.structure.tag": "OmegaConf DictConfig", + "values.my_attr.scalar.union.type.simple": "STRUCT", + "types.my_attr": "typing.Union[omegaconf.dictconfig.DictConfig, NoneType]", + "base_dataclass": "tests.test_objects.MyConf", + } + ) + literal = Literal(scalar=Scalar(generic=struct)) + + # construct a literal with an unknown subconfig type + struct2 = Struct() + struct2.update( + { + "values.my_attr.scalar.union.value.scalar.generic.values.my_attr.scalar.union.value.scalar.primitive.integer": 1, # noqa: E501 + "values.my_attr.scalar.union.value.scalar.generic.values.my_attr.scalar.union.type.structure.tag": "int", + "values.my_attr.scalar.union.value.scalar.generic.values.my_attr.scalar.union.type.simple": "INTEGER", + "values.my_attr.scalar.union.value.scalar.generic.values.list_attr.scalar.generic.values": [], + "values.my_attr.scalar.union.value.scalar.generic.values.list_attr.scalar.generic.types": [], + "values.my_attr.scalar.union.value.scalar.generic.types.my_attr": "typing.Union[builtins.int, builtins.str, NoneType]", # noqa: E501 + "values.my_attr.scalar.union.value.scalar.generic.types.list_attr": "omegaconf.listconfig.ListConfig", + "values.my_attr.scalar.union.value.scalar.generic.base_dataclass": "tests.test_objects.MyFooConf", + "values.my_attr.scalar.union.type.structure.tag": "OmegaConf DictConfig", + "values.my_attr.scalar.union.type.simple": "STRUCT", + "types.my_attr": "typing.Union[omegaconf.dictconfig.DictConfig, NoneType]", + "base_dataclass": "tests.test_objects.MyConf", + } + ) + literal2 = Literal(scalar=Scalar(generic=struct2)) + + ctx = FlyteContext.current_context() + transformer = DictConfigTransformer(mode=OmegaConfTransformerMode.DictConfig) + SharedConfig.set_mode(OmegaConfTransformerMode.Auto) + + reconstructed = transformer.to_python_value(ctx, literal, DictConfig) + assert obj == reconstructed + + part_reconstructed = transformer.to_python_value(ctx, literal2, DictConfig) + assert obj == part_reconstructed + assert OmegaConf.get_type(part_reconstructed, "my_attr") == dict + + part_reconstructed.my_attr.my_attr = (1,) # assign a tuple value to Union[int, str] field + with pytest.raises(ValidationError): + reconstructed.my_attr.my_attr = (1,)