Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add anyset & frozenset, enable copying (cast) to std::set #3901

Merged
merged 9 commits into from
May 5, 2022
8 changes: 8 additions & 0 deletions include/pybind11/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,14 @@ struct handle_type_name<kwargs> {

template <typename type>
struct pyobject_caster {
template <typename T = type, enable_if_t<std::is_same<T, handle>::value, int> = 0>
pyobject_caster() : value() {}

// `type` may not be default constructible (e.g. frozenset, anyset). Initializing `value`
// to a nil handle is safe since it will only be accessed if `load` succeeds.
template <typename T = type, enable_if_t<std::is_base_of<object, T>::value, int> = 0>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about the changes on the caster: So without a default ctor, this ctor is needed on pyobject_caster? At the very least, we should have a comment about why there are necessary. I am also worried that there are edge cases where this might break when it is used in downstream applications.

Thoughts @rwgk @henryiii

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC we have a choice:

  • Either the change here.
  • Or we need to make anyset and frozenset default constructible.

#3901 (comment)

I agree but think it is not just "still safe" but actually safer: if the previously default-constructed object is accidentally used the symptoms could be highly confusing (silent failures). A noisy failure (segfault b/o nullptr) seems better.

I'll run our (Google's) global testing with this PR to see if there are breakages and report back here. I might have the results only tomorrow.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No breakages! This PR is great. If anyone was previously relying on the default-constructed value, they will probably be glad to see that exposed as an obvious bug now.

pyobject_caster() : value(reinterpret_steal<type>(handle())) {}

template <typename T = type, enable_if_t<std::is_same<T, handle>::value, int> = 0>
bool load(handle src, bool /* convert */) {
value = src;
Expand Down
30 changes: 21 additions & 9 deletions include/pybind11/pytypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1784,25 +1784,37 @@ class kwargs : public dict {
PYBIND11_OBJECT_DEFAULT(kwargs, dict, PyDict_Check)
};

class set : public object {
class anyset : public object {
protected:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, last concern. Why is the scope here protected? I don't see any other Pytypes doing that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifically, I am worried it might interfere with the static "check_" method that isinstance is suppose to be using.

return T::check_(obj);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, it turns out the protected: here doesn't do anything, because the PYBIND11_OBJECT macro expands to starting with public:. See clang -E output below.

(Note that there is an isinstance<anyset> in stl.h, i.e. this had to be the case somehow.)

I think it's best to not get fancy here and to simply remove the protected: line.

class anyset : public object {
protected:
    public: [[deprecated("Use reinterpret_borrow<" "anyset" ">() or reinterpret_steal<" "anyset" ">()")]] anyset(handle h, bool is_borrowed) : object(is_borrowed ? object(h, borrowed_t{}) : object(h, stolen_t{})) {} anyset(handle h, borrowed_t) : object(h, borrowed_t{}) {} anyset(handle h, stolen_t) : object(h, stolen_t{}) {} [[deprecated("Use py::isinstance<py::python_type>(obj) instead")]] bool check() const { return m_ptr != nullptr && ((_Py_IS_TYPE(((const PyObject*)(m_ptr)), &PySet_Type) || _Py_IS_TYPE(((const PyObject*)(m_ptr)), &PyFrozenSet_Type) || PyType_IsSubtype((((PyObject*)(m_ptr))->ob_type), &PySet_Type) || PyType_IsSubtype((((PyObject*)(m_ptr))->ob_type), &PyFrozenSet_Type)) != 0); } static bool check_(handle h) { return h.ptr() != nullptr && (_Py_IS_TYPE(((const PyObject*)(h.ptr())), &PySet_Type) || _Py_IS_TYPE(((const PyObject*)(h.ptr())), &PyFrozenSet_Type) || PyType_IsSubtype((((PyObject*)(h.ptr()))->ob_type), &PySet_Type) || PyType_IsSubtype((((PyObject*)(h.ptr()))->ob_type), &PyFrozenSet_Type)); } template <typename Policy_> anyset(const ::pybind11::detail::accessor<Policy_> &a) : anyset(object(a)) {} anyset(const object &o) : object(o) { if (m_ptr && !check_(m_ptr)) throw ::pybind11::type_error("Object of type '" + ::pybind11::detail::get_fully_qualified_tp_name((((PyObject*)(m_ptr))->ob_type)) + "' is not an instance of '" "anyset" "'"); } anyset(object &&o) : object(std::move(o)) { if (m_ptr && !check_(m_ptr)) throw ::pybind11::type_error("Object of type '" + ::pybind11::detail::get_fully_qualified_tp_name((((PyObject*)(m_ptr))->ob_type)) + "' is not an instance of '" "anyset" "'"); }

public:
    size_t size() const { return static_cast<size_t>(PySet_Size(m_ptr)); }
    bool empty() const { return size() == 0; }
    template <typename T>
    bool contains(T &&val) const {
        return PySet_Contains(m_ptr, detail::object_or_cast(std::forward<T>(val)).ptr()) == 1;
    }
};

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, yes. Removed.

PYBIND11_OBJECT(anyset, object, PyAnySet_Check)

public:
size_t size() const { return static_cast<size_t>(PySet_Size(m_ptr)); }
bool empty() const { return size() == 0; }
template <typename T>
bool contains(T &&val) const {
return PySet_Contains(m_ptr, detail::object_or_cast(std::forward<T>(val)).ptr()) == 1;
}
};

class set : public anyset {
public:
PYBIND11_OBJECT_CVT(set, object, PySet_Check, PySet_New)
set() : object(PySet_New(nullptr), stolen_t{}) {
PYBIND11_OBJECT_CVT(set, anyset, PySet_Check, PySet_New)
set() : anyset(PySet_New(nullptr), stolen_t{}) {
if (!m_ptr) {
pybind11_fail("Could not allocate set object!");
Skylion007 marked this conversation as resolved.
Show resolved Hide resolved
}
}
size_t size() const { return (size_t) PySet_Size(m_ptr); }
bool empty() const { return size() == 0; }
template <typename T>
bool add(T &&val) /* py-non-const */ {
return PySet_Add(m_ptr, detail::object_or_cast(std::forward<T>(val)).ptr()) == 0;
}
void clear() /* py-non-const */ { PySet_Clear(m_ptr); }
template <typename T>
bool contains(T &&val) const {
return PySet_Contains(m_ptr, detail::object_or_cast(std::forward<T>(val)).ptr()) == 1;
}
};

class frozenset : public anyset {
public:
PYBIND11_OBJECT_CVT(frozenset, anyset, PyFrozenSet_Check, PyFrozenSet_New)
};

class function : public object {
Expand Down
4 changes: 2 additions & 2 deletions include/pybind11/stl.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ struct set_caster {
using key_conv = make_caster<Key>;

bool load(handle src, bool convert) {
if (!isinstance<pybind11::set>(src)) {
if (!isinstance<anyset>(src)) {
return false;
}
auto s = reinterpret_borrow<pybind11::set>(src);
auto s = reinterpret_borrow<anyset>(src);
value.clear();
for (auto entry : s) {
key_conv conv;
Expand Down
24 changes: 19 additions & 5 deletions tests/test_pytypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,22 +75,34 @@ TEST_SUBMODULE(pytypes, m) {
m.def("get_none", [] { return py::none(); });
m.def("print_none", [](const py::none &none) { py::print("none: {}"_s.format(none)); });

// test_set
// test_set, test_frozenset
m.def("get_set", []() {
py::set set;
set.add(py::str("key1"));
set.add("key2");
set.add(std::string("key3"));
return set;
});
m.def("print_set", [](const py::set &set) {
m.def("get_frozenset", []() {
py::set set;
set.add(py::str("key1"));
set.add("key2");
set.add(std::string("key3"));
return py::frozenset(set);
});
m.def("print_anyset", [](const py::anyset &set) {
for (auto item : set) {
py::print("key:", item);
}
});
m.def("set_contains",
[](const py::set &set, const py::object &key) { return set.contains(key); });
m.def("set_contains", [](const py::set &set, const char *key) { return set.contains(key); });
m.def("anyset_size", [](const py::anyset &set) { return set.size(); });
m.def("anyset_empty", [](const py::anyset &set) { return set.empty(); });
m.def("anyset_contains",
[](const py::anyset &set, const py::object &key) { return set.contains(key); });
m.def("anyset_contains",
[](const py::anyset &set, const char *key) { return set.contains(key); });
m.def("set_add", [](py::set &set, const py::object &key) { set.add(key); });
m.def("set_clear", [](py::set &set) { set.clear(); });

// test_dict
m.def("get_dict", []() { return py::dict("key"_a = "value"); });
Expand Down Expand Up @@ -310,6 +322,7 @@ TEST_SUBMODULE(pytypes, m) {
"list"_a = py::list(d["list"]),
"dict"_a = py::dict(d["dict"]),
"set"_a = py::set(d["set"]),
"frozenset"_a = py::frozenset(d["frozenset"]),
"memoryview"_a = py::memoryview(d["memoryview"]));
});

Expand All @@ -325,6 +338,7 @@ TEST_SUBMODULE(pytypes, m) {
"list"_a = d["list"].cast<py::list>(),
"dict"_a = d["dict"].cast<py::dict>(),
"set"_a = d["set"].cast<py::set>(),
"frozenset"_a = d["frozenset"].cast<py::frozenset>(),
"memoryview"_a = d["memoryview"].cast<py::memoryview>());
});

Expand Down
47 changes: 40 additions & 7 deletions tests/test_pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,12 @@ def test_none(capture, doc):

def test_set(capture, doc):
s = m.get_set()
assert isinstance(s, set)
assert s == {"key1", "key2", "key3"}

s.add("key4")
with capture:
s.add("key4")
m.print_set(s)
m.print_anyset(s)
assert (
capture.unordered
== """
Expand All @@ -81,12 +82,43 @@ def test_set(capture, doc):
"""
)

assert not m.set_contains(set(), 42)
assert m.set_contains({42}, 42)
assert m.set_contains({"foo"}, "foo")
m.set_add(s, "key5")
assert m.anyset_size(s) == 5

assert doc(m.get_list) == "get_list() -> list"
assert doc(m.print_list) == "print_list(arg0: list) -> None"
m.set_clear(s)
assert m.anyset_empty(s)

assert not m.anyset_contains(set(), 42)
assert m.anyset_contains({42}, 42)
assert m.anyset_contains({"foo"}, "foo")

assert doc(m.get_set) == "get_set() -> set"
assert doc(m.print_anyset) == "print_anyset(arg0: anyset) -> None"


def test_frozenset(capture, doc):
s = m.get_frozenset()
assert isinstance(s, frozenset)
assert s == frozenset({"key1", "key2", "key3"})

with capture:
m.print_anyset(s)
assert (
capture.unordered
== """
key: key1
key: key2
key: key3
"""
)
assert m.anyset_size(s) == 3
assert not m.anyset_empty(s)

assert not m.anyset_contains(frozenset(), 42)
assert m.anyset_contains(frozenset({42}), 42)
assert m.anyset_contains(frozenset({"foo"}), "foo")

assert doc(m.get_frozenset) == "get_frozenset() -> frozenset"


def test_dict(capture, doc):
Expand Down Expand Up @@ -302,6 +334,7 @@ def test_constructors():
list: range(3),
dict: [("two", 2), ("one", 1), ("three", 3)],
set: [4, 4, 5, 6, 6, 6],
frozenset: [4, 4, 5, 6, 6, 6],
memoryview: b"abc",
}
inputs = {k.__name__: v for k, v in data.items()}
Expand Down
1 change: 1 addition & 0 deletions tests/test_stl.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def test_set(doc):
assert s == {"key1", "key2"}
s.add("key3")
assert m.load_set(s)
assert m.load_set(frozenset(s))

assert doc(m.cast_set) == "cast_set() -> Set[str]"
assert doc(m.load_set) == "load_set(arg0: Set[str]) -> bool"
Expand Down