-
Notifications
You must be signed in to change notification settings - Fork 301
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
1,098 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Flytekit Hydra Plugin | ||
|
||
Flytekit python natively supports serialization of many data types for exchanging information between tasks. | ||
The Flytekit Hydra Plugin extends these by the `DictConfig` type from the | ||
[OmegaConf package](https://omegaconf.readthedocs.io/) as well as related types | ||
that are being used by the [hydra package](https://hydra.cc/) for configuration management. | ||
|
||
|
||
|
||
## Task example | ||
``` | ||
from dataclasses import dataclass | ||
import flytekitplugins.hydra # noqa F401 | ||
from flytekit import task, workflow | ||
from omegaconf import DictConfig | ||
@dataclass | ||
class MySimpleConf: | ||
_target_: str = "lightning_module.MyEncoderModule" | ||
learning_rate: float = 0.0001 | ||
@task | ||
def my_task(cfg: DictConfig) -> None: | ||
print(f"Doing things with {cfg.learning_rate=}") | ||
@workflow | ||
def pipeline(cfg: DictConfig) -> None: | ||
my_task(cfg=cfg) | ||
if __name__ == "__main__": | ||
import OmegaConf | ||
cfg = OmegaConf.structured(MySimpleConf) | ||
pipeline(cfg=cfg) | ||
``` | ||
|
||
## Transformer configuration | ||
|
||
The transformer can be set to one of three modes: | ||
|
||
`Dataclass` - This mode should be used with a StructuredConfig and will reconstruct the config from the matching dataclass | ||
during deserialisation in order to make typing information from the dataclass and continued validation thereof available. | ||
This requires the dataclass definition to be available via python import in the Flyte execution environment in which | ||
objects are (de-)serialised. | ||
|
||
`DictConfig` - This mode will deserialize the config into a DictConfig object. In particular, dataclasses are translated | ||
into DictConfig objects and only primitive types are being checked. The definition of underlying dataclasses for | ||
structured configs is only required during the initial serialization for this mode. | ||
|
||
`Auto` - This mode will try to deserialize according to the Dataclass mode and fall back to the DictConfig mode if the | ||
dataclass definition is not available. This is the default mode. | ||
|
||
To set the mode either initialise the transformer with the `mode` argument or set the mode of the config directly: | ||
|
||
```python | ||
from flytekitplugins.hydra.config import SharedConfig, OmegaConfTransformerMode | ||
from flytekitplugins.hydra import DictConfigTransformer | ||
|
||
# Set the mode directly on the transformer | ||
transformer_slim = DictConfigTransformer(mode=OmegaConfTransformerMode.DictConfig) | ||
|
||
# Set the mode directly in the config | ||
SharedConfig.set_mode(OmegaConfTransformerMode.DictConfig) | ||
``` | ||
|
||
```note | ||
Since the DictConfig is flattened and keys transformed into dot notation, the keys of the DictConfig must not contain | ||
dots. | ||
``` | ||
|
||
```note | ||
Warning: This plugin overwrites the default serializer for Enum-objects to also allow for non-string-valued enum definitions. | ||
Please check carefully if existing workflows are compatible with the new version. | ||
``` | ||
|
||
```note | ||
Warning: This plugin attempts serialisation of objects with different transformers. In the process exceptions during | ||
serialisation are suppressed. | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from flytekitplugins.hydra.dictconfig_transformer import DictConfigTransformer # noqa: F401 | ||
from flytekitplugins.hydra.extended_enum_transformer import GenericEnumTransformer # noqa: F401 | ||
from flytekitplugins.hydra.listconfig_transformer import ListConfigTransformer # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from enum import Enum | ||
|
||
|
||
class OmegaConfTransformerMode(Enum): | ||
""" | ||
Operation Mode indicating whether a (potentially unannotated) DictConfig object or a structured config using the | ||
underlying dataclass is returned. | ||
Note: We define a single shared configs across all transformers as recursive calls should refer to the same config | ||
Note: The latter requires the use of structured configs. | ||
""" | ||
|
||
DictConfig = "DictConfig" | ||
DataClass = "DataClass" | ||
Auto = "Auto" | ||
|
||
|
||
class SharedConfig: | ||
_mode: OmegaConfTransformerMode = OmegaConfTransformerMode.Auto | ||
|
||
@classmethod | ||
def get_mode(cls) -> OmegaConfTransformerMode: | ||
"""Get the current mode for serialising omegaconf objects.""" | ||
return cls._mode | ||
|
||
@classmethod | ||
def set_mode(cls, new_mode: OmegaConfTransformerMode) -> None: | ||
"""Set the current mode for serialising omegaconf objects.""" | ||
if isinstance(new_mode, OmegaConfTransformerMode): | ||
cls._mode = new_mode |
196 changes: 196 additions & 0 deletions
196
plugins/flytekit-hydra/flytekitplugins/hydra/dictconfig_transformer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
import importlib | ||
import logging | ||
import re | ||
import traceback | ||
import typing | ||
from typing import Type, TypeVar | ||
|
||
import flatten_dict | ||
import omegaconf | ||
from flyteidl.core.literals_pb2 import Literal as PB_Literal | ||
from flytekit import FlyteContext | ||
from flytekit.core.type_engine import TypeTransformerFailedError | ||
from flytekit.extend import TypeEngine, TypeTransformer | ||
from flytekit.models.literals import Literal, Scalar | ||
from flytekit.models.types import LiteralType, SimpleType | ||
from google.protobuf.json_format import MessageToDict, ParseDict | ||
from google.protobuf.struct_pb2 import Struct | ||
from omegaconf import DictConfig, OmegaConf | ||
|
||
from flytekitplugins.hydra.config import OmegaConfTransformerMode, SharedConfig | ||
from flytekitplugins.hydra.flytekit_patch import iterate_get_transformers | ||
from flytekitplugins.hydra.type_information import extract_node_type | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
T = TypeVar("T") | ||
NoneType = type(None) | ||
|
||
|
||
def is_flatable(d: typing.Union[DictConfig]) -> bool: | ||
"""Check if a DictConfig can be properly flattened and unflattened, i.e. keys do not contain dots.""" | ||
return all( | ||
isinstance(k, str) # keys are strings ... | ||
and "." not in k # ... and do not contain dots | ||
and ( | ||
OmegaConf.is_missing(d, k) # values are either MISSING ... | ||
or not isinstance(d[k], DictConfig) # ... not nested Dictionaries ... | ||
or is_flatable(d[k]) | ||
) # or flatable themselves | ||
for k in d.keys() | ||
) | ||
|
||
|
||
class DictConfigTransformer(TypeTransformer[DictConfig]): | ||
def __init__(self, mode: typing.Optional[OmegaConfTransformerMode] = None): | ||
"""Construct DictConfigTransformer.""" | ||
super().__init__(name="OmegaConf DictConfig", t=DictConfig) | ||
self.mode = mode | ||
|
||
@property | ||
def mode(self) -> OmegaConfTransformerMode: | ||
"""Serialisation mode for omegaconf objects defined in shared config.""" | ||
return SharedConfig.get_mode() | ||
|
||
@mode.setter | ||
def mode(self, new_mode: OmegaConfTransformerMode) -> None: | ||
"""Updates the central shared config with a new serialisation mode.""" | ||
SharedConfig.set_mode(new_mode) | ||
|
||
def get_literal_type(self, t: Type[DictConfig]) -> LiteralType: | ||
""" | ||
Provide type hint for Flytekit type system. | ||
To support the multivariate typing of nodes in a DictConfig, we encode them as binaries (no introspection) | ||
with multiple files. | ||
""" | ||
return LiteralType(simple=SimpleType.STRUCT) | ||
|
||
def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: | ||
""" | ||
Convert from given python type object ``DictConfig`` to the Literal representation. | ||
Since the DictConfig type does not offer additional type hints for its nodes, typing information is stored | ||
within the literal itself rather than the Flyte LiteralType. | ||
""" | ||
# instead of raising TypeError, raising AssertError so that flytekit can catch it in | ||
# https://github.com/flyteorg/flytekit/blob/60c982e4b065fdb3aba0b957e506f652a2674c00/flytekit/core/ | ||
# type_engine.py#L1222 | ||
assert isinstance(python_val, DictConfig), f"Invalid type {type(python_val)}, can only serialise DictConfigs" | ||
assert is_flatable( | ||
python_val | ||
), f"{python_val} cannot be flattened as it contains non-string keys or keys containing dots." | ||
|
||
base_config = OmegaConf.get_type(python_val) | ||
type_map = {} | ||
value_map = {} | ||
for key in python_val.keys(): | ||
if OmegaConf.is_missing(python_val, key): | ||
type_map[key] = "MISSING" | ||
else: | ||
node_type, type_name = extract_node_type(python_val, key) | ||
type_map[key] = type_name | ||
|
||
transformation_logs = "" | ||
for transformer in iterate_get_transformers(node_type): | ||
try: | ||
literal_type = transformer.get_literal_type(node_type) | ||
|
||
value_map[key] = MessageToDict( | ||
transformer.to_literal(ctx, python_val[key], node_type, literal_type).to_flyte_idl() | ||
) | ||
break | ||
except Exception: | ||
transformation_logs += ( | ||
f"Serialisation with transformer {type(transformer)} failed:\n" | ||
f"{traceback.format_exc()}\n\n" | ||
) | ||
|
||
if key not in value_map.keys(): | ||
raise ValueError( | ||
f"Could not identify matching transformer for object {python_val[key]} of type " | ||
f"{type(python_val[key])}. This may either indicate that no such transformer " | ||
"exists or the appropriate transformer cannot serialise this object. Attempted the following " | ||
f"transformers:\n{transformation_logs}" | ||
) | ||
|
||
wrapper = Struct() | ||
wrapper.update( | ||
flatten_dict.flatten( | ||
{ | ||
"types": type_map, | ||
"values": value_map, | ||
"base_dataclass": f"{base_config.__module__}.{base_config.__name__}", | ||
}, | ||
reducer="dot", | ||
keep_empty_types=(dict,), | ||
) | ||
) | ||
return Literal(scalar=Scalar(generic=wrapper)) | ||
|
||
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[DictConfig]) -> DictConfig: | ||
"""Re-hydrate the custom object from Flyte Literal value.""" | ||
if lv and lv.scalar is not None: | ||
nested_dict = flatten_dict.unflatten(MessageToDict(lv.scalar.generic), splitter="dot") | ||
cfg_dict = {} | ||
for key, type_desc in nested_dict["types"].items(): | ||
if type_desc == "MISSING": | ||
cfg_dict[key] = omegaconf.MISSING | ||
else: | ||
if re.match(r".+\[.*]", type_desc): | ||
origin_module, origin_type = type_desc.split("[")[0].rsplit(".", 1) | ||
origin = importlib.import_module(origin_module).__getattribute__(origin_type) | ||
sub_types = type_desc.split("[")[1][:-1].split(", ") | ||
for i, t in enumerate(sub_types): | ||
if t != "NoneType": | ||
module_name, class_name = t.rsplit(".", 1) | ||
sub_types[i] = importlib.import_module(module_name).__getattribute__(class_name) | ||
else: | ||
sub_types[i] = type(None) | ||
node_type = origin[tuple(sub_types)] | ||
else: | ||
module_name, class_name = type_desc.rsplit(".", 1) | ||
node_type = importlib.import_module(module_name).__getattribute__(class_name) | ||
|
||
transformation_logs = "" | ||
for transformer in iterate_get_transformers(node_type): | ||
try: | ||
value_literal = Literal.from_flyte_idl(ParseDict(nested_dict["values"][key], PB_Literal())) | ||
cfg_dict[key] = transformer.to_python_value(ctx, value_literal, node_type) | ||
break | ||
except Exception: | ||
logger.debug( | ||
f"Serialisation with transformer {type(transformer)} failed:\n" | ||
f"{traceback.format_exc()}\n\n" | ||
) | ||
transformation_logs += ( | ||
f"Deserialisation with transformer {type(transformer)} failed:\n" | ||
f"{traceback.format_exc()}\n\n" | ||
) | ||
|
||
if key not in cfg_dict.keys(): | ||
raise ValueError( | ||
f"Could not identify matching transformer for object {nested_dict['values'][key]} of " | ||
f"proposed type {node_type}. This may either indicate that no such transformer " | ||
"exists or the appropriate transformer cannot deserialise this object. Attempted the " | ||
f"following transformers:\n{transformation_logs}" | ||
) | ||
if nested_dict["base_dataclass"] != "builtins.dict" and self.mode != OmegaConfTransformerMode.DictConfig: | ||
# Explicitly instantiate dataclass and create DictConfig from there in order to have typing information | ||
module_name, class_name = nested_dict["base_dataclass"].rsplit(".", 1) | ||
try: | ||
return OmegaConf.structured( | ||
importlib.import_module(module_name).__getattribute__(class_name)(**cfg_dict) | ||
) | ||
except (ModuleNotFoundError, AttributeError) as e: | ||
logger.error( | ||
f"Could not import module {module_name}. If you want to deserialise to DictConfig, " | ||
f"set the mode to DictConfigTransformerMode.DictConfig." | ||
) | ||
if self.mode == OmegaConfTransformerMode.DataClass: | ||
raise e | ||
return OmegaConf.create(cfg_dict) | ||
raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") | ||
|
||
|
||
TypeEngine.register(DictConfigTransformer()) |
66 changes: 66 additions & 0 deletions
66
plugins/flytekit-hydra/flytekitplugins/hydra/extended_enum_transformer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import enum | ||
import importlib | ||
from typing import Type, TypeVar | ||
|
||
from flytekit import FlyteContext | ||
from flytekit.core.type_engine import TypeTransformer, TypeTransformerFailedError | ||
from flytekit.extend import TypeEngine | ||
from flytekit.models.literals import Literal, LiteralMap, Primitive, Scalar | ||
from flytekit.models.types import LiteralType, SimpleType | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
class GenericEnumTransformer(TypeTransformer[enum.Enum]): | ||
def __init__(self): | ||
"""Transformer for arbitrary enum.Enum objects.""" | ||
super().__init__(name="GenericEnumTransformer", t=enum.Enum) | ||
|
||
def get_literal_type(self, t: Type[T]) -> LiteralType: | ||
""" | ||
Provide type hint for Flytekit type system. | ||
To support the multivariate typing of generic enums, we encode them as binaries (no introspection) with | ||
multiple files. | ||
""" | ||
return LiteralType(simple=SimpleType.STRUCT) | ||
|
||
def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: | ||
"""Transform python enum to Flyte literal for serialisation""" | ||
if not isinstance(python_val, enum.Enum): | ||
raise TypeTransformerFailedError("Expected an enum") | ||
|
||
enum_type = type(python_val) | ||
node_type = type(python_val.value) | ||
transformer = TypeEngine.get_transformer(node_type) | ||
literal_type = transformer.get_literal_type(node_type) | ||
|
||
return Literal( | ||
map=LiteralMap( | ||
literals={ | ||
"enum_type": Literal( | ||
scalar=Scalar(primitive=Primitive(string_value=f"{enum_type.__module__}/{enum_type.__name__}")) | ||
), | ||
"type": Literal( | ||
scalar=Scalar(primitive=Primitive(string_value=f"{node_type.__module__}/{node_type.__name__}")) | ||
), | ||
"value": transformer.to_literal(ctx, python_val.value, node_type, literal_type), | ||
} | ||
) | ||
) | ||
|
||
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: | ||
"""Transform Literal back to python object (Enum).""" | ||
module_name, class_name = lv.map.literals["type"].scalar.primitive.string_value.split("/") | ||
value_literal = lv.map.literals["value"] | ||
node_type = importlib.import_module(module_name).__getattribute__(class_name) | ||
transformer = TypeEngine.get_transformer(node_type) | ||
base_value = transformer.to_python_value(ctx, value_literal, node_type) | ||
|
||
enum_module, enum_class = lv.map.literals["enum_type"].scalar.primitive.string_value.split("/") | ||
enum_type = importlib.import_module(enum_module).__getattribute__(enum_class) | ||
return enum_type(base_value) | ||
|
||
|
||
TypeEngine._ENUM_TRANSFORMER = GenericEnumTransformer() | ||
TypeEngine.register_additional_type(GenericEnumTransformer(), enum.Enum, override=True) |
Oops, something went wrong.