-
Notifications
You must be signed in to change notification settings - Fork 208
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Port
nb::bind_map
from pybind11 (#114)
This commit adds a port of nanobind's `nb::bind_map<T>` feature to create bindings of STL map types (`map`, `unordered_map`). The implementation contains the following simplifications: 1. The C++17 constexpr feature was used to considerably reduce the size of the implementation. 2. The key/value/item views are simple wrappers without the need for polymorphism or STL unique pointers. They are created once per map type. The commit also includes a port of the associated pybind11 test suite parts. Co-authored by: Nicholas Junge <[email protected]> Co-authored by: Wenzel Jakob <[email protected]>
- Loading branch information
1 parent
83ec1ff
commit 2894c92
Showing
6 changed files
with
377 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
/* | ||
nanobind/stl/bind_map.h: Automatic creation of bindings for map-style containers | ||
All rights reserved. Use of this source code is governed by a | ||
BSD-style license that can be found in the LICENSE file. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <nanobind/nanobind.h> | ||
#include <nanobind/make_iterator.h> | ||
#include <nanobind/stl/detail/traits.h> | ||
|
||
NAMESPACE_BEGIN(NB_NAMESPACE) | ||
|
||
template <typename Map, typename... Args> | ||
class_<Map> bind_map(handle scope, const char *name, Args &&...args) { | ||
using Key = typename Map::key_type; | ||
using Value = typename Map::mapped_type; | ||
|
||
auto cl = class_<Map>(scope, name, std::forward<Args>(args)...) | ||
.def(init<>()) | ||
|
||
.def("__len__", &Map::size) | ||
|
||
.def("__bool__", | ||
[](const Map &m) { return !m.empty(); }, | ||
"Check whether the map is nonempty") | ||
|
||
.def("__contains__", | ||
[](const Map &m, const Key &k) { return m.find(k) != m.end(); }) | ||
|
||
.def("__contains__", // fallback for incompatible types | ||
[](const Map &, handle) { return false; }) | ||
|
||
.def("__iter__", | ||
[](Map &m) { | ||
return make_key_iterator(type<Map>(), "KeyIterator", | ||
m.begin(), m.end()); | ||
}, | ||
keep_alive<0, 1>()) | ||
|
||
.def("__getitem__", | ||
[](Map &m, const Key &k) -> Value & { | ||
auto it = m.find(k); | ||
if (it == m.end()) | ||
throw key_error(); | ||
return it->second; | ||
}, | ||
rv_policy::reference_internal | ||
) | ||
|
||
.def("__delitem__", | ||
[](Map &m, const Key &k) { | ||
auto it = m.find(k); | ||
if (it == m.end()) | ||
throw key_error(); | ||
m.erase(it); | ||
} | ||
); | ||
|
||
// Assignment operator for copy-assignable/copy-constructible types | ||
if constexpr (detail::is_copy_assignable_v<Value> || | ||
detail::is_copy_constructible_v<Value>) { | ||
cl.def("__setitem__", [](Map &m, const Key &k, const Value &v) { | ||
if constexpr (detail::is_copy_assignable_v<Value>) { | ||
m[k] = v; | ||
} else { | ||
auto r = m.emplace(k, v); | ||
if (!r.second) { | ||
// Value is not copy-assignable. Erase and retry | ||
m.erase(r.first); | ||
m.emplace(k, v); | ||
} | ||
} | ||
}); | ||
} | ||
|
||
// Item, value, and key views | ||
struct KeyView { Map ↦ }; | ||
struct ValueView { Map ↦ }; | ||
struct ItemView { Map ↦ }; | ||
|
||
class_<ItemView>(cl, "ItemView") | ||
.def("__len__", [](ItemView &v) { return v.map.size(); }) | ||
.def("__iter__", | ||
[](ItemView &v) { | ||
return make_iterator(type<Map>(), "ItemIterator", | ||
v.map.begin(), v.map.end()); | ||
}, | ||
keep_alive<0, 1>()); | ||
|
||
class_<KeyView>(cl, "KeyView") | ||
.def("__contains__", [](KeyView &v, const Key &k) { return v.map.find(k) != v.map.end(); }) | ||
.def("__contains__", [](KeyView &, handle) { return false; }) | ||
.def("__len__", [](KeyView &v) { return v.map.size(); }) | ||
.def("__iter__", | ||
[](KeyView &v) { | ||
return make_key_iterator(type<Map>(), "KeyIterator", | ||
v.map.begin(), v.map.end()); | ||
}, | ||
keep_alive<0, 1>()); | ||
|
||
class_<ValueView>(cl, "ValueView") | ||
.def("__len__", [](ValueView &v) { return v.map.size(); }) | ||
.def("__iter__", | ||
[](ValueView &v) { | ||
return make_value_iterator(type<Map>(), "ValueIterator", | ||
v.map.begin(), v.map.end()); | ||
}, | ||
keep_alive<0, 1>()); | ||
|
||
cl.def("keys", [](Map &m) { return new KeyView{m}; }, keep_alive<0, 1>()); | ||
cl.def("values", [](Map &m) { return new ValueView{m}; }, keep_alive<0, 1>()); | ||
cl.def("items", [](Map &m) { return new ItemView{m}; }, keep_alive<0, 1>()); | ||
|
||
return cl; | ||
} | ||
|
||
NAMESPACE_END(NB_NAMESPACE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
#include <map> | ||
#include <string> | ||
#include <unordered_map> | ||
#include <vector> | ||
|
||
#include <nanobind/stl/bind_map.h> | ||
#include <nanobind/stl/string.h> | ||
#include <nanobind/stl/vector.h> | ||
|
||
namespace nb = nanobind; | ||
|
||
// testing for insertion of non-copyable class | ||
class E_nc { | ||
public: | ||
explicit E_nc(int i) : value{i} {} | ||
E_nc(const E_nc &) = delete; | ||
E_nc &operator=(const E_nc &) = delete; | ||
E_nc(E_nc &&) = default; | ||
E_nc &operator=(E_nc &&) = default; | ||
|
||
int value; | ||
}; | ||
|
||
template <class Map> | ||
Map *times_ten(int n) { | ||
auto *m = new Map(); | ||
for (int i = 1; i <= n; i++) { | ||
m->emplace(int(i), E_nc(10 * i)); | ||
} | ||
return m; | ||
} | ||
|
||
template <class NestMap> | ||
NestMap *times_hundred(int n) { | ||
auto *m = new NestMap(); | ||
for (int i = 1; i <= n; i++) { | ||
for (int j = 1; j <= n; j++) { | ||
(*m)[i].emplace(int(j * 10), E_nc(100 * j)); | ||
} | ||
} | ||
return m; | ||
} | ||
|
||
NB_MODULE(test_bind_map_ext, m) { | ||
// test_map_string_double | ||
nb::bind_map<std::map<std::string, double>>(m, "MapStringDouble"); | ||
nb::bind_map<std::unordered_map<std::string, double>>(m, "UnorderedMapStringDouble"); | ||
// test_map_string_double_const | ||
nb::bind_map<std::map<std::string, double const>>(m, "MapStringDoubleConst"); | ||
nb::bind_map<std::unordered_map<std::string, double const>>(m, | ||
"UnorderedMapStringDoubleConst"); | ||
|
||
nb::class_<E_nc>(m, "ENC").def(nb::init<int>()).def_readwrite("value", &E_nc::value); | ||
|
||
nb::bind_map<std::map<int, E_nc>>(m, "MapENC"); | ||
m.def("get_mnc", ×_ten<std::map<int, E_nc>>); | ||
nb::bind_map<std::unordered_map<int, E_nc>>(m, "UmapENC"); | ||
m.def("get_umnc", ×_ten<std::unordered_map<int, E_nc>>); | ||
// Issue #1885: binding nested std::map<X, Container<E>> with E non-copyable | ||
nb::bind_map<std::map<int, std::vector<E_nc>>>(m, "MapVecENC"); | ||
m.def("get_nvnc", [](int n) { | ||
auto *m = new std::map<int, std::vector<E_nc>>(); | ||
for (int i = 1; i <= n; i++) { | ||
for (int j = 1; j <= n; j++) { | ||
(*m)[i].emplace_back(j); | ||
} | ||
} | ||
return m; | ||
}); | ||
|
||
nb::bind_map<std::map<int, std::map<int, E_nc>>>(m, "MapMapENC"); | ||
m.def("get_nmnc", ×_hundred<std::map<int, std::map<int, E_nc>>>); | ||
nb::bind_map<std::unordered_map<int, std::unordered_map<int, E_nc>>>(m, "UmapUmapENC"); | ||
m.def("get_numnc", ×_hundred<std::unordered_map<int, std::unordered_map<int, E_nc>>>); | ||
|
||
} |
Oops, something went wrong.