diff --git a/autoconf/dictable.py b/autoconf/dictable.py index a722847..a2a8dfd 100644 --- a/autoconf/dictable.py +++ b/autoconf/dictable.py @@ -10,7 +10,6 @@ logger = logging.getLogger(__name__) - np_type_map = { "bool": "bool_", } @@ -47,6 +46,22 @@ def is_array(obj) -> bool: return False +def compound_key_dict(obj): + """ + Converts a dictionary with compound keys to a dictionary with a single key. + """ + return { + "type": "compound_dict", + "arguments": [ + { + "key": to_dict(key), + "value": to_dict(value), + } + for key, value in obj.items() + ], + } + + def to_dict(obj, filter_args: Tuple[str, ...] = ()) -> dict: if isinstance(obj, (int, float, str, bool, type(None))): return obj @@ -101,7 +116,7 @@ def to_dict(obj, filter_args: Tuple[str, ...] = ()) -> dict: if isinstance(obj, tuple): return {"type": "tuple", "values": list(map(to_dict, obj))} if isinstance(obj, dict): - return { + result = { "type": "dict", "arguments": { key: to_dict(value) @@ -109,6 +124,12 @@ def to_dict(obj, filter_args: Tuple[str, ...] = ()) -> dict: if key not in filter_args }, } + try: + json.dumps(result) + return result + except TypeError: + return compound_key_dict(obj) + if obj.__class__.__name__ == "method": return to_dict(obj()) if obj.__class__.__module__ == "builtins": @@ -297,7 +318,15 @@ def from_dict(dictionary, **kwargs): if type_ == "tuple": return tuple(map(from_dict, dictionary["values"])) if type_ == "dict": - return {key: from_dict(value, **kwargs) for key, value in dictionary.items()} + return { + key: from_dict(value, **kwargs) + for key, value in dictionary["arguments"].items() + } + if type_ == "compound_dict": + return { + from_dict(item["key"], **kwargs): from_dict(item["value"], **kwargs) + for item in dictionary["arguments"] + } if type_ == "type": return get_class(dictionary["class_path"]) diff --git a/test_autoconf/test_dictable.py b/test_autoconf/test_dictable.py index 0a01273..5395542 100644 --- a/test_autoconf/test_dictable.py +++ b/test_autoconf/test_dictable.py @@ -192,3 +192,10 @@ def test_int64_slice(): ) ) ) + + +def test_compound_key(): + d = {(1, 2): 1} + + string = json.dumps(to_dict(d)) + assert d == from_dict(json.loads(string))