diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 19b77a94ed..c973aee3e2 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -346,6 +346,7 @@ jobs: # onnx-tensorflow needs a version of tensorflow that does not work with protobuf>4. # The issue is being tracked on the tensorflow side in https://github.com/tensorflow/tensorflow/issues/53234#issuecomment-1330111693 # flytekit-onnx-tensorflow + - flytekit-omegaconf - flytekit-openai - flytekit-pandera - flytekit-papermill diff --git a/plugins/flytekit-omegaconf/README.md b/plugins/flytekit-omegaconf/README.md new file mode 100644 index 0000000000..cddd406b31 --- /dev/null +++ b/plugins/flytekit-omegaconf/README.md @@ -0,0 +1,69 @@ +# Flytekit OmegaConf Plugin + +Flytekit python natively supports serialization of many data types for exchanging information between tasks. +The Flytekit OmegaConf Plugin extends these by the `DictConfig` type from the +[OmegaConf package](https://omegaconf.readthedocs.io/) as well as related types +that are being used by the [hydra package](https://hydra.cc/) for configuration management. + +## Task example +``` +from dataclasses import dataclass +import flytekitplugins.omegaconf # noqa F401 +from flytekit import task, workflow +from omegaconf import DictConfig + +@dataclass +class MySimpleConf: + _target_: str = "lightning_module.MyEncoderModule" + learning_rate: float = 0.0001 + +@task +def my_task(cfg: DictConfig) -> None: + print(f"Doing things with {cfg.learning_rate=}") + + +@workflow +def pipeline(cfg: DictConfig) -> None: + my_task(cfg=cfg) + + +if __name__ == "__main__": + from omegaconf import OmegaConf + + cfg = OmegaConf.structured(MySimpleConf) + pipeline(cfg=cfg) +``` + +## Transformer configuration + +The transformer can be set to one of three modes: + +`Dataclass` - This mode should be used with a StructuredConfig and will reconstruct the config from the matching dataclass +during deserialisation in order to make typing information from the dataclass and continued validation thereof available. +This requires the dataclass definition to be available via python import in the Flyte execution environment in which +objects are (de-)serialised. + +`DictConfig` - This mode will deserialize the config into a DictConfig object. In particular, dataclasses are translated +into DictConfig objects and only primitive types are being checked. The definition of underlying dataclasses for +structured configs is only required during the initial serialization for this mode. + +`Auto` - This mode will try to deserialize according to the Dataclass mode and fall back to the DictConfig mode if the +dataclass definition is not available. This is the default mode. + +You can set the transformer mode globally or for the current context only the following ways: +```python +from flytekitplugins.omegaconf import set_transformer_mode, set_local_transformer_mode, OmegaConfTransformerMode + +# Set the global transformer mode using the new function +set_transformer_mode(OmegaConfTransformerMode.DictConfig) + +# You can also the mode for the current context only +with set_local_transformer_mode(OmegaConfTransformerMode.Dataclass): + # This will use the Dataclass mode + pass +``` + +```note +Since the DictConfig is flattened and keys transformed into dot notation, the keys of the DictConfig must not contain +dots. +``` diff --git a/plugins/flytekit-omegaconf/flytekitplugins/omegaconf/__init__.py b/plugins/flytekit-omegaconf/flytekitplugins/omegaconf/__init__.py new file mode 100644 index 0000000000..87e2fb8943 --- /dev/null +++ b/plugins/flytekit-omegaconf/flytekitplugins/omegaconf/__init__.py @@ -0,0 +1,33 @@ +from contextlib import contextmanager + +from flytekitplugins.omegaconf.config import OmegaConfTransformerMode +from flytekitplugins.omegaconf.dictconfig_transformer import DictConfigTransformer # noqa: F401 +from flytekitplugins.omegaconf.listconfig_transformer import ListConfigTransformer # noqa: F401 + +_TRANSFORMER_MODE = OmegaConfTransformerMode.Auto + + +def set_transformer_mode(mode: OmegaConfTransformerMode) -> None: + """Set the global serialization mode for OmegaConf objects.""" + global _TRANSFORMER_MODE + _TRANSFORMER_MODE = mode + + +def get_transformer_mode() -> OmegaConfTransformerMode: + """Get the global serialization mode for OmegaConf objects.""" + return _TRANSFORMER_MODE + + +@contextmanager +def local_transformer_mode(mode: OmegaConfTransformerMode): + """Context manager to set a local serialization mode for OmegaConf objects.""" + global _TRANSFORMER_MODE + previous_mode = _TRANSFORMER_MODE + set_transformer_mode(mode) + try: + yield + finally: + set_transformer_mode(previous_mode) + + +__all__ = ["set_transformer_mode", "get_transformer_mode", "local_transformer_mode", "OmegaConfTransformerMode"] diff --git a/plugins/flytekit-omegaconf/flytekitplugins/omegaconf/config.py b/plugins/flytekit-omegaconf/flytekitplugins/omegaconf/config.py new file mode 100644 index 0000000000..5006d5b854 --- /dev/null +++ b/plugins/flytekit-omegaconf/flytekitplugins/omegaconf/config.py @@ -0,0 +1,15 @@ +from enum import Enum + + +class OmegaConfTransformerMode(Enum): + """ + Operation Mode indicating whether a (potentially unannotated) DictConfig object or a structured config using the + underlying dataclass is returned. + + Note: We define a single shared config across all transformers as recursive calls should refer to the same config + Note: The latter requires the use of structured configs. + """ + + DictConfig = "DictConfig" + DataClass = "DataClass" + Auto = "Auto" diff --git a/plugins/flytekit-omegaconf/flytekitplugins/omegaconf/dictconfig_transformer.py b/plugins/flytekit-omegaconf/flytekitplugins/omegaconf/dictconfig_transformer.py new file mode 100644 index 0000000000..0f2b8c63cc --- /dev/null +++ b/plugins/flytekit-omegaconf/flytekitplugins/omegaconf/dictconfig_transformer.py @@ -0,0 +1,181 @@ +import importlib +import re +from typing import Any, Dict, Type, TypeVar + +import flatten_dict +import flytekitplugins.omegaconf +from flyteidl.core.literals_pb2 import Literal as PB_Literal +from flytekitplugins.omegaconf.config import OmegaConfTransformerMode +from flytekitplugins.omegaconf.type_information import extract_node_type +from google.protobuf.json_format import MessageToDict, ParseDict +from google.protobuf.struct_pb2 import Struct + +import omegaconf +from flytekit import FlyteContext +from flytekit.core.type_engine import TypeTransformerFailedError +from flytekit.extend import TypeEngine, TypeTransformer +from flytekit.loggers import logger +from flytekit.models.literals import Literal, Scalar +from flytekit.models.types import LiteralType, SimpleType +from omegaconf import DictConfig, OmegaConf + +T = TypeVar("T") +NoneType = type(None) + + +class DictConfigTransformer(TypeTransformer[DictConfig]): + def __init__(self): + """Construct DictConfigTransformer.""" + super().__init__(name="OmegaConf DictConfig", t=DictConfig) + + def get_literal_type(self, t: Type[DictConfig]) -> LiteralType: + """ + Provide type hint for Flytekit type system. + + To support the multivariate typing of nodes in a DictConfig, we encode them as binaries (no introspection) + with multiple files. + """ + return LiteralType(simple=SimpleType.STRUCT) + + def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: + """Convert from given python type object ``DictConfig`` to the Literal representation.""" + check_if_valid_dictconfig(python_val) + + base_config = OmegaConf.get_type(python_val) + type_map, value_map = extract_type_and_value_maps(ctx, python_val) + wrapper = create_struct(type_map, value_map, base_config) + + return Literal(scalar=Scalar(generic=wrapper)) + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[DictConfig]) -> DictConfig: + """Re-hydrate the custom object from Flyte Literal value.""" + if lv and lv.scalar is not None: + nested_dict = flatten_dict.unflatten(MessageToDict(lv.scalar.generic), splitter="dot") + cfg_dict = {} + for key, type_desc in nested_dict["types"].items(): + cfg_dict[key] = parse_node_value(ctx, key, type_desc, nested_dict) + + return handle_base_dataclass(ctx, nested_dict, cfg_dict) + raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + + +def is_flattenable(d: DictConfig) -> bool: + """Check if a DictConfig can be properly flattened and unflattened, i.e. keys do not contain dots.""" + return all( + isinstance(k, str) # keys are strings ... + and "." not in k # ... and do not contain dots + and ( + OmegaConf.is_missing(d, k) # values are either MISSING ... + or not isinstance(d[k], DictConfig) # ... not nested Dictionaries ... + or is_flattenable(d[k]) + ) # or flattenable themselves + for k in d.keys() + ) + + +def check_if_valid_dictconfig(python_val: DictConfig) -> None: + """Validate the DictConfig to ensure it's serializable.""" + if not isinstance(python_val, DictConfig): + raise ValueError(f"Invalid type {type(python_val)}, can only serialize DictConfigs") + if not is_flattenable(python_val): + raise ValueError(f"{python_val} cannot be flattened as it contains non-string keys or keys containing dots.") + + +def extract_type_and_value_maps(ctx: FlyteContext, python_val: DictConfig) -> (Dict[str, str], Dict[str, Any]): + """Extract type and value maps from the DictConfig.""" + type_map = {} + value_map = {} + for key in python_val.keys(): + if OmegaConf.is_missing(python_val, key): + type_map[key] = "MISSING" + else: + node_type, type_name = extract_node_type(python_val, key) + type_map[key] = type_name + + transformer = TypeEngine.get_transformer(node_type) + literal_type = transformer.get_literal_type(node_type) + + value_map[key] = MessageToDict( + transformer.to_literal(ctx, python_val[key], node_type, literal_type).to_flyte_idl() + ) + return type_map, value_map + + +def create_struct(type_map: Dict[str, str], value_map: Dict[str, Any], base_config: Type) -> Struct: + """Create a protobuf Struct object from type and value maps.""" + wrapper = Struct() + wrapper.update( + flatten_dict.flatten( + { + "types": type_map, + "values": value_map, + "base_dataclass": f"{base_config.__module__}.{base_config.__name__}", + }, + reducer="dot", + keep_empty_types=(dict,), + ) + ) + return wrapper + + +def parse_type_description(type_desc: str) -> Type: + """Parse the type description and return the corresponding type.""" + generic_pattern = re.compile(r"(?P[^\[\]]+)\[(?P[^\[\]]+)\]") + match = generic_pattern.match(type_desc) + + if match: + origin_type = match.group("type") + args = match.group("args").split(", ") + + origin_module, origin_class = origin_type.rsplit(".", 1) + origin = importlib.import_module(origin_module).__getattribute__(origin_class) + + sub_types = [] + for arg in args: + if arg == "NoneType": + sub_types.append(type(None)) + else: + module_name, class_name = arg.rsplit(".", 1) + sub_type = importlib.import_module(module_name).__getattribute__(class_name) + sub_types.append(sub_type) + + if origin_class == "Optional": + return origin[sub_types[0]] + return origin[tuple(sub_types)] + else: + module_name, class_name = type_desc.rsplit(".", 1) + return importlib.import_module(module_name).__getattribute__(class_name) + + +def parse_node_value(ctx: FlyteContext, key: str, type_desc: str, nested_dict: Dict[str, Any]) -> Any: + """Parse the node value from the nested dictionary.""" + if type_desc == "MISSING": + return omegaconf.MISSING + + node_type = parse_type_description(type_desc) + transformer = TypeEngine.get_transformer(node_type) + value_literal = Literal.from_flyte_idl(ParseDict(nested_dict["values"][key], PB_Literal())) + return transformer.to_python_value(ctx, value_literal, node_type) + + +def handle_base_dataclass(ctx: FlyteContext, nested_dict: Dict[str, Any], cfg_dict: Dict[str, Any]) -> DictConfig: + """Handle the base dataclass and create the DictConfig.""" + if ( + nested_dict["base_dataclass"] != "builtins.dict" + and flytekitplugins.omegaconf.get_transformer_mode() != OmegaConfTransformerMode.DictConfig + ): + # Explicitly instantiate dataclass and create DictConfig from there in order to have typing information + module_name, class_name = nested_dict["base_dataclass"].rsplit(".", 1) + try: + return OmegaConf.structured(importlib.import_module(module_name).__getattribute__(class_name)(**cfg_dict)) + except (ModuleNotFoundError, AttributeError) as e: + logger.error( + f"Could not import module {module_name}. If you want to deserialise to DictConfig, " + f"set the mode to DictConfigTransformerMode.DictConfig." + ) + if flytekitplugins.omegaconf.get_transformer_mode() == OmegaConfTransformerMode.DataClass: + raise e + return OmegaConf.create(cfg_dict) + + +TypeEngine.register(DictConfigTransformer()) diff --git a/plugins/flytekit-omegaconf/flytekitplugins/omegaconf/listconfig_transformer.py b/plugins/flytekit-omegaconf/flytekitplugins/omegaconf/listconfig_transformer.py new file mode 100644 index 0000000000..8652facbad --- /dev/null +++ b/plugins/flytekit-omegaconf/flytekitplugins/omegaconf/listconfig_transformer.py @@ -0,0 +1,92 @@ +import importlib +from typing import Type, TypeVar + +from flyteidl.core.literals_pb2 import Literal as PB_Literal +from flytekitplugins.omegaconf.type_information import extract_node_type +from google.protobuf.json_format import MessageToDict, ParseDict +from google.protobuf.struct_pb2 import Struct + +import omegaconf +from flytekit import FlyteContext +from flytekit.core.type_engine import TypeTransformerFailedError +from flytekit.extend import TypeEngine, TypeTransformer +from flytekit.models.literals import Literal, Primitive, Scalar +from flytekit.models.types import LiteralType, SimpleType +from omegaconf import ListConfig, OmegaConf + +T = TypeVar("T") + + +class ListConfigTransformer(TypeTransformer[ListConfig]): + def __init__(self): + """Construct ListConfigTransformer.""" + super().__init__(name="OmegaConf ListConfig", t=ListConfig) + + def get_literal_type(self, t: Type[ListConfig]) -> LiteralType: + """ + Provide type hint for Flytekit type system. + + To support the multivariate typing of nodes in a ListConfig, we encode them as binaries (no introspection) + with multiple files. + """ + return LiteralType(simple=SimpleType.STRUCT) + + def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: + """ + Convert from given python type object ``ListConfig`` to the Literal representation. + + Since the ListConfig type does not offer additional type hints for its nodes, typing information is stored + within the literal itself rather than the Flyte LiteralType. + """ + # instead of raising TypeError, raising AssertError so that flytekit can catch it in + # https://github.com/flyteorg/flytekit/blob/60c982e4b065fdb3aba0b957e506f652a2674c00/flytekit/core/ + # type_engine.py#L1222 + assert isinstance(python_val, ListConfig), f"Invalid type {type(python_val)}, can only serialise ListConfigs" + + type_list = [] + value_list = [] + for idx in range(len(python_val)): + if OmegaConf.is_missing(python_val, idx): + type_list.append("MISSING") + value_list.append( + MessageToDict(Literal(scalar=Scalar(primitive=Primitive(string_value="MISSING"))).to_flyte_idl()) + ) + else: + node_type, type_name = extract_node_type(python_val, idx) + type_list.append(type_name) + + transformer = TypeEngine.get_transformer(node_type) + literal_type = transformer.get_literal_type(node_type) + value_list.append( + MessageToDict(transformer.to_literal(ctx, python_val[idx], node_type, literal_type).to_flyte_idl()) + ) + + wrapper = Struct() + wrapper.update({"types": type_list, "values": value_list}) + return Literal(scalar=Scalar(generic=wrapper)) + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[ListConfig]) -> ListConfig: + """Re-hydrate the custom object from Flyte Literal value.""" + if lv and lv.scalar is not None: + MessageToDict(lv.scalar.generic) + + type_list = MessageToDict(lv.scalar.generic)["types"] + value_list = MessageToDict(lv.scalar.generic)["values"] + cfg_literal = [] + for i, type_name in enumerate(type_list): + if type_name == "MISSING": + cfg_literal.append(omegaconf.MISSING) + else: + module_name, class_name = type_name.rsplit(".", 1) + node_type = importlib.import_module(module_name).__getattribute__(class_name) + + value_literal = Literal.from_flyte_idl(ParseDict(value_list[i], PB_Literal())) + + transformer = TypeEngine.get_transformer(node_type) + cfg_literal.append(transformer.to_python_value(ctx, value_literal, node_type)) + + return OmegaConf.create(cfg_literal) + raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + + +TypeEngine.register(ListConfigTransformer()) diff --git a/plugins/flytekit-omegaconf/flytekitplugins/omegaconf/type_information.py b/plugins/flytekit-omegaconf/flytekitplugins/omegaconf/type_information.py new file mode 100644 index 0000000000..b6a7b247e6 --- /dev/null +++ b/plugins/flytekit-omegaconf/flytekitplugins/omegaconf/type_information.py @@ -0,0 +1,114 @@ +import dataclasses +import typing +from collections import ChainMap + +from dataclasses_json import DataClassJsonMixin + +from flytekit.loggers import logger +from omegaconf import DictConfig, ListConfig, OmegaConf + +NoneType = type(None) + + +def substitute_types(t: typing.Type) -> typing.Type: + """ + Provides a substitute type hint to use when selecting transformers for serialisation. + + :param t: Original type + :return: A corrected typehint + """ + if hasattr(t, "__origin__"): + # Only encode generic type and let appropriate transformer handle the rest + if t.__origin__ in [dict, typing.Dict]: + t = DictConfig + elif t.__origin__ in [list, typing.List]: + t = ListConfig + else: + return t.__origin__ + return t + + +def all_annotations(cls: typing.Type) -> ChainMap: + """ + Returns a dictionary-like ChainMap that includes annotations for all + attributes defined in cls or inherited from superclasses. + """ + return ChainMap(*(c.__annotations__ for c in cls.__mro__ if "__annotations__" in c.__dict__)) + + +def extract_node_type( + python_val: typing.Union[DictConfig, ListConfig], key: typing.Union[str, int] +) -> typing.Tuple[type, str]: + """ + Provides typing information about DictConfig nodes + + :param python_val: A DictConfig + :param key: Key of the node to analyze + :return: + - Type - The extracted type + - str - String representation for (de-)serialisation + """ + assert isinstance(python_val, DictConfig) or isinstance( + python_val, ListConfig + ), "Can only extract type information from omegaconf objects" + + python_val_node_type = OmegaConf.get_type(python_val) + python_val_annotations = all_annotations(python_val_node_type) + + # Check if type annotations are available + if hasattr(python_val_node_type, "__annotations__"): + if key not in python_val_annotations: + raise ValueError( + f"Key '{key}' not found in type annotations {python_val_annotations}. " + "Check your DictConfig object for invalid subtrees not covered by your structured config." + ) + + if typing.get_origin(python_val_annotations[key]) is not None: + # Abstract types + origin = typing.get_origin(python_val_annotations[key]) + if getattr(origin, "__name__", None) is not None: + origin_name = f"{origin.__module__}.{origin.__name__}" + elif getattr(origin, "_name", None) is not None: + origin_name = f"{origin.__module__}.{origin._name}" + else: + raise ValueError(f"Could not extract name from origin type {origin}") + + # Replace list and dict with omegaconf types + if origin_name in ["builtins.list", "typing.List"]: + return ListConfig, "omegaconf.listconfig.ListConfig" + elif origin_name in ["builtins.dict", "typing.Dict"]: + return DictConfig, "omegaconf.dictconfig.DictConfig" + + sub_types = [] + sub_type_names = [] + for sub_type in typing.get_args(python_val_annotations[key]): + if sub_type == NoneType: # NoneType gets special treatment as no import exists + sub_types.append(NoneType) + sub_type_names.append("NoneType") + elif dataclasses.is_dataclass(sub_type) and not issubclass(sub_type, DataClassJsonMixin): + # Dataclasses have no matching transformers and get replaced by DictConfig + # alternatively, dataclasses can use dataclass_json decorator + sub_types.append(DictConfig) + sub_type_names.append("omegaconf.dictconfig.DictConfig") + else: + sub_type = substitute_types(sub_type) + sub_types.append(sub_type) + sub_type_names.append(f"{sub_type.__module__}.{sub_type.__name__}") + return origin[tuple(sub_types)], f"{origin_name}[{', '.join(sub_type_names)}]" + elif dataclasses.is_dataclass(python_val_annotations[key]): + # Dataclasses have no matching transformers and get replaced by DictConfig + # alternatively, dataclasses can use dataclass_json decorator + return DictConfig, "omegaconf.dictconfig.DictConfig" + elif python_val_annotations[key] != typing.Any: + # Use (cleaned) annotation if it is meaningful + node_type = substitute_types(python_val_annotations[key]) + type_name = f"{node_type.__module__}.{node_type.__name__}" + return node_type, type_name + + logger.debug( + f"Inferring type information directly from runtime object {python_val[key]} for serialisation purposes. " + "For more stable type resolution and serialisation provide explicit type hints." + ) + node_type = type(python_val[key]) + type_name = f"{node_type.__module__}.{node_type.__name__}" + return node_type, type_name diff --git a/plugins/flytekit-omegaconf/setup.py b/plugins/flytekit-omegaconf/setup.py new file mode 100644 index 0000000000..3f57594a15 --- /dev/null +++ b/plugins/flytekit-omegaconf/setup.py @@ -0,0 +1,41 @@ +from setuptools import setup + +PLUGIN_NAME = "omegaconf" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.10.0,<2.0.0", "flatten-dict", "omegaconf>=2.3.0"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="OmegaConf plugin for Flytekit", + url="https://github.com/flyteorg/flytekit/tree/master/plugins/flytekit-omegaconf", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.8", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, +) diff --git a/plugins/flytekit-omegaconf/tests/__init__.py b/plugins/flytekit-omegaconf/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-omegaconf/tests/conftest.py b/plugins/flytekit-omegaconf/tests/conftest.py new file mode 100644 index 0000000000..a3c260e4a1 --- /dev/null +++ b/plugins/flytekit-omegaconf/tests/conftest.py @@ -0,0 +1,24 @@ +import typing as t +from dataclasses import dataclass, field + + +@dataclass +class ExampleNestedConfig: + nested_int_key: int = 2 + + +@dataclass +class ExampleConfig: + int_key: int = 1337 + union_key: t.Union[int, str] = 1337 + any_key: t.Any = "1337" + optional_key: t.Optional[int] = 1337 + dictconfig_key: ExampleNestedConfig = field(default_factory=ExampleNestedConfig) + optional_dictconfig_key: t.Optional[ExampleNestedConfig] = None + listconfig_key: t.List[int] = field(default_factory=lambda: (1, 2, 3)) + + +@dataclass +class ExampleConfigWithNonAnnotatedSubtree: + unnanotated_key = 1 + annotated_key: ExampleNestedConfig = field(default_factory=ExampleNestedConfig) diff --git a/plugins/flytekit-omegaconf/tests/test_dictconfig_transformer.py b/plugins/flytekit-omegaconf/tests/test_dictconfig_transformer.py new file mode 100644 index 0000000000..b4d9115fa9 --- /dev/null +++ b/plugins/flytekit-omegaconf/tests/test_dictconfig_transformer.py @@ -0,0 +1,103 @@ +import typing as t + +import pytest +from flytekitplugins.omegaconf.dictconfig_transformer import ( + check_if_valid_dictconfig, + extract_type_and_value_maps, + is_flattenable, + parse_type_description, +) +from omegaconf import DictConfig, OmegaConf + +from flytekit import FlyteContext + + +@pytest.mark.parametrize( + "config, should_raise, match", + [ + (OmegaConf.create({"key1": "value1", "key2": 123, "key3": True}), False, None), + ({"key1": "value1"}, True, "Invalid type , can only serialize DictConfigs"), + ( + OmegaConf.create({"key1.with.dot": "value1", "key2": 123}), + True, + "cannot be flattened as it contains non-string keys or keys containing dots", + ), + ( + OmegaConf.create({1: "value1", "key2": 123}), + True, + "cannot be flattened as it contains non-string keys or keys containing dots", + ), + ], +) +def test_check_if_valid_dictconfig(config, should_raise, match) -> None: + """Test check_if_valid_dictconfig with various configurations.""" + if should_raise: + with pytest.raises(ValueError, match=match): + check_if_valid_dictconfig(config) + else: + check_if_valid_dictconfig(config) + + +@pytest.mark.parametrize( + "config, should_flatten", + [ + (OmegaConf.create({"key1": "value1", "key2": 123, "key3": True}), True), + (OmegaConf.create({"key1": {"nested_key1": "nested_value1", "nested_key2": 456}, "key2": "value2"}), True), + (OmegaConf.create({"key1.with.dot": "value1", "key2": 123}), False), + (OmegaConf.create({1: "value1", "key2": 123}), False), + ( + OmegaConf.create( + { + "key1": "value1", + "key2": "${oc.env:VAR}", + "key3": OmegaConf.create({"nested_key1": "nested_value1", "nested_key2": "${oc.env:VAR}"}), + } + ), + True, + ), + (OmegaConf.create({"key1": {"nested.key1": "value1"}}), False), + ( + OmegaConf.create( + { + "key1": "value1", + "key2": {"nested_key1": "nested_value1", "nested.key2": "value2"}, + "key3": OmegaConf.create({"nested_key3": "nested_value3"}), + } + ), + False, + ), + ], +) +def test_is_flattenable(config: DictConfig, should_flatten: bool, monkeypatch: pytest.MonkeyPatch) -> None: + """Test flattenable and non-flattenable DictConfigs.""" + monkeypatch.setenv("VAR", "some_value") + assert is_flattenable(config) == should_flatten + + +def test_extract_type_and_value_maps_simple() -> None: + """Test extraction of type and value maps from a simple DictConfig.""" + ctx = FlyteContext.current_context() + config: DictConfig = OmegaConf.create({"key1": "value1", "key2": 123, "key3": True}) + + type_map, value_map = extract_type_and_value_maps(ctx, config) + + expected_type_map = {"key1": "builtins.str", "key2": "builtins.int", "key3": "builtins.bool"} + + assert type_map == expected_type_map + assert "key1" in value_map + assert "key2" in value_map + assert "key3" in value_map + + +@pytest.mark.parametrize( + "type_desc, expected_type", + [ + ("builtins.int", int), + ("typing.List[builtins.int]", t.List[int]), + ("typing.Optional[builtins.int]", t.Optional[int]), + ], +) +def test_parse_type_description(type_desc: str, expected_type: t.Type) -> None: + """Test parsing various type descriptions.""" + parsed_type = parse_type_description(type_desc) + assert parsed_type == expected_type diff --git a/plugins/flytekit-omegaconf/tests/test_extract_node_type.py b/plugins/flytekit-omegaconf/tests/test_extract_node_type.py new file mode 100644 index 0000000000..fbd4628961 --- /dev/null +++ b/plugins/flytekit-omegaconf/tests/test_extract_node_type.py @@ -0,0 +1,71 @@ +import typing as t + +import pytest +from flytekitplugins.omegaconf.type_information import extract_node_type +from omegaconf import DictConfig, ListConfig, OmegaConf + +from tests.conftest import ExampleConfig, ExampleConfigWithNonAnnotatedSubtree + + +class TestExtractNodeType: + def test_extract_type_and_string_representation(self) -> None: + """Tests type extraction and string representation.""" + + python_val = OmegaConf.structured(ExampleConfig(union_key="1337", optional_key=None)) + + # test int + node_type, type_name = extract_node_type(python_val, key="int_key") + assert node_type == int + assert type_name == "builtins.int" + + # test union + node_type, type_name = extract_node_type(python_val, key="union_key") + assert node_type == t.Union[int, str] + assert type_name == "typing.Union[builtins.int, builtins.str]" + + # test any + node_type, type_name = extract_node_type(python_val, key="any_key") + assert node_type == str + assert type_name == "builtins.str" + + # test optional + node_type, type_name = extract_node_type(python_val, key="optional_key") + assert node_type == t.Optional[int] + assert type_name == "typing.Union[builtins.int, NoneType]" + + # test dictconfig + node_type, type_name = extract_node_type(python_val, key="dictconfig_key") + assert node_type == DictConfig + assert type_name == "omegaconf.dictconfig.DictConfig" + + # test listconfig + node_type, type_name = extract_node_type(python_val, key="listconfig_key") + assert node_type == ListConfig + assert type_name == "omegaconf.listconfig.ListConfig" + + # test optional dictconfig + node_type, type_name = extract_node_type(python_val, key="optional_dictconfig_key") + assert node_type == t.Optional[DictConfig] + assert type_name == "typing.Union[omegaconf.dictconfig.DictConfig, NoneType]" + + def test_raises_nonannotated_subtree(self) -> None: + """Test that trying to infer type of a non-annotated subtree raises an error.""" + + python_val = OmegaConf.structured(ExampleConfigWithNonAnnotatedSubtree()) + node_type, type_name = extract_node_type(python_val, key="annotated_key") + assert node_type == DictConfig + + # When we try to infer unnanotated subtree combined with typed subtree, we should raise + with pytest.raises(ValueError): + extract_node_type(python_val, "unnanotated_key") + + def test_single_unnanotated_node(self) -> None: + """Test that inferring a fully unnanotated node works by inferring types from runtime values.""" + + python_val = OmegaConf.create({"unannotated_dictconfig_key": {"unnanotated_int_key": 2}}) + node_type, type_name = extract_node_type(python_val, key="unannotated_dictconfig_key") + assert node_type == DictConfig + + python_val = python_val.unannotated_dictconfig_key + node_type, type_name = extract_node_type(python_val, key="unnanotated_int_key") + assert node_type == int diff --git a/plugins/flytekit-omegaconf/tests/test_objects.py b/plugins/flytekit-omegaconf/tests/test_objects.py new file mode 100644 index 0000000000..912f0bffb3 --- /dev/null +++ b/plugins/flytekit-omegaconf/tests/test_objects.py @@ -0,0 +1,44 @@ +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Optional, Union + +from omegaconf import MISSING, OmegaConf + + +class MultiTypeEnum(str, Enum): + fifo = "fifo" # first in first out + filo = "filo" # first in last out + + +@dataclass +class MySubConf: + my_attr: Optional[Union[int, str]] = 1 + list_attr: List[int] = field(default_factory=list) + + +@dataclass +class MyConf: + my_attr: Optional[MySubConf] = None + + +class SpecialConf(MyConf): + key: int = 1 + + +TEST_CFG = OmegaConf.create( + { + "a": 1, + "b": 1.0, + "c": { + "d": 1, + "e": MISSING, + "f": [ + { + "g": 2, + "h": 1.2, + }, + {"j": 0.5, "k": "foo", "l": "bar"}, + ], + }, + } +) diff --git a/plugins/flytekit-omegaconf/tests/test_plugin.py b/plugins/flytekit-omegaconf/tests/test_plugin.py new file mode 100644 index 0000000000..e42f5ab73d --- /dev/null +++ b/plugins/flytekit-omegaconf/tests/test_plugin.py @@ -0,0 +1,193 @@ +from typing import Any + +import flytekitplugins.omegaconf +import pytest +from flyteidl.core.literals_pb2 import Literal, Scalar +from flytekitplugins.omegaconf.config import OmegaConfTransformerMode +from flytekitplugins.omegaconf.dictconfig_transformer import DictConfigTransformer +from google.protobuf.struct_pb2 import Struct +from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf, ValidationError +from pytest import mark, param + +from flytekit import FlyteContext +from flytekit.core.type_engine import TypeEngine +from tests.conftest import ExampleConfig, ExampleNestedConfig +from tests.test_objects import TEST_CFG, MultiTypeEnum, MyConf, MySubConf, SpecialConf + + +@mark.parametrize( + ("obj"), + [ + param( + DictConfig({}), + ), + param( + DictConfig({"a": "b"}), + ), + param( + DictConfig({"a": 1}), + ), + param( + DictConfig({"a": MISSING}), + ), + param( + DictConfig({"tuple": (1, 2, 3)}), + ), + param( + ListConfig(["a", "b"]), + ), + param( + ListConfig(["a", MISSING]), + ), + param( + TEST_CFG, + ), + param( + OmegaConf.create(ExampleNestedConfig()), + ), + param( + OmegaConf.create(ExampleConfig()), + ), + param( + DictConfig({"foo": MultiTypeEnum.fifo}), + ), + param( + DictConfig({"foo": [MultiTypeEnum.fifo]}), + ), + param(DictConfig({"cfgs": [MySubConf(1), MySubConf("a"), "arg"]})), + param(OmegaConf.structured(SpecialConf)), + ], +) +def test_cfg_roundtrip(obj: Any) -> None: + """Test casting DictConfig object to flyte literal and back.""" + + ctx = FlyteContext.current_context() + expected = TypeEngine.to_literal_type(type(obj)) + transformer = TypeEngine.get_transformer(type(obj)) + + assert isinstance( + transformer, flytekitplugins.omegaconf.dictconfig_transformer.DictConfigTransformer + ) or isinstance(transformer, flytekitplugins.omegaconf.listconfig_transformer.ListConfigTransformer) + + literal = transformer.to_literal(ctx, obj, type(obj), expected) + reconstructed = transformer.to_python_value(ctx, literal, type(obj)) + assert obj == reconstructed + + +def test_optional_type() -> None: + """ + Test serialisation of DictConfigs with various optional entries, whose real types are provided by underlying + dataclasses. + """ + optional_obj: DictConfig = OmegaConf.structured(MySubConf()) + optional_obj1: DictConfig = OmegaConf.structured(MyConf(my_attr=MySubConf())) + optional_obj2: DictConfig = OmegaConf.structured(MyConf()) + + ctx = FlyteContext.current_context() + expected = TypeEngine.to_literal_type(DictConfig) + transformer = TypeEngine.get_transformer(DictConfig) + + literal = transformer.to_literal(ctx, optional_obj, DictConfig, expected) + recon = transformer.to_python_value(ctx, literal, DictConfig) + assert recon == optional_obj + + literal1 = transformer.to_literal(ctx, optional_obj1, DictConfig, expected) + recon1 = transformer.to_python_value(ctx, literal1, DictConfig) + assert recon1 == optional_obj1 + + literal2 = transformer.to_literal(ctx, optional_obj2, DictConfig, expected) + recon2 = transformer.to_python_value(ctx, literal2, DictConfig) + assert recon2 == optional_obj2 + + +def test_plugin_mode() -> None: + """Test serialisation with different plugin modes configured.""" + obj = OmegaConf.structured(MyConf(my_attr=MySubConf())) + + ctx = FlyteContext.current_context() + expected = TypeEngine.to_literal_type(DictConfig) + + with flytekitplugins.omegaconf.local_transformer_mode(OmegaConfTransformerMode.DictConfig): + transformer = DictConfigTransformer() + literal_slim = transformer.to_literal(ctx, obj, DictConfig, expected) + reconstructed_slim = transformer.to_python_value(ctx, literal_slim, DictConfig) + + with flytekitplugins.omegaconf.local_transformer_mode(OmegaConfTransformerMode.DataClass): + literal_full = transformer.to_literal(ctx, obj, DictConfig, expected) + reconstructed_full = transformer.to_python_value(ctx, literal_full, DictConfig) + + with flytekitplugins.omegaconf.local_transformer_mode(OmegaConfTransformerMode.Auto): + literal_semi = transformer.to_literal(ctx, obj, DictConfig, expected) + reconstructed_semi = transformer.to_python_value(ctx, literal_semi, DictConfig) + + assert literal_slim == literal_full == literal_semi + assert reconstructed_slim == reconstructed_full == reconstructed_semi # comparison by value should pass + + assert OmegaConf.get_type(reconstructed_slim, "my_attr") == dict + assert OmegaConf.get_type(reconstructed_semi, "my_attr") == MySubConf + assert OmegaConf.get_type(reconstructed_full, "my_attr") == MySubConf + + reconstructed_slim.my_attr.my_attr = (1,) # assign a tuple value to Union[int, str] field + with pytest.raises(ValidationError): + reconstructed_semi.my_attr.my_attr = (1,) + with pytest.raises(ValidationError): + reconstructed_full.my_attr.my_attr = (1,) + + +def test_auto_transformer_mode() -> None: + """Test if auto transformer mode recovers basic information if the specified type cannot be found.""" + obj = OmegaConf.structured(MyConf(my_attr=MySubConf())) + + struct = Struct() + struct.update( + { + "values.my_attr.scalar.union.value.scalar.generic.values.my_attr.scalar.union.value.scalar.primitive.integer": 1, # noqa: E501 + "values.my_attr.scalar.union.value.scalar.generic.values.my_attr.scalar.union.type.structure.tag": "int", + "values.my_attr.scalar.union.value.scalar.generic.values.my_attr.scalar.union.type.simple": "INTEGER", + "values.my_attr.scalar.union.value.scalar.generic.values.list_attr.scalar.generic.values": [], + "values.my_attr.scalar.union.value.scalar.generic.values.list_attr.scalar.generic.types": [], + "values.my_attr.scalar.union.value.scalar.generic.types.my_attr": "typing.Union[builtins.int, builtins.str, NoneType]", # noqa: E501 + "values.my_attr.scalar.union.value.scalar.generic.types.list_attr": "omegaconf.listconfig.ListConfig", + "values.my_attr.scalar.union.value.scalar.generic.base_dataclass": "tests.test_objects.MySubConf", + "values.my_attr.scalar.union.type.structure.tag": "OmegaConf DictConfig", + "values.my_attr.scalar.union.type.simple": "STRUCT", + "types.my_attr": "typing.Union[omegaconf.dictconfig.DictConfig, NoneType]", + "base_dataclass": "tests.test_objects.MyConf", + } + ) + literal = Literal(scalar=Scalar(generic=struct)) + + # construct a literal with an unknown subconfig type + struct2 = Struct() + struct2.update( + { + "values.my_attr.scalar.union.value.scalar.generic.values.my_attr.scalar.union.value.scalar.primitive.integer": 1, # noqa: E501 + "values.my_attr.scalar.union.value.scalar.generic.values.my_attr.scalar.union.type.structure.tag": "int", + "values.my_attr.scalar.union.value.scalar.generic.values.my_attr.scalar.union.type.simple": "INTEGER", + "values.my_attr.scalar.union.value.scalar.generic.values.list_attr.scalar.generic.values": [], + "values.my_attr.scalar.union.value.scalar.generic.values.list_attr.scalar.generic.types": [], + "values.my_attr.scalar.union.value.scalar.generic.types.my_attr": "typing.Union[builtins.int, builtins.str, NoneType]", # noqa: E501 + "values.my_attr.scalar.union.value.scalar.generic.types.list_attr": "omegaconf.listconfig.ListConfig", + "values.my_attr.scalar.union.value.scalar.generic.base_dataclass": "tests.test_objects.MyFooConf", + "values.my_attr.scalar.union.type.structure.tag": "OmegaConf DictConfig", + "values.my_attr.scalar.union.type.simple": "STRUCT", + "types.my_attr": "typing.Union[omegaconf.dictconfig.DictConfig, NoneType]", + "base_dataclass": "tests.test_objects.MyConf", + } + ) + literal2 = Literal(scalar=Scalar(generic=struct2)) + + ctx = FlyteContext.current_context() + flytekitplugins.omegaconf.set_transformer_mode(OmegaConfTransformerMode.Auto) + transformer = DictConfigTransformer() + + reconstructed = transformer.to_python_value(ctx, literal, DictConfig) + assert obj == reconstructed + + part_reconstructed = transformer.to_python_value(ctx, literal2, DictConfig) + assert obj == part_reconstructed + assert OmegaConf.get_type(part_reconstructed, "my_attr") == dict + + part_reconstructed.my_attr.my_attr = (1,) # assign a tuple value to Union[int, str] field + with pytest.raises(ValidationError): + reconstructed.my_attr.my_attr = (1,)