Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support encoding decimals as numbers #465

Merged
merged 1 commit into from
Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions docs/source/supported-types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,8 @@ When decoding, both hyphenated and unhyphenated forms are supported.
-----------

`decimal.Decimal` values are encoded as their string representation in all
protocols. This ensures no precision loss during serialization, as would happen
with a float representation.
protocols by default. This ensures no precision loss during serialization, as
would happen with a float representation.

.. code-block:: python

Expand All @@ -441,6 +441,19 @@ with a float representation.
File "<stdin>", line 1, in <module>
msgspec.ValidationError: Invalid decimal string

For JSON and MessagePack you may instead encode decimal values the same as
numbers by creating a ``Encoder`` and specifying ``decimal_format='number'``.

.. code-block:: python

>>> encoder = msgspec.json.Encoder(decimal_format="number")

>>> encoder.encode(x)
b'1.2345'

This setting is not yet supported for YAML or TOML - if this option is
important for you please `open an issue`_.

All protocols will also decode `decimal.Decimal` values from ``int`` or
``float`` inputs. For JSON the value is parsed directly from the serialized
bytes, avoiding any precision loss:
Expand Down Expand Up @@ -1123,8 +1136,7 @@ Union restrictions are as follows:
- Unions may contain at most one type that encodes to a string (`str`,
`enum.Enum`, `bytes`, `bytearray`, `datetime.datetime`, `datetime.date`,
`datetime.time`, `uuid.UUID`, `decimal.Decimal`). Note that this restriction
is fixable with some work, if this is a feature you need please `open an
issue <https://github.com/jcrist/msgspec/issues>`__.
is fixable with some work, if this is a feature you need please `open an issue`_.

- Unions may contain at most one type that encodes to an object (`dict`,
`typing.TypedDict`, dataclasses_, attrs_, `Struct` with ``array_like=False``)
Expand Down Expand Up @@ -1339,3 +1351,4 @@ TOML_ types are decoded to Python types as follows:
.. _pyright: https://github.com/microsoft/pyright
.. _generic types:
.. _user-defined generic types: https://docs.python.org/3/library/typing.html#user-defined-generic-types
.. _open an issue: https://github.com/jcrist/msgspec/issues>
119 changes: 103 additions & 16 deletions msgspec/_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -8205,6 +8205,7 @@ static PyTypeObject Ext_Type = {
typedef struct EncoderState {
MsgspecState *mod; /* module reference */
PyObject *enc_hook; /* `enc_hook` callback */
bool decimal_as_string; /* true if decimals should encode as strings */
char* (*resize_buffer)(PyObject**, Py_ssize_t); /* callback for resizing buffer */

char *output_buffer_raw; /* raw pointer to output_buffer internal buffer */
Expand All @@ -8217,6 +8218,7 @@ typedef struct Encoder {
PyObject_HEAD
PyObject *enc_hook;
MsgspecState *mod;
bool decimal_as_string; /* true if decimals should encode as strings */
} Encoder;

static char*
Expand Down Expand Up @@ -8269,10 +8271,14 @@ ms_write(EncoderState *self, const char *s, Py_ssize_t n)
static int
Encoder_init(Encoder *self, PyObject *args, PyObject *kwds)
{
char *kwlist[] = {"enc_hook", NULL};
PyObject *enc_hook = NULL;
char *kwlist[] = {"enc_hook", "decimal_format", NULL};
PyObject *enc_hook = NULL, *decimal_format = NULL;

if (!PyArg_ParseTupleAndKeywords(args, kwds, "|$O", kwlist, &enc_hook)) {
if (
!PyArg_ParseTupleAndKeywords(
args, kwds, "|$OO", kwlist, &enc_hook, &decimal_format
)
) {
return -1;
}

Expand All @@ -8287,6 +8293,30 @@ Encoder_init(Encoder *self, PyObject *args, PyObject *kwds)
Py_INCREF(enc_hook);
}

if (decimal_format == NULL) {
self->decimal_as_string = true;
}
else {
bool ok = false;
if (PyUnicode_CheckExact(decimal_format)) {
if (PyUnicode_CompareWithASCIIString(decimal_format, "string") == 0) {
self->decimal_as_string = true;
ok = true;
}
else if (PyUnicode_CompareWithASCIIString(decimal_format, "number") == 0) {
self->decimal_as_string = false;
ok = true;
}
}
if (!ok) {
PyErr_Format(
PyExc_ValueError,
"`decimal_format` must be 'string' or 'number', got %R",
decimal_format
);
return -1;
}
}
self->mod = msgspec_get_global_state();
self->enc_hook = enc_hook;
return 0;
Expand Down Expand Up @@ -8374,6 +8404,7 @@ encoder_encode_into_common(
EncoderState state = {
.mod = self->mod,
.enc_hook = self->enc_hook,
.decimal_as_string = self->decimal_as_string,
.output_buffer = buf,
.output_buffer_raw = PyByteArray_AS_STRING(buf),
.output_len = offset,
Expand Down Expand Up @@ -8416,6 +8447,7 @@ encoder_encode_common(
EncoderState state = {
.mod = self->mod,
.enc_hook = self->enc_hook,
.decimal_as_string = self->decimal_as_string,
.output_len = 0,
.max_output_len = ENC_INIT_BUFSIZE,
.resize_buffer = &ms_resize_bytes
Expand Down Expand Up @@ -8469,6 +8501,7 @@ encode_common(
EncoderState state = {
.mod = mod,
.enc_hook = enc_hook,
.decimal_as_string = true,
.output_len = 0,
.max_output_len = ENC_INIT_BUFSIZE,
.resize_buffer = &ms_resize_bytes
Expand All @@ -8490,6 +8523,20 @@ static PyMemberDef Encoder_members[] = {
{NULL},
};

static PyObject*
Encoder_decimal_format(Encoder *self, void *closure)
{
if (self->decimal_as_string) {
return PyUnicode_InternFromString("string");
}
return PyUnicode_InternFromString("number");
}

static PyGetSetDef Encoder_getset[] = {
{"decimal_format", (getter) Encoder_decimal_format, NULL, NULL, NULL},
{NULL},
};

/*************************************************************************
* Shared Decoding Utilities *
*************************************************************************/
Expand Down Expand Up @@ -10234,7 +10281,7 @@ ms_decode_str_lax(
*************************************************************************/

PyDoc_STRVAR(Encoder__doc__,
"Encoder(*, enc_hook=None)\n"
"Encoder(*, enc_hook=None, decimal_format='string')\n"
"--\n"
"\n"
"A MessagePack encoder.\n"
Expand All @@ -10244,7 +10291,12 @@ PyDoc_STRVAR(Encoder__doc__,
"enc_hook : callable, optional\n"
" A callable to call for objects that aren't supported msgspec types. Takes\n"
" the unsupported object and should return a supported object, or raise a\n"
" ``NotImplementedError`` if unsupported."
" ``NotImplementedError`` if unsupported.\n"
"decimal_format : {'string', 'number'}, optional\n"
" The format to use for encoding `decimal.Decimal` objects. If 'string'\n"
" they're encoded as strings, if 'number', they're encoded as floats.\n"
" Defaults to 'string', which is the recommended value since 'number'\n"
" may result in precision loss when decoding."
);

enum mpack_code {
Expand Down Expand Up @@ -10955,10 +11007,20 @@ mpack_encode_uuid(EncoderState *self, PyObject *obj)
static int
mpack_encode_decimal(EncoderState *self, PyObject *obj)
{
PyObject *str = PyObject_Str(obj);
if (str == NULL) return -1;
int out = mpack_encode_str(self, str);
Py_DECREF(str);
PyObject *temp;
int out;

if (MS_LIKELY(self->decimal_as_string)) {
temp = PyObject_Str(obj);
if (temp == NULL) return -1;
out = mpack_encode_str(self, temp);
}
else {
temp = PyNumber_Float(obj);
if (temp == NULL) return -1;
out = mpack_encode_float(self, temp);
}
Py_DECREF(temp);
return out;
}

Expand Down Expand Up @@ -11202,6 +11264,7 @@ static PyTypeObject Encoder_Type = {
.tp_init = (initproc)Encoder_init,
.tp_methods = Encoder_methods,
.tp_members = Encoder_members,
.tp_getset = Encoder_getset,
};

PyDoc_STRVAR(msgspec_msgpack_encode__doc__,
Expand Down Expand Up @@ -11239,7 +11302,7 @@ msgspec_msgpack_encode(PyObject *self, PyObject *const *args, Py_ssize_t nargs,
*************************************************************************/

PyDoc_STRVAR(JSONEncoder__doc__,
"Encoder(*, enc_hook=None)\n"
"Encoder(*, enc_hook=None, decimal_format='string')\n"
"--\n"
"\n"
"A JSON encoder.\n"
Expand All @@ -11249,7 +11312,13 @@ PyDoc_STRVAR(JSONEncoder__doc__,
"enc_hook : callable, optional\n"
" A callable to call for objects that aren't supported msgspec types. Takes\n"
" the unsupported object and should return a supported object, or raise a\n"
" ``NotImplementedError`` if unsupported."
" ``NotImplementedError`` if unsupported.\n"
"decimal_format : {'string', 'number'}, optional\n"
" The format to use for encoding `decimal.Decimal` objects. If 'string'\n"
" they're encoded as strings, if 'number', they're encoded as floats.\n"
" Defaults to 'string', which is the recommended value since 'number'\n"
" may result in precision loss when decoding for some JSON library\n"
" implementations."
);

static int json_encode_inline(EncoderState*, PyObject*);
Expand Down Expand Up @@ -11555,11 +11624,28 @@ json_encode_uuid(EncoderState *self, PyObject *obj)
static int
json_encode_decimal(EncoderState *self, PyObject *obj)
{
PyObject *str = PyObject_Str(obj);
if (str == NULL) return -1;
int out = json_encode_str_nocheck(self, str);
Py_DECREF(str);
return out;
PyObject *temp = PyObject_Str(obj);
if (temp == NULL) return -1;

Py_ssize_t size;
const char* buf = unicode_str_and_size_nocheck(temp, &size);

Py_ssize_t required = size + (2 * self->decimal_as_string);
if (ms_ensure_space(self, size + 2) < 0) {
Py_DECREF(temp);
return -1;
}

char *p = self->output_buffer_raw + self->output_len;
if (MS_LIKELY(self->decimal_as_string)) *p++ = '"';
memcpy(p, buf, size);
if (MS_LIKELY(self->decimal_as_string)) *(p + size) = '"';

self->output_len += required;

Py_DECREF(temp);

return 0;
}

static int
Expand Down Expand Up @@ -12052,6 +12138,7 @@ static PyTypeObject JSONEncoder_Type = {
.tp_init = (initproc)Encoder_init,
.tp_methods = JSONEncoder_methods,
.tp_members = Encoder_members,
.tp_getset = Encoder_getset,
};

PyDoc_STRVAR(msgspec_json_encode__doc__,
Expand Down
6 changes: 4 additions & 2 deletions msgspec/json.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ from typing import (
Callable,
Dict,
Generic,
Literal,
Optional,
Tuple,
Type,
Expand All @@ -19,12 +20,13 @@ dec_hook_sig = Optional[Callable[[type, Any], Any]]

class Encoder:
enc_hook: enc_hook_sig
write_buffer_size: int
decimal_format: Literal["string", "number"]

def __init__(
self,
*,
enc_hook: enc_hook_sig = None,
write_buffer_size: int = ...,
decimal_format: Literal["string", "number"] = "string",
): ...
def encode(self, obj: Any) -> bytes: ...
def encode_into(
Expand Down
5 changes: 3 additions & 2 deletions msgspec/msgpack.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ from typing import (
Any,
Callable,
Generic,
Literal,
Optional,
Type,
TypeVar,
Expand Down Expand Up @@ -57,12 +58,12 @@ class Decoder(Generic[T]):

class Encoder:
enc_hook: enc_hook_sig
write_buffer_size: int
decimal_format: Literal["string", "number"]
def __init__(
self,
*,
enc_hook: enc_hook_sig = None,
write_buffer_size: int = ...,
decimal_format: Literal["string", "number"] = "string",
): ...
def encode(self, obj: Any) -> bytes: ...
def encode_into(
Expand Down
12 changes: 12 additions & 0 deletions tests/basic_typing_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,12 @@ def check_msgpack_Encoder_enc_hook() -> None:
msgspec.msgpack.Encoder(enc_hook=lambda x: None)


def check_msgpack_Encoder_decimal_format() -> None:
enc = msgspec.msgpack.Encoder(decimal_format="string")
msgspec.msgpack.Encoder(decimal_format="number")
reveal_type(enc.decimal_format) # assert "string" in typ.lower() and "number" in typ.lower()


def check_msgpack_decode_dec_hook() -> None:
def dec_hook(typ: Type, obj: Any) -> Any:
return typ(obj)
Expand Down Expand Up @@ -770,6 +776,12 @@ def check_json_Encoder_enc_hook() -> None:
msgspec.json.Encoder(enc_hook=lambda x: None)


def check_json_Encoder_decimal_format() -> None:
enc = msgspec.json.Encoder(decimal_format="string")
msgspec.json.Encoder(decimal_format="number")
reveal_type(enc.decimal_format) # assert "string" in typ.lower() and "number" in typ.lower()


def check_json_decode_dec_hook() -> None:
def dec_hook(typ: Type, obj: Any) -> Any:
return typ(obj)
Expand Down
40 changes: 40 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3184,6 +3184,46 @@ def test_decode_annotated_newtype_annotated(self, proto, Annotated):


class TestDecimal:
def test_encoder_decimal_format(self, proto):
assert proto.Encoder().decimal_format == "string"
assert proto.Encoder(decimal_format="string").decimal_format == "string"
assert proto.Encoder(decimal_format="number").decimal_format == "number"

def test_encoder_invalid_decimal_format(self, proto):
with pytest.raises(ValueError, match="must be 'string' or 'number', got 'bad'"):
proto.Encoder(decimal_format="bad")

with pytest.raises(ValueError, match="must be 'string' or 'number', got 1"):
proto.Encoder(decimal_format=1)

def test_encoder_encode_decimal(self, proto):
enc = proto.Encoder()
d = decimal.Decimal("1.5")
s = str(d)
assert enc.encode(d) == enc.encode(s)

def test_Encoder_encode_decimal_string(self, proto):
enc = proto.Encoder(decimal_format="string")
d = decimal.Decimal("1.5")
sol = enc.encode(str(d))

assert enc.encode(d) == sol

buf = bytearray()
enc.encode_into(d, buf)
assert buf == sol

def test_Encoder_encode_decimal_number(self, proto):
enc = proto.Encoder(decimal_format="number")
d = decimal.Decimal("1.5")
sol = enc.encode(float(d))

assert enc.encode(d) == sol

buf = bytearray()
enc.encode_into(d, buf)
assert buf == sol

def test_encode_decimal(self, proto):
d = decimal.Decimal("1.5")
s = str(d)
Expand Down
Loading