From 701eab24140db379addd63af4c91137c53551faa Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Mon, 3 Jul 2023 22:28:27 -0500 Subject: [PATCH] Support encoding decimals as numbers Previously we always encoded `decimal.Decimal` values as strings. While this is the best format for these types (prevents precision loss, widely supported), sometimes there is a need to encode `Decimal` objects the same as numbers. This adds a new `decimal_format` option to all `Encoder` classes. This value defaults to `"string"` (to encode decimals as strings), but may be set to `"number"` to encode them the same as the protocol's numeric types. We opt not to support this configuration in the top-level `encode` functions for now, since this is a less-common setting. --- docs/source/supported-types.rst | 21 ++++-- msgspec/_core.c | 119 +++++++++++++++++++++++++++----- msgspec/json.pyi | 6 +- msgspec/msgpack.pyi | 5 +- tests/basic_typing_examples.py | 12 ++++ tests/test_common.py | 40 +++++++++++ tests/test_json.py | 5 ++ 7 files changed, 184 insertions(+), 24 deletions(-) diff --git a/docs/source/supported-types.rst b/docs/source/supported-types.rst index 3cdd1157..fde393e1 100644 --- a/docs/source/supported-types.rst +++ b/docs/source/supported-types.rst @@ -419,8 +419,8 @@ When decoding, both hyphenated and unhyphenated forms are supported. ----------- `decimal.Decimal` values are encoded as their string representation in all -protocols. This ensures no precision loss during serialization, as would happen -with a float representation. +protocols by default. This ensures no precision loss during serialization, as +would happen with a float representation. .. code-block:: python @@ -441,6 +441,19 @@ with a float representation. File "", line 1, in msgspec.ValidationError: Invalid decimal string +For JSON and MessagePack you may instead encode decimal values the same as +numbers by creating a ``Encoder`` and specifying ``decimal_format='number'``. + +.. code-block:: python + + >>> encoder = msgspec.json.Encoder(decimal_format="number") + + >>> encoder.encode(x) + b'1.2345' + +This setting is not yet supported for YAML or TOML - if this option is +important for you please `open an issue`_. + All protocols will also decode `decimal.Decimal` values from ``int`` or ``float`` inputs. For JSON the value is parsed directly from the serialized bytes, avoiding any precision loss: @@ -1123,8 +1136,7 @@ Union restrictions are as follows: - Unions may contain at most one type that encodes to a string (`str`, `enum.Enum`, `bytes`, `bytearray`, `datetime.datetime`, `datetime.date`, `datetime.time`, `uuid.UUID`, `decimal.Decimal`). Note that this restriction - is fixable with some work, if this is a feature you need please `open an - issue `__. + is fixable with some work, if this is a feature you need please `open an issue`_. - Unions may contain at most one type that encodes to an object (`dict`, `typing.TypedDict`, dataclasses_, attrs_, `Struct` with ``array_like=False``) @@ -1339,3 +1351,4 @@ TOML_ types are decoded to Python types as follows: .. _pyright: https://github.com/microsoft/pyright .. _generic types: .. _user-defined generic types: https://docs.python.org/3/library/typing.html#user-defined-generic-types +.. _open an issue: https://github.com/jcrist/msgspec/issues> diff --git a/msgspec/_core.c b/msgspec/_core.c index f48739fe..3d8ab68a 100644 --- a/msgspec/_core.c +++ b/msgspec/_core.c @@ -8205,6 +8205,7 @@ static PyTypeObject Ext_Type = { typedef struct EncoderState { MsgspecState *mod; /* module reference */ PyObject *enc_hook; /* `enc_hook` callback */ + bool decimal_as_string; /* true if decimals should encode as strings */ char* (*resize_buffer)(PyObject**, Py_ssize_t); /* callback for resizing buffer */ char *output_buffer_raw; /* raw pointer to output_buffer internal buffer */ @@ -8217,6 +8218,7 @@ typedef struct Encoder { PyObject_HEAD PyObject *enc_hook; MsgspecState *mod; + bool decimal_as_string; /* true if decimals should encode as strings */ } Encoder; static char* @@ -8269,10 +8271,14 @@ ms_write(EncoderState *self, const char *s, Py_ssize_t n) static int Encoder_init(Encoder *self, PyObject *args, PyObject *kwds) { - char *kwlist[] = {"enc_hook", NULL}; - PyObject *enc_hook = NULL; + char *kwlist[] = {"enc_hook", "decimal_format", NULL}; + PyObject *enc_hook = NULL, *decimal_format = NULL; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "|$O", kwlist, &enc_hook)) { + if ( + !PyArg_ParseTupleAndKeywords( + args, kwds, "|$OO", kwlist, &enc_hook, &decimal_format + ) + ) { return -1; } @@ -8287,6 +8293,30 @@ Encoder_init(Encoder *self, PyObject *args, PyObject *kwds) Py_INCREF(enc_hook); } + if (decimal_format == NULL) { + self->decimal_as_string = true; + } + else { + bool ok = false; + if (PyUnicode_CheckExact(decimal_format)) { + if (PyUnicode_CompareWithASCIIString(decimal_format, "string") == 0) { + self->decimal_as_string = true; + ok = true; + } + else if (PyUnicode_CompareWithASCIIString(decimal_format, "number") == 0) { + self->decimal_as_string = false; + ok = true; + } + } + if (!ok) { + PyErr_Format( + PyExc_ValueError, + "`decimal_format` must be 'string' or 'number', got %R", + decimal_format + ); + return -1; + } + } self->mod = msgspec_get_global_state(); self->enc_hook = enc_hook; return 0; @@ -8374,6 +8404,7 @@ encoder_encode_into_common( EncoderState state = { .mod = self->mod, .enc_hook = self->enc_hook, + .decimal_as_string = self->decimal_as_string, .output_buffer = buf, .output_buffer_raw = PyByteArray_AS_STRING(buf), .output_len = offset, @@ -8416,6 +8447,7 @@ encoder_encode_common( EncoderState state = { .mod = self->mod, .enc_hook = self->enc_hook, + .decimal_as_string = self->decimal_as_string, .output_len = 0, .max_output_len = ENC_INIT_BUFSIZE, .resize_buffer = &ms_resize_bytes @@ -8469,6 +8501,7 @@ encode_common( EncoderState state = { .mod = mod, .enc_hook = enc_hook, + .decimal_as_string = true, .output_len = 0, .max_output_len = ENC_INIT_BUFSIZE, .resize_buffer = &ms_resize_bytes @@ -8490,6 +8523,20 @@ static PyMemberDef Encoder_members[] = { {NULL}, }; +static PyObject* +Encoder_decimal_format(Encoder *self, void *closure) +{ + if (self->decimal_as_string) { + return PyUnicode_InternFromString("string"); + } + return PyUnicode_InternFromString("number"); +} + +static PyGetSetDef Encoder_getset[] = { + {"decimal_format", (getter) Encoder_decimal_format, NULL, NULL, NULL}, + {NULL}, +}; + /************************************************************************* * Shared Decoding Utilities * *************************************************************************/ @@ -10234,7 +10281,7 @@ ms_decode_str_lax( *************************************************************************/ PyDoc_STRVAR(Encoder__doc__, -"Encoder(*, enc_hook=None)\n" +"Encoder(*, enc_hook=None, decimal_format='string')\n" "--\n" "\n" "A MessagePack encoder.\n" @@ -10244,7 +10291,12 @@ PyDoc_STRVAR(Encoder__doc__, "enc_hook : callable, optional\n" " A callable to call for objects that aren't supported msgspec types. Takes\n" " the unsupported object and should return a supported object, or raise a\n" -" ``NotImplementedError`` if unsupported." +" ``NotImplementedError`` if unsupported.\n" +"decimal_format : {'string', 'number'}, optional\n" +" The format to use for encoding `decimal.Decimal` objects. If 'string'\n" +" they're encoded as strings, if 'number', they're encoded as floats.\n" +" Defaults to 'string', which is the recommended value since 'number'\n" +" may result in precision loss when decoding." ); enum mpack_code { @@ -10955,10 +11007,20 @@ mpack_encode_uuid(EncoderState *self, PyObject *obj) static int mpack_encode_decimal(EncoderState *self, PyObject *obj) { - PyObject *str = PyObject_Str(obj); - if (str == NULL) return -1; - int out = mpack_encode_str(self, str); - Py_DECREF(str); + PyObject *temp; + int out; + + if (MS_LIKELY(self->decimal_as_string)) { + temp = PyObject_Str(obj); + if (temp == NULL) return -1; + out = mpack_encode_str(self, temp); + } + else { + temp = PyNumber_Float(obj); + if (temp == NULL) return -1; + out = mpack_encode_float(self, temp); + } + Py_DECREF(temp); return out; } @@ -11202,6 +11264,7 @@ static PyTypeObject Encoder_Type = { .tp_init = (initproc)Encoder_init, .tp_methods = Encoder_methods, .tp_members = Encoder_members, + .tp_getset = Encoder_getset, }; PyDoc_STRVAR(msgspec_msgpack_encode__doc__, @@ -11239,7 +11302,7 @@ msgspec_msgpack_encode(PyObject *self, PyObject *const *args, Py_ssize_t nargs, *************************************************************************/ PyDoc_STRVAR(JSONEncoder__doc__, -"Encoder(*, enc_hook=None)\n" +"Encoder(*, enc_hook=None, decimal_format='string')\n" "--\n" "\n" "A JSON encoder.\n" @@ -11249,7 +11312,13 @@ PyDoc_STRVAR(JSONEncoder__doc__, "enc_hook : callable, optional\n" " A callable to call for objects that aren't supported msgspec types. Takes\n" " the unsupported object and should return a supported object, or raise a\n" -" ``NotImplementedError`` if unsupported." +" ``NotImplementedError`` if unsupported.\n" +"decimal_format : {'string', 'number'}, optional\n" +" The format to use for encoding `decimal.Decimal` objects. If 'string'\n" +" they're encoded as strings, if 'number', they're encoded as floats.\n" +" Defaults to 'string', which is the recommended value since 'number'\n" +" may result in precision loss when decoding for some JSON library\n" +" implementations." ); static int json_encode_inline(EncoderState*, PyObject*); @@ -11555,11 +11624,28 @@ json_encode_uuid(EncoderState *self, PyObject *obj) static int json_encode_decimal(EncoderState *self, PyObject *obj) { - PyObject *str = PyObject_Str(obj); - if (str == NULL) return -1; - int out = json_encode_str_nocheck(self, str); - Py_DECREF(str); - return out; + PyObject *temp = PyObject_Str(obj); + if (temp == NULL) return -1; + + Py_ssize_t size; + const char* buf = unicode_str_and_size_nocheck(temp, &size); + + Py_ssize_t required = size + (2 * self->decimal_as_string); + if (ms_ensure_space(self, size + 2) < 0) { + Py_DECREF(temp); + return -1; + } + + char *p = self->output_buffer_raw + self->output_len; + if (MS_LIKELY(self->decimal_as_string)) *p++ = '"'; + memcpy(p, buf, size); + if (MS_LIKELY(self->decimal_as_string)) *(p + size) = '"'; + + self->output_len += required; + + Py_DECREF(temp); + + return 0; } static int @@ -12052,6 +12138,7 @@ static PyTypeObject JSONEncoder_Type = { .tp_init = (initproc)Encoder_init, .tp_methods = JSONEncoder_methods, .tp_members = Encoder_members, + .tp_getset = Encoder_getset, }; PyDoc_STRVAR(msgspec_json_encode__doc__, diff --git a/msgspec/json.pyi b/msgspec/json.pyi index 07c6f9f4..58f9111d 100644 --- a/msgspec/json.pyi +++ b/msgspec/json.pyi @@ -4,6 +4,7 @@ from typing import ( Callable, Dict, Generic, + Literal, Optional, Tuple, Type, @@ -19,12 +20,13 @@ dec_hook_sig = Optional[Callable[[type, Any], Any]] class Encoder: enc_hook: enc_hook_sig - write_buffer_size: int + decimal_format: Literal["string", "number"] + def __init__( self, *, enc_hook: enc_hook_sig = None, - write_buffer_size: int = ..., + decimal_format: Literal["string", "number"] = "string", ): ... def encode(self, obj: Any) -> bytes: ... def encode_into( diff --git a/msgspec/msgpack.pyi b/msgspec/msgpack.pyi index 17341a12..d492153d 100644 --- a/msgspec/msgpack.pyi +++ b/msgspec/msgpack.pyi @@ -2,6 +2,7 @@ from typing import ( Any, Callable, Generic, + Literal, Optional, Type, TypeVar, @@ -57,12 +58,12 @@ class Decoder(Generic[T]): class Encoder: enc_hook: enc_hook_sig - write_buffer_size: int + decimal_format: Literal["string", "number"] def __init__( self, *, enc_hook: enc_hook_sig = None, - write_buffer_size: int = ..., + decimal_format: Literal["string", "number"] = "string", ): ... def encode(self, obj: Any) -> bytes: ... def encode_into( diff --git a/tests/basic_typing_examples.py b/tests/basic_typing_examples.py index 77c29006..24331a07 100644 --- a/tests/basic_typing_examples.py +++ b/tests/basic_typing_examples.py @@ -641,6 +641,12 @@ def check_msgpack_Encoder_enc_hook() -> None: msgspec.msgpack.Encoder(enc_hook=lambda x: None) +def check_msgpack_Encoder_decimal_format() -> None: + enc = msgspec.msgpack.Encoder(decimal_format="string") + msgspec.msgpack.Encoder(decimal_format="number") + reveal_type(enc.decimal_format) # assert "string" in typ.lower() and "number" in typ.lower() + + def check_msgpack_decode_dec_hook() -> None: def dec_hook(typ: Type, obj: Any) -> Any: return typ(obj) @@ -770,6 +776,12 @@ def check_json_Encoder_enc_hook() -> None: msgspec.json.Encoder(enc_hook=lambda x: None) +def check_json_Encoder_decimal_format() -> None: + enc = msgspec.json.Encoder(decimal_format="string") + msgspec.json.Encoder(decimal_format="number") + reveal_type(enc.decimal_format) # assert "string" in typ.lower() and "number" in typ.lower() + + def check_json_decode_dec_hook() -> None: def dec_hook(typ: Type, obj: Any) -> Any: return typ(obj) diff --git a/tests/test_common.py b/tests/test_common.py index 774e57a5..edb03029 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -3184,6 +3184,46 @@ def test_decode_annotated_newtype_annotated(self, proto, Annotated): class TestDecimal: + def test_encoder_decimal_format(self, proto): + assert proto.Encoder().decimal_format == "string" + assert proto.Encoder(decimal_format="string").decimal_format == "string" + assert proto.Encoder(decimal_format="number").decimal_format == "number" + + def test_encoder_invalid_decimal_format(self, proto): + with pytest.raises(ValueError, match="must be 'string' or 'number', got 'bad'"): + proto.Encoder(decimal_format="bad") + + with pytest.raises(ValueError, match="must be 'string' or 'number', got 1"): + proto.Encoder(decimal_format=1) + + def test_encoder_encode_decimal(self, proto): + enc = proto.Encoder() + d = decimal.Decimal("1.5") + s = str(d) + assert enc.encode(d) == enc.encode(s) + + def test_Encoder_encode_decimal_string(self, proto): + enc = proto.Encoder(decimal_format="string") + d = decimal.Decimal("1.5") + sol = enc.encode(str(d)) + + assert enc.encode(d) == sol + + buf = bytearray() + enc.encode_into(d, buf) + assert buf == sol + + def test_Encoder_encode_decimal_number(self, proto): + enc = proto.Encoder(decimal_format="number") + d = decimal.Decimal("1.5") + sol = enc.encode(float(d)) + + assert enc.encode(d) == sol + + buf = bytearray() + enc.encode_into(d, buf) + assert buf == sol + def test_encode_decimal(self, proto): d = decimal.Decimal("1.5") s = str(d) diff --git a/tests/test_json.py b/tests/test_json.py index 2cf6695a..b2e82e77 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -1440,6 +1440,11 @@ class TestDecimal: """Most decimal tests are in test_common.py, the ones here are for json specific behaviors""" + def test_decimal_to_number_keeps_precision(self): + enc = msgspec.json.Encoder(decimal_format="number") + msg = enc.encode(Decimal("1.3000")) + assert msg == b"1.3000" + @pytest.mark.parametrize( "msg", [