-
Notifications
You must be signed in to change notification settings - Fork 301
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add flytekit-omegaconf plugin (#2299)
* add flytekit-hydra Signed-off-by: mg515 <[email protected]> * fix small typo readme Signed-off-by: mg515 <[email protected]> * ruff ruff Signed-off-by: mg515 <[email protected]> * lint more Signed-off-by: mg515 <[email protected]> * rename plugin into flytekit-omegaconf Signed-off-by: mg515 <[email protected]> * lint sort imports Signed-off-by: mg515 <[email protected]> * use flytekit logger Signed-off-by: mg515 <[email protected]> * use flytekit logger #2 Signed-off-by: mg515 <[email protected]> * fix typing info in is_flatable Signed-off-by: mg515 <[email protected]> * use default_factory instead of mutable default value Signed-off-by: mg515 <[email protected]> * add python3.11 and python3.12 to setup.py Signed-off-by: mg515 <[email protected]> * make fmt Signed-off-by: mg515 <[email protected]> * define error message only once Signed-off-by: mg515 <[email protected]> * add docstring Signed-off-by: mg515 <[email protected]> * remove GenericEnumTransformer and tests Signed-off-by: mg515 <[email protected]> * fallback to TypeEngine.get_transformer(node_type) to find suitable transformer Signed-off-by: mg515 <[email protected]> * explicit valueerrors instead of asserts Signed-off-by: mg515 <[email protected]> * minor style improvements Signed-off-by: mg515 <[email protected]> * remove obsolete warnings Signed-off-by: mg515 <[email protected]> * import flytekit logger instead of instantiating our own Signed-off-by: mg515 <[email protected]> * docstrings in reST format Signed-off-by: mg515 <[email protected]> * refactor transformer mode Signed-off-by: mg515 <[email protected]> * improve docs Signed-off-by: mg515 <[email protected]> * refactor dictconfig class into smaller methods Signed-off-by: mg515 <[email protected]> * add unit tests for dictconfig transformer Signed-off-by: mg515 <[email protected]> * refactor of parse_type_description() Signed-off-by: mg515 <[email protected]> * add omegaconf plugin to pythonbuild.yaml --------- Signed-off-by: mg515 <[email protected]> Signed-off-by: Eduardo Apolinario <[email protected]> Co-authored-by: Eduardo Apolinario <[email protected]>
- Loading branch information
1 parent
3549597
commit df94e1c
Showing
14 changed files
with
981 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Flytekit OmegaConf Plugin | ||
|
||
Flytekit python natively supports serialization of many data types for exchanging information between tasks. | ||
The Flytekit OmegaConf Plugin extends these by the `DictConfig` type from the | ||
[OmegaConf package](https://omegaconf.readthedocs.io/) as well as related types | ||
that are being used by the [hydra package](https://hydra.cc/) for configuration management. | ||
|
||
## Task example | ||
``` | ||
from dataclasses import dataclass | ||
import flytekitplugins.omegaconf # noqa F401 | ||
from flytekit import task, workflow | ||
from omegaconf import DictConfig | ||
@dataclass | ||
class MySimpleConf: | ||
_target_: str = "lightning_module.MyEncoderModule" | ||
learning_rate: float = 0.0001 | ||
@task | ||
def my_task(cfg: DictConfig) -> None: | ||
print(f"Doing things with {cfg.learning_rate=}") | ||
@workflow | ||
def pipeline(cfg: DictConfig) -> None: | ||
my_task(cfg=cfg) | ||
if __name__ == "__main__": | ||
from omegaconf import OmegaConf | ||
cfg = OmegaConf.structured(MySimpleConf) | ||
pipeline(cfg=cfg) | ||
``` | ||
|
||
## Transformer configuration | ||
|
||
The transformer can be set to one of three modes: | ||
|
||
`Dataclass` - This mode should be used with a StructuredConfig and will reconstruct the config from the matching dataclass | ||
during deserialisation in order to make typing information from the dataclass and continued validation thereof available. | ||
This requires the dataclass definition to be available via python import in the Flyte execution environment in which | ||
objects are (de-)serialised. | ||
|
||
`DictConfig` - This mode will deserialize the config into a DictConfig object. In particular, dataclasses are translated | ||
into DictConfig objects and only primitive types are being checked. The definition of underlying dataclasses for | ||
structured configs is only required during the initial serialization for this mode. | ||
|
||
`Auto` - This mode will try to deserialize according to the Dataclass mode and fall back to the DictConfig mode if the | ||
dataclass definition is not available. This is the default mode. | ||
|
||
You can set the transformer mode globally or for the current context only the following ways: | ||
```python | ||
from flytekitplugins.omegaconf import set_transformer_mode, set_local_transformer_mode, OmegaConfTransformerMode | ||
|
||
# Set the global transformer mode using the new function | ||
set_transformer_mode(OmegaConfTransformerMode.DictConfig) | ||
|
||
# You can also the mode for the current context only | ||
with set_local_transformer_mode(OmegaConfTransformerMode.Dataclass): | ||
# This will use the Dataclass mode | ||
pass | ||
``` | ||
|
||
```note | ||
Since the DictConfig is flattened and keys transformed into dot notation, the keys of the DictConfig must not contain | ||
dots. | ||
``` |
33 changes: 33 additions & 0 deletions
33
plugins/flytekit-omegaconf/flytekitplugins/omegaconf/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from contextlib import contextmanager | ||
|
||
from flytekitplugins.omegaconf.config import OmegaConfTransformerMode | ||
from flytekitplugins.omegaconf.dictconfig_transformer import DictConfigTransformer # noqa: F401 | ||
from flytekitplugins.omegaconf.listconfig_transformer import ListConfigTransformer # noqa: F401 | ||
|
||
_TRANSFORMER_MODE = OmegaConfTransformerMode.Auto | ||
|
||
|
||
def set_transformer_mode(mode: OmegaConfTransformerMode) -> None: | ||
"""Set the global serialization mode for OmegaConf objects.""" | ||
global _TRANSFORMER_MODE | ||
_TRANSFORMER_MODE = mode | ||
|
||
|
||
def get_transformer_mode() -> OmegaConfTransformerMode: | ||
"""Get the global serialization mode for OmegaConf objects.""" | ||
return _TRANSFORMER_MODE | ||
|
||
|
||
@contextmanager | ||
def local_transformer_mode(mode: OmegaConfTransformerMode): | ||
"""Context manager to set a local serialization mode for OmegaConf objects.""" | ||
global _TRANSFORMER_MODE | ||
previous_mode = _TRANSFORMER_MODE | ||
set_transformer_mode(mode) | ||
try: | ||
yield | ||
finally: | ||
set_transformer_mode(previous_mode) | ||
|
||
|
||
__all__ = ["set_transformer_mode", "get_transformer_mode", "local_transformer_mode", "OmegaConfTransformerMode"] |
15 changes: 15 additions & 0 deletions
15
plugins/flytekit-omegaconf/flytekitplugins/omegaconf/config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from enum import Enum | ||
|
||
|
||
class OmegaConfTransformerMode(Enum): | ||
""" | ||
Operation Mode indicating whether a (potentially unannotated) DictConfig object or a structured config using the | ||
underlying dataclass is returned. | ||
Note: We define a single shared config across all transformers as recursive calls should refer to the same config | ||
Note: The latter requires the use of structured configs. | ||
""" | ||
|
||
DictConfig = "DictConfig" | ||
DataClass = "DataClass" | ||
Auto = "Auto" |
181 changes: 181 additions & 0 deletions
181
plugins/flytekit-omegaconf/flytekitplugins/omegaconf/dictconfig_transformer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
import importlib | ||
import re | ||
from typing import Any, Dict, Type, TypeVar | ||
|
||
import flatten_dict | ||
import flytekitplugins.omegaconf | ||
from flyteidl.core.literals_pb2 import Literal as PB_Literal | ||
from flytekitplugins.omegaconf.config import OmegaConfTransformerMode | ||
from flytekitplugins.omegaconf.type_information import extract_node_type | ||
from google.protobuf.json_format import MessageToDict, ParseDict | ||
from google.protobuf.struct_pb2 import Struct | ||
|
||
import omegaconf | ||
from flytekit import FlyteContext | ||
from flytekit.core.type_engine import TypeTransformerFailedError | ||
from flytekit.extend import TypeEngine, TypeTransformer | ||
from flytekit.loggers import logger | ||
from flytekit.models.literals import Literal, Scalar | ||
from flytekit.models.types import LiteralType, SimpleType | ||
from omegaconf import DictConfig, OmegaConf | ||
|
||
T = TypeVar("T") | ||
NoneType = type(None) | ||
|
||
|
||
class DictConfigTransformer(TypeTransformer[DictConfig]): | ||
def __init__(self): | ||
"""Construct DictConfigTransformer.""" | ||
super().__init__(name="OmegaConf DictConfig", t=DictConfig) | ||
|
||
def get_literal_type(self, t: Type[DictConfig]) -> LiteralType: | ||
""" | ||
Provide type hint for Flytekit type system. | ||
To support the multivariate typing of nodes in a DictConfig, we encode them as binaries (no introspection) | ||
with multiple files. | ||
""" | ||
return LiteralType(simple=SimpleType.STRUCT) | ||
|
||
def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: | ||
"""Convert from given python type object ``DictConfig`` to the Literal representation.""" | ||
check_if_valid_dictconfig(python_val) | ||
|
||
base_config = OmegaConf.get_type(python_val) | ||
type_map, value_map = extract_type_and_value_maps(ctx, python_val) | ||
wrapper = create_struct(type_map, value_map, base_config) | ||
|
||
return Literal(scalar=Scalar(generic=wrapper)) | ||
|
||
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[DictConfig]) -> DictConfig: | ||
"""Re-hydrate the custom object from Flyte Literal value.""" | ||
if lv and lv.scalar is not None: | ||
nested_dict = flatten_dict.unflatten(MessageToDict(lv.scalar.generic), splitter="dot") | ||
cfg_dict = {} | ||
for key, type_desc in nested_dict["types"].items(): | ||
cfg_dict[key] = parse_node_value(ctx, key, type_desc, nested_dict) | ||
|
||
return handle_base_dataclass(ctx, nested_dict, cfg_dict) | ||
raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") | ||
|
||
|
||
def is_flattenable(d: DictConfig) -> bool: | ||
"""Check if a DictConfig can be properly flattened and unflattened, i.e. keys do not contain dots.""" | ||
return all( | ||
isinstance(k, str) # keys are strings ... | ||
and "." not in k # ... and do not contain dots | ||
and ( | ||
OmegaConf.is_missing(d, k) # values are either MISSING ... | ||
or not isinstance(d[k], DictConfig) # ... not nested Dictionaries ... | ||
or is_flattenable(d[k]) | ||
) # or flattenable themselves | ||
for k in d.keys() | ||
) | ||
|
||
|
||
def check_if_valid_dictconfig(python_val: DictConfig) -> None: | ||
"""Validate the DictConfig to ensure it's serializable.""" | ||
if not isinstance(python_val, DictConfig): | ||
raise ValueError(f"Invalid type {type(python_val)}, can only serialize DictConfigs") | ||
if not is_flattenable(python_val): | ||
raise ValueError(f"{python_val} cannot be flattened as it contains non-string keys or keys containing dots.") | ||
|
||
|
||
def extract_type_and_value_maps(ctx: FlyteContext, python_val: DictConfig) -> (Dict[str, str], Dict[str, Any]): | ||
"""Extract type and value maps from the DictConfig.""" | ||
type_map = {} | ||
value_map = {} | ||
for key in python_val.keys(): | ||
if OmegaConf.is_missing(python_val, key): | ||
type_map[key] = "MISSING" | ||
else: | ||
node_type, type_name = extract_node_type(python_val, key) | ||
type_map[key] = type_name | ||
|
||
transformer = TypeEngine.get_transformer(node_type) | ||
literal_type = transformer.get_literal_type(node_type) | ||
|
||
value_map[key] = MessageToDict( | ||
transformer.to_literal(ctx, python_val[key], node_type, literal_type).to_flyte_idl() | ||
) | ||
return type_map, value_map | ||
|
||
|
||
def create_struct(type_map: Dict[str, str], value_map: Dict[str, Any], base_config: Type) -> Struct: | ||
"""Create a protobuf Struct object from type and value maps.""" | ||
wrapper = Struct() | ||
wrapper.update( | ||
flatten_dict.flatten( | ||
{ | ||
"types": type_map, | ||
"values": value_map, | ||
"base_dataclass": f"{base_config.__module__}.{base_config.__name__}", | ||
}, | ||
reducer="dot", | ||
keep_empty_types=(dict,), | ||
) | ||
) | ||
return wrapper | ||
|
||
|
||
def parse_type_description(type_desc: str) -> Type: | ||
"""Parse the type description and return the corresponding type.""" | ||
generic_pattern = re.compile(r"(?P<type>[^\[\]]+)\[(?P<args>[^\[\]]+)\]") | ||
match = generic_pattern.match(type_desc) | ||
|
||
if match: | ||
origin_type = match.group("type") | ||
args = match.group("args").split(", ") | ||
|
||
origin_module, origin_class = origin_type.rsplit(".", 1) | ||
origin = importlib.import_module(origin_module).__getattribute__(origin_class) | ||
|
||
sub_types = [] | ||
for arg in args: | ||
if arg == "NoneType": | ||
sub_types.append(type(None)) | ||
else: | ||
module_name, class_name = arg.rsplit(".", 1) | ||
sub_type = importlib.import_module(module_name).__getattribute__(class_name) | ||
sub_types.append(sub_type) | ||
|
||
if origin_class == "Optional": | ||
return origin[sub_types[0]] | ||
return origin[tuple(sub_types)] | ||
else: | ||
module_name, class_name = type_desc.rsplit(".", 1) | ||
return importlib.import_module(module_name).__getattribute__(class_name) | ||
|
||
|
||
def parse_node_value(ctx: FlyteContext, key: str, type_desc: str, nested_dict: Dict[str, Any]) -> Any: | ||
"""Parse the node value from the nested dictionary.""" | ||
if type_desc == "MISSING": | ||
return omegaconf.MISSING | ||
|
||
node_type = parse_type_description(type_desc) | ||
transformer = TypeEngine.get_transformer(node_type) | ||
value_literal = Literal.from_flyte_idl(ParseDict(nested_dict["values"][key], PB_Literal())) | ||
return transformer.to_python_value(ctx, value_literal, node_type) | ||
|
||
|
||
def handle_base_dataclass(ctx: FlyteContext, nested_dict: Dict[str, Any], cfg_dict: Dict[str, Any]) -> DictConfig: | ||
"""Handle the base dataclass and create the DictConfig.""" | ||
if ( | ||
nested_dict["base_dataclass"] != "builtins.dict" | ||
and flytekitplugins.omegaconf.get_transformer_mode() != OmegaConfTransformerMode.DictConfig | ||
): | ||
# Explicitly instantiate dataclass and create DictConfig from there in order to have typing information | ||
module_name, class_name = nested_dict["base_dataclass"].rsplit(".", 1) | ||
try: | ||
return OmegaConf.structured(importlib.import_module(module_name).__getattribute__(class_name)(**cfg_dict)) | ||
except (ModuleNotFoundError, AttributeError) as e: | ||
logger.error( | ||
f"Could not import module {module_name}. If you want to deserialise to DictConfig, " | ||
f"set the mode to DictConfigTransformerMode.DictConfig." | ||
) | ||
if flytekitplugins.omegaconf.get_transformer_mode() == OmegaConfTransformerMode.DataClass: | ||
raise e | ||
return OmegaConf.create(cfg_dict) | ||
|
||
|
||
TypeEngine.register(DictConfigTransformer()) |
Oops, something went wrong.