diff --git a/tests/test_items.py b/tests/test_items.py index 45aea25..485e47f 100644 --- a/tests/test_items.py +++ b/tests/test_items.py @@ -946,3 +946,21 @@ def test_copy_copy(): ) def test_escape_key(key_str, escaped): assert api.key(key_str).as_string() == escaped + + +def test_custom_encoders(): + import decimal + + @api.register_encoder + def encode_decimal(obj): + if isinstance(obj, decimal.Decimal): + return api.float_(str(obj)) + raise TypeError + + assert api.item(decimal.Decimal("1.23")).as_string() == "1.23" + + with pytest.raises(TypeError): + api.item(object()) + + assert api.dumps({"foo": decimal.Decimal("1.23")}) == "foo = 1.23\n" + api.unregister_encoder(encode_decimal) diff --git a/tomlkit/__init__.py b/tomlkit/__init__.py index acc7046..c2dd53f 100644 --- a/tomlkit/__init__.py +++ b/tomlkit/__init__.py @@ -18,9 +18,11 @@ from tomlkit.api import loads from tomlkit.api import nl from tomlkit.api import parse +from tomlkit.api import register_encoder from tomlkit.api import string from tomlkit.api import table from tomlkit.api import time +from tomlkit.api import unregister_encoder from tomlkit.api import value from tomlkit.api import ws @@ -52,4 +54,6 @@ "TOMLDocument", "value", "ws", + "register_encoder", + "unregister_encoder", ] diff --git a/tomlkit/api.py b/tomlkit/api.py index 8ec5653..686fd1c 100644 --- a/tomlkit/api.py +++ b/tomlkit/api.py @@ -1,14 +1,17 @@ from __future__ import annotations +import contextlib import datetime as _datetime from collections.abc import Mapping from typing import IO from typing import Iterable +from typing import TypeVar from tomlkit._utils import parse_rfc3339 from tomlkit.container import Container from tomlkit.exceptions import UnexpectedCharError +from tomlkit.items import CUSTOM_ENCODERS from tomlkit.items import AoT from tomlkit.items import Array from tomlkit.items import Bool @@ -16,6 +19,7 @@ from tomlkit.items import Date from tomlkit.items import DateTime from tomlkit.items import DottedKey +from tomlkit.items import Encoder from tomlkit.items import Float from tomlkit.items import InlineTable from tomlkit.items import Integer @@ -284,3 +288,21 @@ def nl() -> Whitespace: def comment(string: str) -> Comment: """Create a comment item.""" return Comment(Trivia(comment_ws=" ", comment="# " + string)) + + +E = TypeVar("E", bound=Encoder) + + +def register_encoder(encoder: E) -> E: + """Add a custom encoder, which should be a function that will be called + if the value can't otherwise be converted. It should takes a single value + and return a TOMLKit item or raise a ``TypeError``. + """ + CUSTOM_ENCODERS.append(encoder) + return encoder + + +def unregister_encoder(encoder: Encoder) -> None: + """Unregister a custom encoder.""" + with contextlib.suppress(ValueError): + CUSTOM_ENCODERS.remove(encoder) diff --git a/tomlkit/items.py b/tomlkit/items.py index 683c189..41dccc3 100644 --- a/tomlkit/items.py +++ b/tomlkit/items.py @@ -13,6 +13,7 @@ from enum import Enum from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import Collection from typing import Iterable from typing import Iterator @@ -57,6 +58,15 @@ class _CustomDict(MutableMapping, dict): ItemT = TypeVar("ItemT", bound="Item") +Encoder = Callable[[Any], "Item"] +CUSTOM_ENCODERS: list[Encoder] = [] + + +class _ConvertError(TypeError, ValueError): + """An internal error raised when item() fails to convert a value. + It should be a TypeError, but due to historical reasons + it needs to subclass ValueError as well. + """ @overload @@ -218,8 +228,20 @@ def item(value: Any, _parent: Item | None = None, _sort_keys: bool = False) -> I Trivia(), value.isoformat(), ) + else: + for encoder in CUSTOM_ENCODERS: + try: + rv = encoder(value) + except TypeError: + pass + else: + if not isinstance(rv, Item): + raise _ConvertError( + f"Custom encoder returned {type(rv)}, not a subclass of Item" + ) + return rv - raise ValueError(f"Invalid type {type(value)}") + raise _ConvertError(f"Invalid type {type(value)}") class StringType(Enum):