Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implicit conversions to bool + np.bool_ conversion #925

Merged
merged 17 commits into from
Jul 23, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions include/pybind11/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -1049,11 +1049,37 @@ template <> class type_caster<std::nullptr_t> : public void_caster<std::nullptr_

template <> class type_caster<bool> {
public:
bool load(handle src, bool) {
bool load(handle src, bool convert) {
if (!src) return false;
else if (src.ptr() == Py_True) { value = true; return true; }
else if (src.ptr() == Py_False) { value = false; return true; }
else return false;
else if (convert || !strcmp("numpy.bool_", Py_TYPE(src.ptr())->tp_name)) {
// (allow non-implicit conversion for numpy booleans)

Py_ssize_t res = -1;
if (src.is_none()) {
res = 0; // None is implicitly converted to False
}
#if defined(PYPY_VERSION)
// On PyPy, check that "__bool__" (or "__nonzero__" on Python 2.7) attr exists
else if (hasattr(src, PYBIND11_BOOL_ATTR)) {
res = PyObject_IsTrue(src.ptr());
}
#else
// Alternate approach for CPython: this does the same as the above, but optimized
// using the CPython API so as to avoid an unneeded attribute lookup.
else if (auto tp_as_number = src.ptr()->ob_type->tp_as_number) {
if (PYBIND11_NB_BOOL(tp_as_number)) {
res = (*PYBIND11_NB_BOOL(tp_as_number))(src.ptr());
}
}
#endif
if (res == 0 || res == 1) {
value = (bool) res;
return true;
}
}
return false;
}
static handle cast(bool src, return_value_policy /* policy */, handle /* parent */) {
return handle(src ? Py_True : Py_False).inc_ref();
Expand Down
5 changes: 5 additions & 0 deletions include/pybind11/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,11 @@
#define PYBIND11_SLICE_OBJECT PyObject
#define PYBIND11_FROM_STRING PyUnicode_FromString
#define PYBIND11_STR_TYPE ::pybind11::str
#define PYBIND11_BOOL_ATTR "__bool__"
#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_bool)
#define PYBIND11_PLUGIN_IMPL(name) \
extern "C" PYBIND11_EXPORT PyObject *PyInit_##name()

#else
#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyMethod_New(ptr, nullptr, class_)
#define PYBIND11_INSTANCE_METHOD_CHECK PyMethod_Check
Expand All @@ -171,6 +174,8 @@
#define PYBIND11_SLICE_OBJECT PySliceObject
#define PYBIND11_FROM_STRING PyString_FromString
#define PYBIND11_STR_TYPE ::pybind11::bytes
#define PYBIND11_BOOL_ATTR "__nonzero__"
#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_nonzero)
#define PYBIND11_PLUGIN_IMPL(name) \
static PyObject *pybind11_init_wrapper(); \
extern "C" PYBIND11_EXPORT void init##name() { \
Expand Down
4 changes: 4 additions & 0 deletions tests/test_builtin_casters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ TEST_SUBMODULE(builtin_casters, m) {
m.def("load_nullptr_t", [](std::nullptr_t) {}); // not useful, but it should still compile
m.def("cast_nullptr_t", []() { return std::nullptr_t{}; });

// test_bool_caster
m.def("bool_passthrough", [](bool arg) { return arg; });
m.def("bool_passthrough_noconvert", [](bool arg) { return arg; }, py::arg().noconvert());

// test_reference_wrapper
m.def("refwrap_builtin", [](std::reference_wrapper<int> p) { return 10 * p.get(); });
m.def("refwrap_usertype", [](std::reference_wrapper<UserType> p) { return p.get().value(); });
Expand Down
55 changes: 55 additions & 0 deletions tests/test_builtin_casters.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,58 @@ def test_complex_cast():
"""std::complex casts"""
assert m.complex_cast(1) == "1.0"
assert m.complex_cast(2j) == "(0.0, 2.0)"


def test_bool_caster():
"""Test bool caster implicit conversions."""
convert, noconvert = m.bool_passthrough, m.bool_passthrough_noconvert

def require_implicit(v):
pytest.raises(TypeError, noconvert, v)

def cant_convert(v):
pytest.raises(TypeError, convert, v)

# straight up bool
assert convert(True) is True
assert convert(False) is False
assert noconvert(True) is True
assert noconvert(False) is False

# None requires implicit conversion
require_implicit(None)
assert convert(None) is False

class A(object):
def __init__(self, x):
self.x = x

def __nonzero__(self):
return self.x

def __bool__(self):
return self.x

class B(object):
pass

# Arbitrary objects are not accepted
cant_convert(object())
cant_convert(B())

# Objects with __nonzero__ / __bool__ defined can be converted
require_implicit(A(True))
assert convert(A(True)) is True
assert convert(A(False)) is False


@pytest.requires_numpy
def test_numpy_bool():
import numpy as np
convert, noconvert = m.bool_passthrough, m.bool_passthrough_noconvert

# np.bool_ is not considered implicit
assert convert(np.bool_(True)) is True
assert convert(np.bool_(False)) is False
assert noconvert(np.bool_(True)) is True
assert noconvert(np.bool_(False)) is False