From 1a820a4ab952b53d01e9644419c9499b15641517 Mon Sep 17 00:00:00 2001 From: Ivan Smirnov Date: Tue, 27 Jun 2017 17:29:46 +0100 Subject: [PATCH] Add tests for generic bool conversions --- tests/test_numpy_dtypes.cpp | 3 ++- tests/test_numpy_dtypes.py | 51 ++++++++++++++++++++++++++++++++++--- 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/tests/test_numpy_dtypes.cpp b/tests/test_numpy_dtypes.cpp index 07fc78b003a..3ca7696c428 100644 --- a/tests/test_numpy_dtypes.cpp +++ b/tests/test_numpy_dtypes.cpp @@ -467,7 +467,8 @@ test_initializer numpy_dtypes([](py::module &m) { m.def("f_packed", [](PackedStruct s) { return s.uint_ * 10; }); m.def("f_nested", [](NestedStruct s) { return s.a.uint_ * 10; }); m.def("register_dtype", []() { PYBIND11_NUMPY_DTYPE(SimpleStruct, bool_, uint_, float_, ldbl_); }); - m.def("negate_bool", [](bool arg) { return !arg; }); + m.def("bool_passthrough", [](bool arg) { return arg; }); + m.def("bool_passthrough_noconvert", [](bool arg) { return arg; }, py::arg().noconvert()); }); #undef PYBIND11_PACKED diff --git a/tests/test_numpy_dtypes.py b/tests/test_numpy_dtypes.py index 2523543f808..77421a4149e 100644 --- a/tests/test_numpy_dtypes.py +++ b/tests/test_numpy_dtypes.py @@ -323,6 +323,51 @@ def test_compare_buffer_info(): @pytest.requires_numpy def test_numpy_bool(): - from pybind11_tests import negate_bool - assert negate_bool(np.bool_(True)) is False - assert negate_bool(np.bool_(False)) is True + from pybind11_tests import bool_passthrough as convert, bool_passthrough_noconvert as noconvert + + require_implicit = lambda v: pytest.raises(TypeError, noconvert, v) + cant_convert = lambda v: pytest.raises(TypeError, convert, v) + + # straight up bool + assert convert(True) is True + assert convert(False) is False + assert noconvert(True) is True + assert noconvert(False) is False + + # np.bool_ is not considered implicit + assert convert(np.bool_(True)) is True + assert convert(np.bool_(False)) is False + assert noconvert(np.bool_(True)) is True + assert noconvert(np.bool_(False)) is False + + # None requires implicit conversion + require_implicit(None) + assert convert(None) is False + + # Sequence types check for lengths (same as in PyObject_IsTrue) + require_implicit([]) + require_implicit(()) + require_implicit('') + assert convert([]) is False + assert convert([1]) is True + assert convert(()) is False + assert convert((0,)) is True + assert convert('') is False + assert convert('foo') is True + + class A(object): + def __init__(self, x): self.x = x + def __nonzero__(self): return self.x + __bool__ = __nonzero__ + + class B(object): + pass + + # Arbitrary objects are not accepted + cant_convert(object()) + cant_convert(B()) + + # Objects with __nonzero__ / __bool__ defined can be converted + require_implicit(A(True)) + assert convert(A(True)) is True + assert convert(A(False)) is False