diff --git a/msgspec/_core.c b/msgspec/_core.c index 4b964d38..221d9a57 100644 --- a/msgspec/_core.c +++ b/msgspec/_core.c @@ -10174,6 +10174,7 @@ enum mpack_code { }; static int mpack_encode_inline(EncoderState *self, PyObject *obj); +static int mpack_encode_dict_key_inline(EncoderState *self, PyObject *obj); static int mpack_encode(EncoderState *self, PyObject *obj); static int @@ -10532,7 +10533,7 @@ mpack_encode_dict(EncoderState *self, PyObject *obj) if (len == 0) return 0; if (Py_EnterRecursiveCall(" while serializing an object")) return -1; while (PyDict_Next(obj, &pos, &key, &val)) { - if (mpack_encode_inline(self, key) < 0) goto error; + if (mpack_encode_dict_key_inline(self, key) < 0) goto error; if (mpack_encode_inline(self, val) < 0) goto error; } status = 0; @@ -11023,6 +11024,31 @@ mpack_encode_inline(EncoderState *self, PyObject *obj) } } +static MS_INLINE int +mpack_encode_dict_key_inline(EncoderState *self, PyObject *obj) +{ + PyTypeObject *type = Py_TYPE(obj); + + if (PyUnicode_Check(obj)) { + return mpack_encode_str(self, obj); + } + else if (type == &PyLong_Type) { + return mpack_encode_long(self, obj); + } + else if (type == &PyFloat_Type) { + return mpack_encode_float(self, obj); + } + else if (PyList_Check(obj)) { + return mpack_encode_list(self, obj); + } + else if (PyDict_Check(obj)) { + return mpack_encode_dict(self, obj); + } + else { + return mpack_encode_uncommon(self, type, obj); + } +} + static int mpack_encode(EncoderState *self, PyObject *obj) { return mpack_encode_inline(self, obj); @@ -11555,7 +11581,7 @@ json_encode_dict(EncoderState *self, PyObject *obj) if (ms_write(self, "{", 1) < 0) return -1; if (Py_EnterRecursiveCall(" while serializing an object")) return -1; while (PyDict_Next(obj, &pos, &key, &val)) { - if (MS_LIKELY(PyUnicode_CheckExact(key))) { + if (MS_LIKELY(PyUnicode_Check(key))) { if (json_encode_str(self, key) < 0) goto cleanup; } else { @@ -17460,6 +17486,9 @@ to_builtins(ToBuiltinsState *self, PyObject *obj, bool is_key) { else if (Py_TYPE(type) == self->mod->EnumMetaType) { return to_builtins_enum(self, obj); } + else if (is_key & PyUnicode_Check(obj)) { + return PyObject_Str(obj); + } else if (PyType_IsSubtype(type, (PyTypeObject *)(self->mod->UUIDType))) { if (self->builtin_types & MS_BUILTIN_UUID) goto builtin; return to_builtins_uuid(self, obj); diff --git a/tests/test_json.py b/tests/test_json.py index b6344495..a81a3f9b 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -1840,6 +1840,13 @@ def test_decode_dict_int_literal_key(self): with pytest.raises(msgspec.ValidationError, match="Invalid enum value 3"): dec.decode(b'{"-1": 10, "3": 20}') + def test_encode_dict_str_subclass_key(self): + class mystr(str): + pass + + msg = msgspec.json.encode({mystr("test"): 1}) + assert msg == b'{"test":1}' + @pytest.mark.parametrize( "s, error", [ diff --git a/tests/test_msgpack.py b/tests/test_msgpack.py index 6727a961..5e6e5c01 100644 --- a/tests/test_msgpack.py +++ b/tests/test_msgpack.py @@ -1026,6 +1026,14 @@ def test_dict_any_key(self): ): dec.decode(enc.encode({1: 2})) + def test_dict_str_subclass_key(self): + class mystr(str): + pass + + msg1 = msgspec.msgpack.encode({mystr("test"): 1}) + msg2 = msgspec.msgpack.encode({"test": 1}) + assert msg1 == msg2 + def test_dict_typed(self): enc = msgspec.msgpack.Encoder() dec = msgspec.msgpack.Decoder(Dict[str, int]) diff --git a/tests/test_to_builtins.py b/tests/test_to_builtins.py index 004ef0d7..ae4769d4 100644 --- a/tests/test_to_builtins.py +++ b/tests/test_to_builtins.py @@ -262,6 +262,14 @@ class in_type(dict): res = to_builtins(in_type()) assert res == {} + def test_dict_str_subclass_key(self): + class mystr(str): + pass + + msg = to_builtins({mystr("test"): 1}) + assert msg == {"test": 1} + assert type(list(msg.keys())[0]) is str + def test_dict_unsupported_key(self): msg = {Bad(): 1} with pytest.raises(TypeError, match="Encoding objects of type Bad"):