diff --git a/docs/source/supported-types.rst b/docs/source/supported-types.rst index 7d6ead03..63df38d7 100644 --- a/docs/source/supported-types.rst +++ b/docs/source/supported-types.rst @@ -696,9 +696,9 @@ Dict subclasses (`collections.OrderedDict`, for example) are also supported for encoding only. To decode into a ``dict`` subclass you'll need to implement a ``dec_hook`` (see :doc:`extending`). -JSON and TOML only support key types that encode as strings or integers (for -example `str`, `int`, `enum.Enum`, `datetime.datetime`, `uuid.UUID`, ...). -MessagePack and YAML support any hashable for the key type. +JSON and TOML only support key types that encode as strings or numbers (for +example `str`, `int`, `float`, `enum.Enum`, `datetime.datetime`, `uuid.UUID`, +...). MessagePack and YAML support any hashable for the key type. An error is raised during decoding if the keys or values don't match their respective types (if specified). diff --git a/msgspec/_core.c b/msgspec/_core.c index c275eb95..1f5d5230 100644 --- a/msgspec/_core.c +++ b/msgspec/_core.c @@ -10774,7 +10774,7 @@ ms_decode_decimal_from_float(double val, PathNode *path, MsgspecState *mod) { /* For finite values, render as the nearest IEEE754 double in string * form, then call decimal.Decimal to parse */ char buf[24]; - int n = write_f64(val, buf); + int n = write_f64(val, buf, false); return ms_decode_decimal(buf, n, true, path, mod); } else { @@ -11036,7 +11036,7 @@ parse_number_nonfinite( PathNode *path, bool strict ) { - ssize_t size = pend - p; + size_t size = pend - p; double val; if (size == 3) { if ( @@ -12565,10 +12565,24 @@ static MS_NOINLINE int json_encode_float(EncoderState *self, PyObject *obj) { char buf[24]; double x = PyFloat_AS_DOUBLE(obj); - int n = write_f64(x, buf); + int n = write_f64(x, buf, false); return ms_write(self, buf, n); } +static MS_NOINLINE int +json_encode_float_as_str(EncoderState *self, PyObject *obj) { + char buf[24]; + double x = PyFloat_AS_DOUBLE(obj); + int size = write_f64(x, buf, true); + if (ms_ensure_space(self, size + 2) < 0) return -1; + char *p = self->output_buffer_raw + self->output_len; + *p++ = '"'; + memcpy(p, buf, size); + *(p + size) = '"'; + self->output_len += size + 2; + return 0; +} + static MS_INLINE int json_encode_cstr(EncoderState *self, const char *str, Py_ssize_t size) { if (ms_ensure_space(self, size + 2) < 0) return -1; @@ -12940,6 +12954,9 @@ json_encode_dict_key(EncoderState *self, PyObject *obj) { if (type == &PyLong_Type) { return json_encode_long_as_str(self, obj); } + else if (type == &PyFloat_Type) { + return json_encode_float_as_str(self, obj); + } else if (Py_TYPE(type) == self->mod->EnumMetaType) { return json_encode_enum(self, obj, true); } @@ -12967,7 +12984,7 @@ json_encode_dict_key(EncoderState *self, PyObject *obj) { else { PyErr_SetString( PyExc_TypeError, - "Only dicts with str-like or int-like keys are supported" + "Only dicts with str-like or number-like keys are supported" ); return -1; } @@ -18560,7 +18577,7 @@ to_builtins_dict(ToBuiltinsState *self, PyObject *obj) { new_key = to_builtins(self, key, true); if (new_key == NULL) goto cleanup; if (self->str_keys) { - if (PyLong_CheckExact(new_key)) { + if (PyLong_CheckExact(new_key) || PyFloat_CheckExact(new_key)) { PyObject *temp = PyObject_Str(new_key); if (temp == NULL) goto cleanup; Py_DECREF(new_key); @@ -18569,7 +18586,7 @@ to_builtins_dict(ToBuiltinsState *self, PyObject *obj) { else if (!PyUnicode_CheckExact(new_key)) { PyErr_SetString( PyExc_TypeError, - "Only dicts with `str` or `int` keys are supported" + "Only dicts with str-like or number-like keys are supported" ); goto cleanup; } diff --git a/msgspec/ryu.h b/msgspec/ryu.h index 6e7f270c..c6a5a7e0 100644 --- a/msgspec/ryu.h +++ b/msgspec/ryu.h @@ -919,16 +919,30 @@ write_exponent(int32_t k, char* buf) { /* Write a double to buf, requires 24 bytes of space */ static inline int -write_f64(double f, char* buf) { +write_f64(double f, char* buf, bool allow_nonfinite) { const uint64_t bits = double_to_bits(f); const int sign = ((bits >> (DOUBLE_MANTISSA_BITS + DOUBLE_EXPONENT_BITS)) & 1) != 0; const uint64_t ieee_mantissa = bits & ((1ull << DOUBLE_MANTISSA_BITS) - 1); const uint32_t ieee_exponent = (uint32_t) ((bits >> DOUBLE_MANTISSA_BITS) & ((1u << DOUBLE_EXPONENT_BITS) - 1)); /* Serialize all non-finite numbers as null */ - if (ieee_exponent == ((1 << DOUBLE_EXPONENT_BITS) - 1)) { - memcpy(buf, "null", 4); - return 4; + if (MS_UNLIKELY(ieee_exponent == ((1 << DOUBLE_EXPONENT_BITS) - 1))) { + if (MS_LIKELY(!allow_nonfinite)) { + memcpy(buf, "null", 4); + return 4; + } + else { + if (ieee_mantissa == 0) { + if (sign) { + memcpy(buf, "-inf", 4); + return 4; + } + memcpy(buf, "inf", 3); + return 3; + } + memcpy(buf, "nan", 3); + return 3; + } } if (sign) { diff --git a/tests/test_common.py b/tests/test_common.py index 3bd7a62d..72e11453 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -3608,12 +3608,15 @@ def test_encode_decimal(self, proto): s = str(d) assert proto.encode(d) == proto.encode(s) - def test_decode_decimal_str(self, proto): - d = decimal.Decimal("1.5") - msg = proto.encode(d) + @pytest.mark.parametrize( + "val", ["1.5", "InF", "-iNf", "iNfInItY", "-InFiNiTy", "NaN"] + ) + def test_decode_decimal_str(self, val, proto): + sol = decimal.Decimal(val) + msg = proto.encode(sol) res = proto.decode(msg, type=decimal.Decimal) + assert str(res) == str(sol) assert type(res) is decimal.Decimal - assert res == d def test_decode_decimal_str_invalid(self, proto): msg = proto.encode("1..5") diff --git a/tests/test_json.py b/tests/test_json.py index 634aff18..7ffa0b03 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -1800,11 +1800,12 @@ class Test(NamedTuple): class TestDict: - def test_encode_dict_raises_non_string_or_int_keys(self): + def test_encode_dict_raises_non_string_or_numeric_keys(self): with pytest.raises( - TypeError, match="Only dicts with str-like or int-like keys are supported" + TypeError, + match="Only dicts with str-like or number-like keys are supported", ): - msgspec.json.encode({"a": 1, 2.5: "bad"}) + msgspec.json.encode({"a": 1, (1, 2): "bad"}) @pytest.mark.parametrize("x", [{}, {"a": 1}, {"a": 1, "b": 2}]) def test_roundtrip_dict(self, x): @@ -1896,11 +1897,13 @@ def test_decode_dict_string_cache_ascii_only(self): "key", [ 1, + 1.5, FruitInt.APPLE, uuid.uuid4(), datetime.datetime.now(), datetime.date.today(), datetime.datetime.now().time(), + datetime.timedelta(1.5), b"test", Decimal("1.5"), ], @@ -1973,6 +1976,19 @@ def test_decode_dict_int_literal_key(self): with pytest.raises(msgspec.ValidationError, match="Invalid enum value 3"): dec.decode(b'{"-1": 10, "3": 20}') + def test_encode_dict_float_key(self): + msg = { + 1.5: 1, + -1.5: 2, + 0.0: 3, + float("-inf"): 4, + float("inf"): 5, + float("nan"): 6, + } + sol = msgspec.json.encode({str(k): v for k, v in msg.items()}) + res = msgspec.json.encode(msg) + assert res == sol + def test_decode_dict_float_key(self): msg = {"1.5": 1, "inf": 2, "-inf": 3, "0": 4, "-1.5e12": 5, "123": 6} buf = msgspec.json.encode(msg) @@ -1980,6 +1996,13 @@ def test_decode_dict_float_key(self): res = msgspec.json.decode(buf, type=Dict[float, int]) assert res == sol + def test_decode_dict_int_or_float_key(self): + buf = b'{"1.5": "a", "123": "b"}' + sol = {1.5: "a", 123: "b"} + res = msgspec.json.decode(buf, type=Dict[Union[int, float], str]) + assert res == sol + assert type(list(res.keys())[-1]) is int + def test_encode_dict_str_subclass_key(self): class mystr(str): pass diff --git a/tests/test_to_builtins.py b/tests/test_to_builtins.py index 046395e3..628f2e23 100644 --- a/tests/test_to_builtins.py +++ b/tests/test_to_builtins.py @@ -296,17 +296,13 @@ def test_dict_str_keys(self): assert to_builtins({FruitInt.BANANA: 1}, str_keys=True) == {"2": 1} assert to_builtins({2: 1}, str_keys=True) == {"2": 1} - with pytest.raises( - TypeError, match="Only dicts with `str` or `int` keys are supported" - ): - to_builtins({(1, 2): 3}, str_keys=True) - def test_dict_sequence_keys(self): msg = {frozenset([1, 2]): 1} assert to_builtins(msg) == {(1, 2): 1} with pytest.raises( - TypeError, match="Only dicts with `str` or `int` keys are supported" + TypeError, + match="Only dicts with str-like or number-like keys are supported", ): to_builtins(msg, str_keys=True)