Skip to content

Commit

Permalink
Port nb::bind_map from pybind11 (#114)
Browse files Browse the repository at this point in the history
This commit adds a port of nanobind's `nb::bind_map<T>` feature to create
bindings of STL map types (`map`, `unordered_map`). 

The implementation contains the following simplifications:

1. The C++17 constexpr feature was used to considerably reduce the
   size of the implementation.

2. The key/value/item views are simple wrappers without the need for
   polymorphism or STL unique pointers. They are created once per map type.

The commit also includes a port of the associated pybind11 test suite parts.

Co-authored by: Nicholas Junge <[email protected]>
Co-authored-by: Wenzel Jakob <[email protected]>
  • Loading branch information
nicholasjng authored Jan 10, 2023
1 parent 83ec1ff commit 80df8d4
Show file tree
Hide file tree
Showing 6 changed files with 377 additions and 2 deletions.
5 changes: 4 additions & 1 deletion include/nanobind/nb_class.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ template <op_id id, op_type ot, typename L = undefined_t, typename R = undefined
template <typename T, typename SFINAE = int>
struct is_copy_constructible : std::is_copy_constructible<T> { };

template <typename T>
constexpr bool is_copy_constructible_v = is_copy_constructible<T>::value;

NAMESPACE_END(detail)

template <typename T, typename... Ts>
Expand Down Expand Up @@ -275,7 +278,7 @@ class class_ : public object {
if constexpr (!std::is_same_v<Alias, T>)
d.flags |= (uint32_t) detail::type_flags::is_trampoline;

if constexpr (detail::is_copy_constructible<T>::value) {
if constexpr (detail::is_copy_constructible_v<T>) {
d.flags |= (uint32_t) detail::type_flags::is_copy_constructible;

if constexpr (!std::is_trivially_copy_constructible_v<T>) {
Expand Down
120 changes: 120 additions & 0 deletions include/nanobind/stl/bind_map.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
nanobind/stl/bind_map.h: Automatic creation of bindings for map-style containers
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/

#pragma once

#include <nanobind/nanobind.h>
#include <nanobind/make_iterator.h>
#include <nanobind/stl/detail/traits.h>

NAMESPACE_BEGIN(NB_NAMESPACE)

template <typename Map, typename... Args>
class_<Map> bind_map(handle scope, const char *name, Args &&...args) {
using Key = typename Map::key_type;
using Value = typename Map::mapped_type;

auto cl = class_<Map>(scope, name, std::forward<Args>(args)...)
.def(init<>())

.def("__len__", &Map::size)

.def("__bool__",
[](const Map &m) { return !m.empty(); },
"Check whether the map is nonempty")

.def("__contains__",
[](const Map &m, const Key &k) { return m.find(k) != m.end(); })

.def("__contains__", // fallback for incompatible types
[](const Map &, handle) { return false; })

.def("__iter__",
[](Map &m) {
return make_key_iterator(type<Map>(), "KeyIterator",
m.begin(), m.end());
},
keep_alive<0, 1>())

.def("__getitem__",
[](Map &m, const Key &k) -> Value & {
auto it = m.find(k);
if (it == m.end())
throw key_error();
return it->second;
},
rv_policy::reference_internal
)

.def("__delitem__",
[](Map &m, const Key &k) {
auto it = m.find(k);
if (it == m.end())
throw key_error();
m.erase(it);
}
);

// Assignment operator for copy-assignable/copy-constructible types
if constexpr (detail::is_copy_assignable_v<Value> ||
detail::is_copy_constructible_v<Value>) {
cl.def("__setitem__", [](Map &m, const Key &k, const Value &v) {
if constexpr (detail::is_copy_assignable_v<Value>) {
m[k] = v;
} else {
auto r = m.emplace(k, v);
if (!r.second) {
// Value is not copy-assignable. Erase and retry
m.erase(r.first);
m.emplace(k, v);
}
}
});
}

// Item, value, and key views
struct KeyView { Map &map; };
struct ValueView { Map &map; };
struct ItemView { Map &map; };

class_<ItemView>(cl, "ItemView")
.def("__len__", [](ItemView &v) { return v.map.size(); })
.def("__iter__",
[](ItemView &v) {
return make_iterator(type<Map>(), "ItemIterator",
v.map.begin(), v.map.end());
},
keep_alive<0, 1>());

class_<KeyView>(cl, "KeyView")
.def("__contains__", [](KeyView &v, const Key &k) { return v.map.find(k) != v.map.end(); })
.def("__contains__", [](KeyView &, handle) { return false; })
.def("__len__", [](KeyView &v) { return v.map.size(); })
.def("__iter__",
[](KeyView &v) {
return make_key_iterator(type<Map>(), "KeyIterator",
v.map.begin(), v.map.end());
},
keep_alive<0, 1>());

class_<ValueView>(cl, "ValueView")
.def("__len__", [](ValueView &v) { return v.map.size(); })
.def("__iter__",
[](ValueView &v) {
return make_value_iterator(type<Map>(), "ValueIterator",
v.map.begin(), v.map.end());
},
keep_alive<0, 1>());

cl.def("keys", [](Map &m) { return new KeyView{m}; }, keep_alive<0, 1>());
cl.def("values", [](Map &m) { return new ValueView{m}; }, keep_alive<0, 1>());
cl.def("items", [](Map &m) { return new ItemView{m}; }, keep_alive<0, 1>());

return cl;
}

NAMESPACE_END(NB_NAMESPACE)
25 changes: 24 additions & 1 deletion include/nanobind/stl/detail/traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,36 @@ struct is_copy_constructible<
is_copy_constructible<typename T::value_type>::value;
};

// std::pair is copy-constructible <=> both constituents are copy-constructible
template <typename T1, typename T2>
struct is_copy_constructible<std::pair<T1, T2>> {
static constexpr bool value =
is_copy_constructible<T1>::value ||
is_copy_constructible<T1>::value &&
is_copy_constructible<T2>::value;
};

// Analogous template for checking copy-assignability
template <typename T, typename SFINAE = int>
struct is_copy_assignable : std::is_copy_assignable<T> { };

template <typename T>
struct is_copy_assignable<T,
enable_if_t<std::is_copy_assignable_v<T> &&
std::is_same_v<typename T::value_type &,
typename T::reference>>> {
static constexpr bool value = is_copy_assignable<typename T::value_type>::value;
};

template <typename T1, typename T2>
struct is_copy_assignable<std::pair<T1, T2>> {
static constexpr bool value =
is_copy_assignable<T1>::value &&
is_copy_assignable<T2>::value;
};

template <typename T>
constexpr bool is_copy_assignable_v = is_copy_assignable<T>::value;

NAMESPACE_END(detail)
NAMESPACE_END(NB_NAMESPACE)

2 changes: 2 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ nanobind_add_module(test_functions_ext test_functions.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_classes_ext test_classes.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_holders_ext test_holders.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_stl_ext test_stl.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_bind_map_ext test_stl_bind_map.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_enum_ext test_enum.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_tensor_ext test_tensor.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_intrusive_ext test_intrusive.cpp object.cpp object.h ${NB_EXTRA_ARGS})
Expand Down Expand Up @@ -40,6 +41,7 @@ set(TEST_FILES
test_classes.py
test_holders.py
test_stl.py
test_stl_bind_map.py
test_enum.py
test_tensor.py
test_intrusive.py
Expand Down
76 changes: 76 additions & 0 deletions tests/test_stl_bind_map.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#include <map>
#include <string>
#include <unordered_map>
#include <vector>

#include <nanobind/stl/bind_map.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>

namespace nb = nanobind;

// testing for insertion of non-copyable class
class E_nc {
public:
explicit E_nc(int i) : value{i} {}
E_nc(const E_nc &) = delete;
E_nc &operator=(const E_nc &) = delete;
E_nc(E_nc &&) = default;
E_nc &operator=(E_nc &&) = default;

int value;
};

template <class Map>
Map *times_ten(int n) {
auto *m = new Map();
for (int i = 1; i <= n; i++) {
m->emplace(int(i), E_nc(10 * i));
}
return m;
}

template <class NestMap>
NestMap *times_hundred(int n) {
auto *m = new NestMap();
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
(*m)[i].emplace(int(j * 10), E_nc(100 * j));
}
}
return m;
}

NB_MODULE(test_bind_map_ext, m) {
// test_map_string_double
nb::bind_map<std::map<std::string, double>>(m, "MapStringDouble");
nb::bind_map<std::unordered_map<std::string, double>>(m, "UnorderedMapStringDouble");
// test_map_string_double_const
nb::bind_map<std::map<std::string, double const>>(m, "MapStringDoubleConst");
nb::bind_map<std::unordered_map<std::string, double const>>(m,
"UnorderedMapStringDoubleConst");

nb::class_<E_nc>(m, "ENC").def(nb::init<int>()).def_readwrite("value", &E_nc::value);

nb::bind_map<std::map<int, E_nc>>(m, "MapENC");
m.def("get_mnc", &times_ten<std::map<int, E_nc>>);
nb::bind_map<std::unordered_map<int, E_nc>>(m, "UmapENC");
m.def("get_umnc", &times_ten<std::unordered_map<int, E_nc>>);
// Issue #1885: binding nested std::map<X, Container<E>> with E non-copyable
nb::bind_map<std::map<int, std::vector<E_nc>>>(m, "MapVecENC");
m.def("get_nvnc", [](int n) {
auto *m = new std::map<int, std::vector<E_nc>>();
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
(*m)[i].emplace_back(j);
}
}
return m;
});

nb::bind_map<std::map<int, std::map<int, E_nc>>>(m, "MapMapENC");
m.def("get_nmnc", &times_hundred<std::map<int, std::map<int, E_nc>>>);
nb::bind_map<std::unordered_map<int, std::unordered_map<int, E_nc>>>(m, "UmapUmapENC");
m.def("get_numnc", &times_hundred<std::unordered_map<int, std::unordered_map<int, E_nc>>>);

}
Loading

0 comments on commit 80df8d4

Please sign in to comment.