Skip to content

Commit

Permalink
refactor transformer mode
Browse files Browse the repository at this point in the history
Signed-off-by: mg515 <[email protected]>
  • Loading branch information
mg515 committed Jul 30, 2024
1 parent 1e15009 commit e6e1896
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 65 deletions.
10 changes: 3 additions & 7 deletions plugins/flytekit-omegaconf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,10 @@ dataclass definition is not available. This is the default mode.
To set the mode either initialise the transformer with the `mode` argument or set the mode of the config directly:

```python
from flytekitplugins.omegaconf.config import SharedConfig, OmegaConfTransformerMode
from flytekitplugins.omegaconf import DictConfigTransformer
from flytekitplugins.omegaconf import set_transformer_mode, OmegaConfTransformerMode

# Set the mode directly on the transformer
transformer_slim = DictConfigTransformer(mode=OmegaConfTransformerMode.DictConfig)

# Set the mode directly in the config
SharedConfig.set_mode(OmegaConfTransformerMode.DictConfig)
# Set the mode using the new function
set_transformer_mode(OmegaConfTransformerMode.DictConfig)
```

```note
Expand Down
31 changes: 31 additions & 0 deletions plugins/flytekit-omegaconf/flytekitplugins/omegaconf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,33 @@
from contextlib import contextmanager

from flytekitplugins.omegaconf.config import OmegaConfTransformerMode
from flytekitplugins.omegaconf.dictconfig_transformer import DictConfigTransformer # noqa: F401
from flytekitplugins.omegaconf.listconfig_transformer import ListConfigTransformer # noqa: F401

_TRANSFORMER_MODE = OmegaConfTransformerMode.Auto


def set_transformer_mode(mode: OmegaConfTransformerMode) -> None:
"""Set the global serialization mode for OmegaConf objects."""
global _TRANSFORMER_MODE
_TRANSFORMER_MODE = mode


def get_transformer_mode() -> OmegaConfTransformerMode:
"""Get the global serialization mode for OmegaConf objects."""
return _TRANSFORMER_MODE


@contextmanager
def local_transformer_mode(mode: OmegaConfTransformerMode):
"""Context manager to set a local serialization mode for OmegaConf objects."""
global _TRANSFORMER_MODE
previous_mode = _TRANSFORMER_MODE
set_transformer_mode(mode)
try:
yield
finally:
set_transformer_mode(previous_mode)


__all__ = ["set_transformer_mode", "get_transformer_mode", "local_transformer_mode", "OmegaConfTransformerMode"]
15 changes: 0 additions & 15 deletions plugins/flytekit-omegaconf/flytekitplugins/omegaconf/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,3 @@ class OmegaConfTransformerMode(Enum):
DictConfig = "DictConfig"
DataClass = "DataClass"
Auto = "Auto"


class SharedConfig:
_mode: OmegaConfTransformerMode = OmegaConfTransformerMode.Auto

@classmethod
def get_mode(cls) -> OmegaConfTransformerMode:
"""Get the current mode for serialising omegaconf objects."""
return cls._mode

@classmethod
def set_mode(cls, new_mode: OmegaConfTransformerMode) -> None:
"""Set the current mode for serialising omegaconf objects."""
if isinstance(new_mode, OmegaConfTransformerMode):
cls._mode = new_mode
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import importlib
import re
import typing
from typing import Type, TypeVar

import flatten_dict
import flytekitplugins.omegaconf
from flyteidl.core.literals_pb2 import Literal as PB_Literal
from flytekitplugins.omegaconf.config import OmegaConfTransformerMode, SharedConfig
from flytekitplugins.omegaconf.config import OmegaConfTransformerMode
from flytekitplugins.omegaconf.type_information import extract_node_type
from google.protobuf.json_format import MessageToDict, ParseDict
from google.protobuf.struct_pb2 import Struct
Expand Down Expand Up @@ -38,20 +38,9 @@ def is_flattenable(d: DictConfig) -> bool:


class DictConfigTransformer(TypeTransformer[DictConfig]):
def __init__(self, mode: typing.Optional[OmegaConfTransformerMode] = None):
def __init__(self):
"""Construct DictConfigTransformer."""
super().__init__(name="OmegaConf DictConfig", t=DictConfig)
self.mode = mode

@property
def mode(self) -> OmegaConfTransformerMode:
"""Serialisation mode for omegaconf objects defined in shared config."""
return SharedConfig.get_mode()

@mode.setter
def mode(self, new_mode: OmegaConfTransformerMode) -> None:
"""Updates the central shared config with a new serialisation mode."""
SharedConfig.set_mode(new_mode)

def get_literal_type(self, t: Type[DictConfig]) -> LiteralType:
"""
Expand Down Expand Up @@ -138,7 +127,10 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
value_literal = Literal.from_flyte_idl(ParseDict(nested_dict["values"][key], PB_Literal()))
cfg_dict[key] = transformer.to_python_value(ctx, value_literal, node_type)

if nested_dict["base_dataclass"] != "builtins.dict" and self.mode != OmegaConfTransformerMode.DictConfig:
if (
nested_dict["base_dataclass"] != "builtins.dict"
and flytekitplugins.omegaconf.get_transformer_mode() != OmegaConfTransformerMode.DictConfig
):
# Explicitly instantiate dataclass and create DictConfig from there in order to have typing information
module_name, class_name = nested_dict["base_dataclass"].rsplit(".", 1)
try:
Expand All @@ -150,7 +142,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
f"Could not import module {module_name}. If you want to deserialise to DictConfig, "
f"set the mode to DictConfigTransformerMode.DictConfig."
)
if self.mode == OmegaConfTransformerMode.DataClass:
if flytekitplugins.omegaconf.get_transformer_mode() == OmegaConfTransformerMode.DataClass:
raise e
return OmegaConf.create(cfg_dict)
raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import importlib
from typing import Optional, Type, TypeVar
from typing import Type, TypeVar

from flyteidl.core.literals_pb2 import Literal as PB_Literal
from flytekitplugins.omegaconf.config import OmegaConfTransformerMode, SharedConfig
from flytekitplugins.omegaconf.type_information import extract_node_type
from google.protobuf.json_format import MessageToDict, ParseDict
from google.protobuf.struct_pb2 import Struct
Expand All @@ -19,20 +18,9 @@


class ListConfigTransformer(TypeTransformer[ListConfig]):
def __init__(self, mode: Optional[OmegaConfTransformerMode] = None):
def __init__(self):
"""Construct ListConfigTransformer."""
super().__init__(name="OmegaConf ListConfig", t=ListConfig)
self.mode = mode

@property
def mode(self) -> OmegaConfTransformerMode:
"""Serialisation mode for omegaconf objects defined in shared config."""
return SharedConfig.get_mode()

@mode.setter
def mode(self, new_mode: OmegaConfTransformerMode) -> None:
"""Updates the central shared config with a new serialisation mode."""
SharedConfig.set_mode(new_mode)

def get_literal_type(self, t: Type[ListConfig]) -> LiteralType:
"""
Expand Down
28 changes: 15 additions & 13 deletions plugins/flytekit-omegaconf/tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import flytekitplugins.omegaconf # noqa
import pytest
from flyteidl.core.literals_pb2 import Literal, Scalar
from flytekitplugins.omegaconf import DictConfigTransformer
from flytekitplugins.omegaconf.config import OmegaConfTransformerMode, SharedConfig
from flytekitplugins.omegaconf.config import OmegaConfTransformerMode
from flytekitplugins.omegaconf.dictconfig_transformer import DictConfigTransformer
from google.protobuf.struct_pb2 import Struct
from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf, ValidationError
from pytest import mark, param
Expand Down Expand Up @@ -107,17 +107,18 @@ def test_plugin_mode() -> None:
ctx = FlyteContext.current_context()
expected = TypeEngine.to_literal_type(DictConfig)

transformer = DictConfigTransformer(mode=OmegaConfTransformerMode.DictConfig)
literal_slim = transformer.to_literal(ctx, obj, DictConfig, expected)
reconstructed_slim = transformer.to_python_value(ctx, literal_slim, DictConfig)
with flytekitplugins.omegaconf.local_transformer_mode(OmegaConfTransformerMode.DictConfig):
transformer = DictConfigTransformer()
literal_slim = transformer.to_literal(ctx, obj, DictConfig, expected)
reconstructed_slim = transformer.to_python_value(ctx, literal_slim, DictConfig)

SharedConfig.set_mode(OmegaConfTransformerMode.DataClass)
literal_full = transformer.to_literal(ctx, obj, DictConfig, expected)
reconstructed_full = transformer.to_python_value(ctx, literal_full, DictConfig)
with flytekitplugins.omegaconf.local_transformer_mode(OmegaConfTransformerMode.DataClass):
literal_full = transformer.to_literal(ctx, obj, DictConfig, expected)
reconstructed_full = transformer.to_python_value(ctx, literal_full, DictConfig)

SharedConfig.set_mode(OmegaConfTransformerMode.Auto)
literal_semi = transformer.to_literal(ctx, obj, DictConfig, expected)
reconstructed_semi = transformer.to_python_value(ctx, literal_semi, DictConfig)
with flytekitplugins.omegaconf.local_transformer_mode(OmegaConfTransformerMode.Auto):
literal_semi = transformer.to_literal(ctx, obj, DictConfig, expected)
reconstructed_semi = transformer.to_python_value(ctx, literal_semi, DictConfig)

assert literal_slim == literal_full == literal_semi
assert reconstructed_slim == reconstructed_full == reconstructed_semi # comparison by value should pass
Expand Down Expand Up @@ -177,8 +178,9 @@ def test_fallback() -> None:
literal2 = Literal(scalar=Scalar(generic=struct2))

ctx = FlyteContext.current_context()
transformer = DictConfigTransformer(mode=OmegaConfTransformerMode.DictConfig)
SharedConfig.set_mode(OmegaConfTransformerMode.Auto)
flytekitplugins.omegaconf.set_transformer_mode(OmegaConfTransformerMode.DictConfig)
transformer = DictConfigTransformer()
flytekitplugins.omegaconf.set_transformer_mode(OmegaConfTransformerMode.Auto)

reconstructed = transformer.to_python_value(ctx, literal, DictConfig)
assert obj == reconstructed
Expand Down

0 comments on commit e6e1896

Please sign in to comment.