Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flytekit-omegaconf plugin #2299

Merged
merged 28 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
d0149a9
add flytekit-hydra
mg515 Mar 27, 2024
f65136a
fix small typo readme
mg515 Mar 27, 2024
76297da
ruff ruff
mg515 Mar 28, 2024
20e2281
lint more
mg515 Mar 28, 2024
9574b44
rename plugin into flytekit-omegaconf
mg515 Apr 4, 2024
a568a82
lint sort imports
mg515 Apr 8, 2024
61fcbab
use flytekit logger
mg515 Apr 9, 2024
4395540
use flytekit logger #2
mg515 Apr 9, 2024
8864f70
fix typing info in is_flatable
mg515 Apr 9, 2024
0785c83
use default_factory instead of mutable default value
mg515 Apr 9, 2024
167a28c
add python3.11 and python3.12 to setup.py
mg515 Apr 9, 2024
12a779e
make fmt
mg515 Apr 9, 2024
a783c8d
define error message only once
mg515 Apr 9, 2024
5abe25d
add docstring
mg515 Apr 12, 2024
2feac8e
remove GenericEnumTransformer and tests
mg515 May 12, 2024
debe2f7
fallback to TypeEngine.get_transformer(node_type) to find suitable tr…
mg515 May 12, 2024
4fdb7e5
explicit valueerrors instead of asserts
mg515 May 12, 2024
08f9a80
minor style improvements
mg515 May 12, 2024
40af6f3
remove obsolete warnings
mg515 Jul 30, 2024
9c6db69
import flytekit logger instead of instantiating our own
mg515 Jul 30, 2024
1e15009
docstrings in reST format
mg515 Jul 30, 2024
e6e1896
refactor transformer mode
mg515 Jul 30, 2024
f8e945c
improve docs
mg515 Jul 30, 2024
34fd0a4
refactor dictconfig class into smaller methods
mg515 Jul 30, 2024
212ee86
add unit tests for dictconfig transformer
mg515 Jul 30, 2024
0a745d0
refactor of parse_type_description()
mg515 Jul 30, 2024
75dabed
add omegaconf plugin to pythonbuild.yaml
mg515 Aug 1, 2024
257e32d
Merge remote-tracking branch 'origin' into plugin-flytekit-hydra
eapolinario Aug 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions plugins/flytekit-omegaconf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Flytekit OmegaConf Plugin

Flytekit python natively supports serialization of many data types for exchanging information between tasks.
The Flytekit OmegaConf Plugin extends these by the `DictConfig` type from the
[OmegaConf package](https://omegaconf.readthedocs.io/) as well as related types
that are being used by the [hydra package](https://hydra.cc/) for configuration management.

## Task example
```
from dataclasses import dataclass
import flytekitplugins.omegaconf # noqa F401
from flytekit import task, workflow
from omegaconf import DictConfig

@dataclass
class MySimpleConf:
_target_: str = "lightning_module.MyEncoderModule"
learning_rate: float = 0.0001

@task
def my_task(cfg: DictConfig) -> None:
fg91 marked this conversation as resolved.
Show resolved Hide resolved
print(f"Doing things with {cfg.learning_rate=}")


@workflow
def pipeline(cfg: DictConfig) -> None:
my_task(cfg=cfg)


if __name__ == "__main__":
from omegaconf import OmegaConf

cfg = OmegaConf.structured(MySimpleConf)
pipeline(cfg=cfg)
```

## Transformer configuration

The transformer can be set to one of three modes:

`Dataclass` - This mode should be used with a StructuredConfig and will reconstruct the config from the matching dataclass
during deserialisation in order to make typing information from the dataclass and continued validation thereof available.
This requires the dataclass definition to be available via python import in the Flyte execution environment in which
objects are (de-)serialised.

`DictConfig` - This mode will deserialize the config into a DictConfig object. In particular, dataclasses are translated
into DictConfig objects and only primitive types are being checked. The definition of underlying dataclasses for
structured configs is only required during the initial serialization for this mode.

`Auto` - This mode will try to deserialize according to the Dataclass mode and fall back to the DictConfig mode if the
dataclass definition is not available. This is the default mode.

To set the mode either initialise the transformer with the `mode` argument or set the mode of the config directly:

```python
from flytekitplugins.omegaconf.config import SharedConfig, OmegaConfTransformerMode
from flytekitplugins.omegaconf import DictConfigTransformer

# Set the mode directly on the transformer
Copy link
Member

@fg91 fg91 May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it untypical for flytekit that users import type transformers to configure their behaviour. Would the user have to register transformer_slim in this example with the TypeEngine again?

Do I understand it correctly that these are two mechanisms that achieve exactly the same thing? If yes, what do you think of only providing one mechanism?

I personally find the SharedConfig singleton a bit awkward as well :S

As an analogy, setting the start method for multiprocessing comes to my mind:

import multiprocessing
multiprocessing.set_start_method(...)

Do you think it would be possible to simply do:

import flytekitplugins.omegaconf  # noqa F401

flytekitplugins.omegaconf.set_mode(...)

This would feel more idiomatic to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand it correctly that these are two mechanisms that achieve exactly the same thing?

I think the second example configures the mode of the (dict) transformer globally, whereas the first one only configures the specific instance.

Either way, I can't think of a use-case where this would really be useful, I think we can configure it globally across all transformers, the way you suggested. Any objections @SebS94 ?
Will come back to you once I implement it (ran out of time now).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in e6e1896

We provide global mode and a context manager for ability to control it locally still.

transformer_slim = DictConfigTransformer(mode=OmegaConfTransformerMode.DictConfig)

# Set the mode directly in the config
SharedConfig.set_mode(OmegaConfTransformerMode.DictConfig)
```

```note
Since the DictConfig is flattened and keys transformed into dot notation, the keys of the DictConfig must not contain
dots.
fg91 marked this conversation as resolved.
Show resolved Hide resolved
```

```note
Warning: This plugin overwrites the default serializer for Enum-objects to also allow for non-string-valued enum definitions.
mg515 marked this conversation as resolved.
Show resolved Hide resolved
Please check carefully if existing workflows are compatible with the new version.
```

```note
Warning: This plugin attempts serialisation of objects with different transformers. In the process exceptions during
serialisation are suppressed.
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from flytekitplugins.omegaconf.dictconfig_transformer import DictConfigTransformer # noqa: F401
from flytekitplugins.omegaconf.listconfig_transformer import ListConfigTransformer # noqa: F401
30 changes: 30 additions & 0 deletions plugins/flytekit-omegaconf/flytekitplugins/omegaconf/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from enum import Enum


class OmegaConfTransformerMode(Enum):
"""
Operation Mode indicating whether a (potentially unannotated) DictConfig object or a structured config using the
underlying dataclass is returned.

Note: We define a single shared config across all transformers as recursive calls should refer to the same config
Note: The latter requires the use of structured configs.
"""

DictConfig = "DictConfig"
DataClass = "DataClass"
Auto = "Auto"


class SharedConfig:
_mode: OmegaConfTransformerMode = OmegaConfTransformerMode.Auto

@classmethod
def get_mode(cls) -> OmegaConfTransformerMode:
"""Get the current mode for serialising omegaconf objects."""
return cls._mode

@classmethod
def set_mode(cls, new_mode: OmegaConfTransformerMode) -> None:
"""Set the current mode for serialising omegaconf objects."""
if isinstance(new_mode, OmegaConfTransformerMode):
cls._mode = new_mode
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import importlib
import logging
import re
import typing
from typing import Type, TypeVar

import flatten_dict
from flyteidl.core.literals_pb2 import Literal as PB_Literal
from flytekitplugins.omegaconf.config import OmegaConfTransformerMode, SharedConfig
from flytekitplugins.omegaconf.type_information import extract_node_type
from google.protobuf.json_format import MessageToDict, ParseDict
from google.protobuf.struct_pb2 import Struct

import omegaconf
from flytekit import FlyteContext
from flytekit.core.type_engine import TypeTransformerFailedError
from flytekit.extend import TypeEngine, TypeTransformer
from flytekit.models.literals import Literal, Scalar
from flytekit.models.types import LiteralType, SimpleType
from omegaconf import DictConfig, OmegaConf

logger = logging.getLogger("flytekit")
mg515 marked this conversation as resolved.
Show resolved Hide resolved

T = TypeVar("T")
NoneType = type(None)


def is_flattenable(d: DictConfig) -> bool:
"""Check if a DictConfig can be properly flattened and unflattened, i.e. keys do not contain dots."""
return all(
isinstance(k, str) # keys are strings ...
and "." not in k # ... and do not contain dots
and (
OmegaConf.is_missing(d, k) # values are either MISSING ...
or not isinstance(d[k], DictConfig) # ... not nested Dictionaries ...
or is_flattenable(d[k])
) # or flatable themselves
for k in d.keys()
)


class DictConfigTransformer(TypeTransformer[DictConfig]):
def __init__(self, mode: typing.Optional[OmegaConfTransformerMode] = None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed here, shell we remove the mode here and only provide a global mode?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bump.

Copy link
Contributor Author

@mg515 mg515 Jul 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in e6e1896

We provide global mode and a context manager for ability to control it locally still.
Please resolve the conversation if you think this is a suitable solution.

"""Construct DictConfigTransformer."""
super().__init__(name="OmegaConf DictConfig", t=DictConfig)
self.mode = mode

@property
def mode(self) -> OmegaConfTransformerMode:
"""Serialisation mode for omegaconf objects defined in shared config."""
return SharedConfig.get_mode()

@mode.setter
def mode(self, new_mode: OmegaConfTransformerMode) -> None:
"""Updates the central shared config with a new serialisation mode."""
SharedConfig.set_mode(new_mode)

def get_literal_type(self, t: Type[DictConfig]) -> LiteralType:
"""
Provide type hint for Flytekit type system.

To support the multivariate typing of nodes in a DictConfig, we encode them as binaries (no introspection)
with multiple files.
"""
return LiteralType(simple=SimpleType.STRUCT)

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
"""
Convert from given python type object ``DictConfig`` to the Literal representation.

Since the DictConfig type does not offer additional type hints for its nodes, typing information is stored
within the literal itself rather than the Flyte LiteralType.
"""
# instead of raising TypeError, raising ValueError so that flytekit can catch it in
# https://github.com/flyteorg/flytekit/blob/60c982e4b065fdb3aba0b957e506f652a2674c00/flytekit/core/
# type_engine.py#L1222
if not isinstance(python_val, DictConfig):
raise ValueError(f"Invalid type {type(python_val)}, can only serialise DictConfigs")
if not is_flattenable(python_val):
raise ValueError(
f"{python_val} cannot be flattened as it contains non-string keys or keys containing dots."
)

base_config = OmegaConf.get_type(python_val)
type_map = {}
value_map = {}
for key in python_val.keys():
if OmegaConf.is_missing(python_val, key):
type_map[key] = "MISSING"
else:
node_type, type_name = extract_node_type(python_val, key)
type_map[key] = type_name

transformation_logs = ""
mg515 marked this conversation as resolved.
Show resolved Hide resolved
transformer = TypeEngine.get_transformer(node_type)
literal_type = transformer.get_literal_type(node_type)

value_map[key] = MessageToDict(
transformer.to_literal(ctx, python_val[key], node_type, literal_type).to_flyte_idl()
)

if key not in value_map.keys():
raise ValueError(
f"Could not identify matching transformer for object {python_val[key]} of type "
f"{type(python_val[key])}. This may either indicate that no such transformer "
"exists or the appropriate transformer cannot serialise this object. Attempted the following "
f"transformers:\n{transformation_logs}"
)

wrapper = Struct()
wrapper.update(
flatten_dict.flatten(
{
"types": type_map,
"values": value_map,
"base_dataclass": f"{base_config.__module__}.{base_config.__name__}",
},
reducer="dot",
keep_empty_types=(dict,),
)
)
return Literal(scalar=Scalar(generic=wrapper))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[DictConfig]) -> DictConfig:
"""Re-hydrate the custom object from Flyte Literal value."""
if lv and lv.scalar is not None:
nested_dict = flatten_dict.unflatten(MessageToDict(lv.scalar.generic), splitter="dot")
cfg_dict = {}
for key, type_desc in nested_dict["types"].items():
if type_desc == "MISSING":
cfg_dict[key] = omegaconf.MISSING
else:
if re.match(r".+\[.*]", type_desc):
origin_module, origin_type = type_desc.split("[")[0].rsplit(".", 1)
origin = importlib.import_module(origin_module).__getattribute__(origin_type)
sub_types = type_desc.split("[")[1][:-1].split(", ")
for i, t in enumerate(sub_types):
if t != "NoneType":
module_name, class_name = t.rsplit(".", 1)
sub_types[i] = importlib.import_module(module_name).__getattribute__(class_name)
else:
sub_types[i] = type(None)
node_type = origin[tuple(sub_types)]
else:
module_name, class_name = type_desc.rsplit(".", 1)
node_type = importlib.import_module(module_name).__getattribute__(class_name)

transformation_logs = ""
transformer = TypeEngine.get_transformer(node_type)
value_literal = Literal.from_flyte_idl(ParseDict(nested_dict["values"][key], PB_Literal()))
cfg_dict[key] = transformer.to_python_value(ctx, value_literal, node_type)

if key not in cfg_dict.keys():
raise ValueError(
f"Could not identify matching transformer for object {nested_dict['values'][key]} of "
f"proposed type {node_type}. This may either indicate that no such transformer "
"exists or the appropriate transformer cannot deserialise this object. Attempted the "
f"following transformers:\n{transformation_logs}"
mg515 marked this conversation as resolved.
Show resolved Hide resolved
)
if nested_dict["base_dataclass"] != "builtins.dict" and self.mode != OmegaConfTransformerMode.DictConfig:
# Explicitly instantiate dataclass and create DictConfig from there in order to have typing information
module_name, class_name = nested_dict["base_dataclass"].rsplit(".", 1)
try:
return OmegaConf.structured(
importlib.import_module(module_name).__getattribute__(class_name)(**cfg_dict)
)
except (ModuleNotFoundError, AttributeError) as e:
logger.error(
f"Could not import module {module_name}. If you want to deserialise to DictConfig, "
f"set the mode to DictConfigTransformerMode.DictConfig."
)
if self.mode == OmegaConfTransformerMode.DataClass:
raise e
return OmegaConf.create(cfg_dict)
raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")


TypeEngine.register(DictConfigTransformer())
Loading
Loading