From 80df8d4fcfb739924868ff4dda1fcb6ad265a3b6 Mon Sep 17 00:00:00 2001 From: Nicholas Junge Date: Tue, 10 Jan 2023 13:18:05 +0100 Subject: [PATCH] Port `nb::bind_map` from pybind11 (#114) This commit adds a port of nanobind's `nb::bind_map` feature to create bindings of STL map types (`map`, `unordered_map`). The implementation contains the following simplifications: 1. The C++17 constexpr feature was used to considerably reduce the size of the implementation. 2. The key/value/item views are simple wrappers without the need for polymorphism or STL unique pointers. They are created once per map type. The commit also includes a port of the associated pybind11 test suite parts. Co-authored by: Nicholas Junge Co-authored-by: Wenzel Jakob --- include/nanobind/nb_class.h | 5 +- include/nanobind/stl/bind_map.h | 120 +++++++++++++++++++++ include/nanobind/stl/detail/traits.h | 25 ++++- tests/CMakeLists.txt | 2 + tests/test_stl_bind_map.cpp | 76 ++++++++++++++ tests/test_stl_bind_map.py | 151 +++++++++++++++++++++++++++ 6 files changed, 377 insertions(+), 2 deletions(-) create mode 100644 include/nanobind/stl/bind_map.h create mode 100644 tests/test_stl_bind_map.cpp create mode 100644 tests/test_stl_bind_map.py diff --git a/include/nanobind/nb_class.h b/include/nanobind/nb_class.h index 30233327f..9e454456a 100644 --- a/include/nanobind/nb_class.h +++ b/include/nanobind/nb_class.h @@ -239,6 +239,9 @@ template struct is_copy_constructible : std::is_copy_constructible { }; +template +constexpr bool is_copy_constructible_v = is_copy_constructible::value; + NAMESPACE_END(detail) template @@ -275,7 +278,7 @@ class class_ : public object { if constexpr (!std::is_same_v) d.flags |= (uint32_t) detail::type_flags::is_trampoline; - if constexpr (detail::is_copy_constructible::value) { + if constexpr (detail::is_copy_constructible_v) { d.flags |= (uint32_t) detail::type_flags::is_copy_constructible; if constexpr (!std::is_trivially_copy_constructible_v) { diff --git a/include/nanobind/stl/bind_map.h b/include/nanobind/stl/bind_map.h new file mode 100644 index 000000000..7dfb0a102 --- /dev/null +++ b/include/nanobind/stl/bind_map.h @@ -0,0 +1,120 @@ +/* + nanobind/stl/bind_map.h: Automatic creation of bindings for map-style containers + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include +#include +#include + +NAMESPACE_BEGIN(NB_NAMESPACE) + +template +class_ bind_map(handle scope, const char *name, Args &&...args) { + using Key = typename Map::key_type; + using Value = typename Map::mapped_type; + + auto cl = class_(scope, name, std::forward(args)...) + .def(init<>()) + + .def("__len__", &Map::size) + + .def("__bool__", + [](const Map &m) { return !m.empty(); }, + "Check whether the map is nonempty") + + .def("__contains__", + [](const Map &m, const Key &k) { return m.find(k) != m.end(); }) + + .def("__contains__", // fallback for incompatible types + [](const Map &, handle) { return false; }) + + .def("__iter__", + [](Map &m) { + return make_key_iterator(type(), "KeyIterator", + m.begin(), m.end()); + }, + keep_alive<0, 1>()) + + .def("__getitem__", + [](Map &m, const Key &k) -> Value & { + auto it = m.find(k); + if (it == m.end()) + throw key_error(); + return it->second; + }, + rv_policy::reference_internal + ) + + .def("__delitem__", + [](Map &m, const Key &k) { + auto it = m.find(k); + if (it == m.end()) + throw key_error(); + m.erase(it); + } + ); + + // Assignment operator for copy-assignable/copy-constructible types + if constexpr (detail::is_copy_assignable_v || + detail::is_copy_constructible_v) { + cl.def("__setitem__", [](Map &m, const Key &k, const Value &v) { + if constexpr (detail::is_copy_assignable_v) { + m[k] = v; + } else { + auto r = m.emplace(k, v); + if (!r.second) { + // Value is not copy-assignable. Erase and retry + m.erase(r.first); + m.emplace(k, v); + } + } + }); + } + + // Item, value, and key views + struct KeyView { Map ↦ }; + struct ValueView { Map ↦ }; + struct ItemView { Map ↦ }; + + class_(cl, "ItemView") + .def("__len__", [](ItemView &v) { return v.map.size(); }) + .def("__iter__", + [](ItemView &v) { + return make_iterator(type(), "ItemIterator", + v.map.begin(), v.map.end()); + }, + keep_alive<0, 1>()); + + class_(cl, "KeyView") + .def("__contains__", [](KeyView &v, const Key &k) { return v.map.find(k) != v.map.end(); }) + .def("__contains__", [](KeyView &, handle) { return false; }) + .def("__len__", [](KeyView &v) { return v.map.size(); }) + .def("__iter__", + [](KeyView &v) { + return make_key_iterator(type(), "KeyIterator", + v.map.begin(), v.map.end()); + }, + keep_alive<0, 1>()); + + class_(cl, "ValueView") + .def("__len__", [](ValueView &v) { return v.map.size(); }) + .def("__iter__", + [](ValueView &v) { + return make_value_iterator(type(), "ValueIterator", + v.map.begin(), v.map.end()); + }, + keep_alive<0, 1>()); + + cl.def("keys", [](Map &m) { return new KeyView{m}; }, keep_alive<0, 1>()); + cl.def("values", [](Map &m) { return new ValueView{m}; }, keep_alive<0, 1>()); + cl.def("items", [](Map &m) { return new ItemView{m}; }, keep_alive<0, 1>()); + + return cl; +} + +NAMESPACE_END(NB_NAMESPACE) diff --git a/include/nanobind/stl/detail/traits.h b/include/nanobind/stl/detail/traits.h index 36efcb139..d8ad3abec 100644 --- a/include/nanobind/stl/detail/traits.h +++ b/include/nanobind/stl/detail/traits.h @@ -31,13 +31,36 @@ struct is_copy_constructible< is_copy_constructible::value; }; +// std::pair is copy-constructible <=> both constituents are copy-constructible template struct is_copy_constructible> { static constexpr bool value = - is_copy_constructible::value || + is_copy_constructible::value && is_copy_constructible::value; }; +// Analogous template for checking copy-assignability +template +struct is_copy_assignable : std::is_copy_assignable { }; + +template +struct is_copy_assignable && + std::is_same_v>> { + static constexpr bool value = is_copy_assignable::value; +}; + +template +struct is_copy_assignable> { + static constexpr bool value = + is_copy_assignable::value && + is_copy_assignable::value; +}; + +template +constexpr bool is_copy_assignable_v = is_copy_assignable::value; + NAMESPACE_END(detail) NAMESPACE_END(NB_NAMESPACE) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 4e50777f1..b544bd626 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -13,6 +13,7 @@ nanobind_add_module(test_functions_ext test_functions.cpp ${NB_EXTRA_ARGS}) nanobind_add_module(test_classes_ext test_classes.cpp ${NB_EXTRA_ARGS}) nanobind_add_module(test_holders_ext test_holders.cpp ${NB_EXTRA_ARGS}) nanobind_add_module(test_stl_ext test_stl.cpp ${NB_EXTRA_ARGS}) +nanobind_add_module(test_bind_map_ext test_stl_bind_map.cpp ${NB_EXTRA_ARGS}) nanobind_add_module(test_enum_ext test_enum.cpp ${NB_EXTRA_ARGS}) nanobind_add_module(test_tensor_ext test_tensor.cpp ${NB_EXTRA_ARGS}) nanobind_add_module(test_intrusive_ext test_intrusive.cpp object.cpp object.h ${NB_EXTRA_ARGS}) @@ -40,6 +41,7 @@ set(TEST_FILES test_classes.py test_holders.py test_stl.py + test_stl_bind_map.py test_enum.py test_tensor.py test_intrusive.py diff --git a/tests/test_stl_bind_map.cpp b/tests/test_stl_bind_map.cpp new file mode 100644 index 000000000..b5293153f --- /dev/null +++ b/tests/test_stl_bind_map.cpp @@ -0,0 +1,76 @@ +#include +#include +#include +#include + +#include +#include +#include + +namespace nb = nanobind; + +// testing for insertion of non-copyable class +class E_nc { +public: + explicit E_nc(int i) : value{i} {} + E_nc(const E_nc &) = delete; + E_nc &operator=(const E_nc &) = delete; + E_nc(E_nc &&) = default; + E_nc &operator=(E_nc &&) = default; + + int value; +}; + +template +Map *times_ten(int n) { + auto *m = new Map(); + for (int i = 1; i <= n; i++) { + m->emplace(int(i), E_nc(10 * i)); + } + return m; +} + +template +NestMap *times_hundred(int n) { + auto *m = new NestMap(); + for (int i = 1; i <= n; i++) { + for (int j = 1; j <= n; j++) { + (*m)[i].emplace(int(j * 10), E_nc(100 * j)); + } + } + return m; +} + +NB_MODULE(test_bind_map_ext, m) { + // test_map_string_double + nb::bind_map>(m, "MapStringDouble"); + nb::bind_map>(m, "UnorderedMapStringDouble"); + // test_map_string_double_const + nb::bind_map>(m, "MapStringDoubleConst"); + nb::bind_map>(m, + "UnorderedMapStringDoubleConst"); + + nb::class_(m, "ENC").def(nb::init()).def_readwrite("value", &E_nc::value); + + nb::bind_map>(m, "MapENC"); + m.def("get_mnc", ×_ten>); + nb::bind_map>(m, "UmapENC"); + m.def("get_umnc", ×_ten>); + // Issue #1885: binding nested std::map> with E non-copyable + nb::bind_map>>(m, "MapVecENC"); + m.def("get_nvnc", [](int n) { + auto *m = new std::map>(); + for (int i = 1; i <= n; i++) { + for (int j = 1; j <= n; j++) { + (*m)[i].emplace_back(j); + } + } + return m; + }); + + nb::bind_map>>(m, "MapMapENC"); + m.def("get_nmnc", ×_hundred>>); + nb::bind_map>>(m, "UmapUmapENC"); + m.def("get_numnc", ×_hundred>>); + +} diff --git a/tests/test_stl_bind_map.py b/tests/test_stl_bind_map.py new file mode 100644 index 000000000..db484e9c3 --- /dev/null +++ b/tests/test_stl_bind_map.py @@ -0,0 +1,151 @@ +import pytest + +import test_bind_map_ext as t + + +def test_map_string_double(): + mm = t.MapStringDouble() + mm["a"] = 1 + mm["b"] = 2.5 + + assert list(mm) == ["a", "b"] + assert "b" in mm + assert "c" not in mm + assert 123 not in mm + + # Check that keys, values, items are views, not merely iterable + keys = mm.keys() + values = mm.values() + items = mm.items() + assert list(keys) == ["a", "b"] + assert len(keys) == 2 + assert "a" in keys + assert "c" not in keys + assert 123 not in keys + assert list(items) == [("a", 1), ("b", 2.5)] + assert len(items) == 2 + assert ("b", 2.5) in items + assert "hello" not in items + assert ("b", 2.5, None) not in items + assert list(values) == [1, 2.5] + assert len(values) == 2 + assert 1 in values + assert 2 not in values + # Check that views update when the map is updated + mm["c"] = -1 + assert list(keys) == ["a", "b", "c"] + assert list(values) == [1, 2.5, -1] + assert list(items) == [("a", 1), ("b", 2.5), ("c", -1)] + + um = t.UnorderedMapStringDouble() + um["ua"] = 1.1 + um["ub"] = 2.6 + + assert sorted(list(um)) == ["ua", "ub"] + assert list(um.keys()) == list(um) + assert sorted(list(um.items())) == [("ua", 1.1), ("ub", 2.6)] + assert list(zip(um.keys(), um.values())) == list(um.items()) + assert "UnorderedMapStringDouble" in str(um) + + assert type(keys).__qualname__ == 'MapStringDouble.KeyView' + assert type(values).__qualname__ == 'MapStringDouble.ValueView' + assert type(items).__qualname__ == 'MapStringDouble.ItemView' + + +def test_map_string_double_const(): + mc = t.MapStringDoubleConst() + mc["a"] = 10 + mc["b"] = 20.5 + + umc = t.UnorderedMapStringDoubleConst() + umc["a"] = 11 + umc["b"] = 21.5 + + str(umc) + + +def test_maps_with_noncopyable_values(): + # std::map + mnc = t.get_mnc(5) + for i in range(1, 6): + assert mnc[i].value == 10 * i + + vsum = 0 + for k, v in mnc.items(): + assert v.value == 10 * k + vsum += v.value + + assert vsum == 150 + + # std::unordered_map + mnc = t.get_umnc(5) + for i in range(1, 6): + assert mnc[i].value == 10 * i + + vsum = 0 + for k, v in mnc.items(): + assert v.value == 10 * k + vsum += v.value + + assert vsum == 150 + + # nested std::map + nvnc = t.get_nvnc(5) + for i in range(1, 6): + for j in range(0, 5): + assert nvnc[i][j].value == j + 1 + + # Note: maps do not have .values() + for _, v in nvnc.items(): + for i, j in enumerate(v, start=1): + assert j.value == i + + # nested std::map + nmnc = t.get_nmnc(5) + for i in range(1, 6): + for j in range(10, 60, 10): + assert nmnc[i][j].value == 10 * j + + vsum = 0 + for _, v_o in nmnc.items(): + for k_i, v_i in v_o.items(): + assert v_i.value == 10 * k_i + vsum += v_i.value + + assert vsum == 7500 + + # nested std::unordered_map + numnc = t.get_numnc(5) + for i in range(1, 6): + for j in range(10, 60, 10): + assert numnc[i][j].value == 10 * j + + vsum = 0 + for _, v_o in numnc.items(): + for k_i, v_i in v_o.items(): + assert v_i.value == 10 * k_i + vsum += v_i.value + + assert vsum == 7500 + + +def test_map_delitem(): + mm = t.MapStringDouble() + mm["a"] = 1 + mm["b"] = 2.5 + + assert list(mm) == ["a", "b"] + assert list(mm.items()) == [("a", 1), ("b", 2.5)] + del mm["a"] + assert list(mm) == ["b"] + assert list(mm.items()) == [("b", 2.5)] + + um = t.UnorderedMapStringDouble() + um["ua"] = 1.1 + um["ub"] = 2.6 + + assert sorted(list(um)) == ["ua", "ub"] + assert sorted(list(um.items())) == [("ua", 1.1), ("ub", 2.6)] + del um["ua"] + assert sorted(list(um)) == ["ub"] + assert sorted(list(um.items())) == [("ub", 2.6)]