Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom builtin_types in to_builtins #517

Merged
merged 1 commit into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 94 additions & 47 deletions msgspec/_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -18459,6 +18459,7 @@ typedef struct {
PyObject *enc_hook;
bool str_keys;
uint32_t builtin_types;
PyObject *builtin_types_seq;
} ToBuiltinsState;

static PyObject * to_builtins(ToBuiltinsState *, PyObject *, bool);
Expand Down Expand Up @@ -18948,6 +18949,14 @@ to_builtins(ToBuiltinsState *self, PyObject *obj, bool is_key) {
}
}

if (self->builtin_types_seq != NULL) {
PyObject **items = PySequence_Fast_ITEMS(self->builtin_types_seq);
Py_ssize_t size = PySequence_Fast_GET_SIZE(self->builtin_types_seq);
for (Py_ssize_t i = 0; i < size; i++) {
if (((PyObject *)type) == *(items + i)) goto builtin;
}
}

if (self->enc_hook != NULL) {
PyObject *out = NULL;
PyObject *temp;
Expand All @@ -18971,51 +18980,76 @@ to_builtins(ToBuiltinsState *self, PyObject *obj, bool is_key) {
}

static int
ms_process_builtin_types(MsgspecState *mod, PyObject *builtin_types, uint32_t *mask) {
if (builtin_types != NULL && builtin_types != Py_None) {
PyObject *seq = PySequence_Fast(
builtin_types, "builtin_types must be an iterable of types"
);
if (seq == NULL) return -1;
Py_ssize_t size = PySequence_Fast_GET_SIZE(seq);
for (Py_ssize_t i = 0; i < size; i++) {
PyObject *type = PySequence_Fast_GET_ITEM(seq, i);
if (type == (PyObject *)(&PyBytes_Type)) {
*mask |= MS_BUILTIN_BYTES;
}
else if (type == (PyObject *)(&PyByteArray_Type)) {
*mask |= MS_BUILTIN_BYTEARRAY;
}
else if (type == (PyObject *)(&PyMemoryView_Type)) {
*mask |= MS_BUILTIN_MEMORYVIEW;
}
else if (type == (PyObject *)(PyDateTimeAPI->DateTimeType)) {
*mask |= MS_BUILTIN_DATETIME;
}
else if (type == (PyObject *)(PyDateTimeAPI->DateType)) {
*mask |= MS_BUILTIN_DATE;
}
else if (type == (PyObject *)(PyDateTimeAPI->TimeType)) {
*mask |= MS_BUILTIN_TIME;
}
else if (type == (PyObject *)(PyDateTimeAPI->DeltaType)) {
*mask |= MS_BUILTIN_TIMEDELTA;
}
else if (type == mod->UUIDType) {
*mask |= MS_BUILTIN_UUID;
}
else if (type == mod->DecimalType) {
*mask |= MS_BUILTIN_DECIMAL;
}
else {
PyErr_Format(PyExc_TypeError, "Cannot treat %R as a builtin type", type);
Py_DECREF(seq);
return -1;
}
ms_process_builtin_types(
MsgspecState *mod,
PyObject *builtin_types,
uint32_t *mask,
PyObject **custom_types
) {
if (builtin_types == NULL || builtin_types == Py_None) return 0;

bool forward_builtins_seq = false;
const char *invalid_type_err = "builtin_types must be an iterable of types";

PyObject *seq = PySequence_Fast(builtin_types, invalid_type_err);
if (seq == NULL) return -1;

Py_ssize_t size = PySequence_Fast_GET_SIZE(seq);
PyObject **items = PySequence_Fast_ITEMS(seq);

for (Py_ssize_t i = 0; i < size; i++) {
PyObject *type = *(items + i);
if (type == (PyObject *)(&PyBytes_Type)) {
*mask |= MS_BUILTIN_BYTES;
}
else if (type == (PyObject *)(&PyByteArray_Type)) {
*mask |= MS_BUILTIN_BYTEARRAY;
}
else if (type == (PyObject *)(&PyMemoryView_Type)) {
*mask |= MS_BUILTIN_MEMORYVIEW;
}
else if (type == (PyObject *)(PyDateTimeAPI->DateTimeType)) {
*mask |= MS_BUILTIN_DATETIME;
}
else if (type == (PyObject *)(PyDateTimeAPI->DateType)) {
*mask |= MS_BUILTIN_DATE;
}
else if (type == (PyObject *)(PyDateTimeAPI->TimeType)) {
*mask |= MS_BUILTIN_TIME;
}
else if (type == (PyObject *)(PyDateTimeAPI->DeltaType)) {
*mask |= MS_BUILTIN_TIMEDELTA;
}
else if (type == mod->UUIDType) {
*mask |= MS_BUILTIN_UUID;
}
else if (type == mod->DecimalType) {
*mask |= MS_BUILTIN_DECIMAL;
}
else if (!PyType_Check(type)) {
PyErr_SetString(PyExc_TypeError, invalid_type_err);
goto error;
}
else if (custom_types != NULL) {
forward_builtins_seq = true;
}
else {
PyErr_Format(PyExc_TypeError, "Cannot treat %R as a builtin type", type);
goto error;
}
}

if (forward_builtins_seq) {
*custom_types = seq;
}
else {
Py_DECREF(seq);
}
return 0;

error:
Py_DECREF(seq);
return -1;
}


Expand All @@ -19034,9 +19068,10 @@ PyDoc_STRVAR(msgspec_to_builtins__doc__,
" The object to convert.\n"
"builtin_types: Iterable[type], optional\n"
" An iterable of types to treat as additional builtin types. These types will\n"
" be passed through ``to_builtins`` unchanged. Currently only supports\n"
" `bytes`, `bytearray`, `memoryview`, `datetime.datetime`, `datetime.time`,\n"
" `datetime.date`, `datetime.timedelta`, `uuid.UUID`, and `decimal.Decimal`.\n"
" be passed through ``to_builtins`` unchanged. Currently supports `bytes`,\n"
" `bytearray`, `memoryview`, `datetime.datetime`, `datetime.time`,\n"
" `datetime.date`, `datetime.timedelta`, `uuid.UUID`, `decimal.Decimal`,\n"
" and custom types.\n"
"str_keys: bool, optional\n"
" Whether to convert all object keys to strings. Default is False.\n"
"enc_hook : callable, optional\n"
Expand Down Expand Up @@ -19094,6 +19129,7 @@ msgspec_to_builtins(PyObject *self, PyObject *args, PyObject *kwargs)
state.mod = msgspec_get_global_state();
state.str_keys = str_keys;
state.builtin_types = 0;
state.builtin_types_seq = NULL;

if (enc_hook == Py_None) {
enc_hook = NULL;
Expand All @@ -19103,9 +19139,20 @@ msgspec_to_builtins(PyObject *self, PyObject *args, PyObject *kwargs)
return NULL;
}
state.enc_hook = enc_hook;
if (ms_process_builtin_types(state.mod, builtin_types, &(state.builtin_types)) < 0) return NULL;
if (
ms_process_builtin_types(
state.mod,
builtin_types,
&(state.builtin_types),
&(state.builtin_types_seq)
) < 0
) {
return NULL;
}

return to_builtins(&state, obj, false);
PyObject *out = to_builtins(&state, obj, false);
Py_XDECREF(state.builtin_types_seq);
return out;
}

/*************************************************************************
Expand Down Expand Up @@ -20611,7 +20658,7 @@ msgspec_convert(PyObject *self, PyObject *args, PyObject *kwargs)
state.strict = strict;
if (strict) {
state.str_keys = str_keys;
if (ms_process_builtin_types(state.mod, builtin_types, &(state.builtin_types)) < 0) {
if (ms_process_builtin_types(state.mod, builtin_types, &(state.builtin_types), NULL) < 0) {
return NULL;
}
}
Expand Down
30 changes: 26 additions & 4 deletions tests/test_to_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,17 @@ def test_to_builtins_bad_calls(self):
):
to_builtins([1], builtin_types=1)

with pytest.raises(TypeError) as rec:
to_builtins([1], builtin_types=(int,))
assert "Cannot treat" in str(rec.value)
assert "int" in str(rec.value)
with pytest.raises(
TypeError, match="builtin_types must be an iterable of types"
):
to_builtins([1], builtin_types=(1,))

with pytest.raises(TypeError, match="enc_hook must be callable"):
to_builtins([1], enc_hook=1)

def test_to_builtins_builtin_types_explicit_none(self):
assert to_builtins(1, builtin_types=None) == 1

def test_to_builtins_enc_hook_explicit_none(self):
assert to_builtins(1, enc_hook=None) == 1

Expand Down Expand Up @@ -492,3 +495,22 @@ def test_custom(self):
to_builtins(Bad())

assert to_builtins(Bad(), enc_hook=lambda x: "bad") == "bad"

@pytest.mark.parametrize("col_type", [tuple, list, set])
def test_custom_builtin_types(self, col_type):
class C1:
pass

class C2:
pass

builtins = col_type([C1, bytes, C2])
count = sys.getrefcount(builtins)

for msg in [C1(), C2(), b"test"]:
assert to_builtins(msg, builtin_types=builtins) is msg

with pytest.raises(TypeError, match="Encoding objects of type Bad"):
to_builtins(Bad(), builtin_types=builtins)

assert sys.getrefcount(builtins) == count
Loading