Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nb::bind_map #114

Merged
merged 7 commits into from
Jan 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion include/nanobind/nb_class.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ template <op_id id, op_type ot, typename L = undefined_t, typename R = undefined
template <typename T, typename SFINAE = int>
struct is_copy_constructible : std::is_copy_constructible<T> { };

template <typename T>
constexpr bool is_copy_constructible_v = is_copy_constructible<T>::value;

NAMESPACE_END(detail)

template <typename T, typename... Ts>
Expand Down Expand Up @@ -275,7 +278,7 @@ class class_ : public object {
if constexpr (!std::is_same_v<Alias, T>)
d.flags |= (uint32_t) detail::type_flags::is_trampoline;

if constexpr (detail::is_copy_constructible<T>::value) {
if constexpr (detail::is_copy_constructible_v<T>) {
d.flags |= (uint32_t) detail::type_flags::is_copy_constructible;

if constexpr (!std::is_trivially_copy_constructible_v<T>) {
Expand Down
120 changes: 120 additions & 0 deletions include/nanobind/stl/bind_map.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
nanobind/stl/bind_map.h: Automatic creation of bindings for map-style containers

All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/

#pragma once

#include <nanobind/nanobind.h>
#include <nanobind/make_iterator.h>
#include <nanobind/stl/detail/traits.h>

NAMESPACE_BEGIN(NB_NAMESPACE)

template <typename Map, typename... Args>
class_<Map> bind_map(handle scope, const char *name, Args &&...args) {
using Key = typename Map::key_type;
using Value = typename Map::mapped_type;

auto cl = class_<Map>(scope, name, std::forward<Args>(args)...)
.def(init<>())

.def("__len__", &Map::size)

.def("__bool__",
[](const Map &m) { return !m.empty(); },
"Check whether the map is nonempty")

.def("__contains__",
[](const Map &m, const Key &k) { return m.find(k) != m.end(); })

.def("__contains__", // fallback for incompatible types
[](const Map &, handle) { return false; })

.def("__iter__",
[](Map &m) {
return make_key_iterator(type<Map>(), "KeyIterator",
m.begin(), m.end());
},
keep_alive<0, 1>())

.def("__getitem__",
[](Map &m, const Key &k) -> Value & {
auto it = m.find(k);
if (it == m.end())
throw key_error();
return it->second;
},
rv_policy::reference_internal
)

.def("__delitem__",
[](Map &m, const Key &k) {
auto it = m.find(k);
if (it == m.end())
throw key_error();
m.erase(it);
}
);

// Assignment operator for copy-assignable/copy-constructible types
if constexpr (detail::is_copy_assignable_v<Value> ||
detail::is_copy_constructible_v<Value>) {
cl.def("__setitem__", [](Map &m, const Key &k, const Value &v) {
if constexpr (detail::is_copy_assignable_v<Value>) {
m[k] = v;
} else {
auto r = m.emplace(k, v);
if (!r.second) {
// Value is not copy-assignable. Erase and retry
m.erase(r.first);
m.emplace(k, v);
}
}
});
}

// Item, value, and key views
struct KeyView { Map &map; };
struct ValueView { Map &map; };
struct ItemView { Map &map; };

class_<ItemView>(cl, "ItemView")
.def("__len__", [](ItemView &v) { return v.map.size(); })
.def("__iter__",
[](ItemView &v) {
return make_iterator(type<Map>(), "ItemIterator",
v.map.begin(), v.map.end());
},
keep_alive<0, 1>());

class_<KeyView>(cl, "KeyView")
.def("__contains__", [](KeyView &v, const Key &k) { return v.map.find(k) != v.map.end(); })
.def("__contains__", [](KeyView &, handle) { return false; })
.def("__len__", [](KeyView &v) { return v.map.size(); })
.def("__iter__",
[](KeyView &v) {
return make_key_iterator(type<Map>(), "KeyIterator",
v.map.begin(), v.map.end());
},
keep_alive<0, 1>());

class_<ValueView>(cl, "ValueView")
.def("__len__", [](ValueView &v) { return v.map.size(); })
.def("__iter__",
[](ValueView &v) {
return make_value_iterator(type<Map>(), "ValueIterator",
v.map.begin(), v.map.end());
},
keep_alive<0, 1>());

cl.def("keys", [](Map &m) { return new KeyView{m}; }, keep_alive<0, 1>());
cl.def("values", [](Map &m) { return new ValueView{m}; }, keep_alive<0, 1>());
cl.def("items", [](Map &m) { return new ItemView{m}; }, keep_alive<0, 1>());

return cl;
}

NAMESPACE_END(NB_NAMESPACE)
25 changes: 24 additions & 1 deletion include/nanobind/stl/detail/traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,36 @@ struct is_copy_constructible<
is_copy_constructible<typename T::value_type>::value;
};

// std::pair is copy-constructible <=> both constituents are copy-constructible
template <typename T1, typename T2>
struct is_copy_constructible<std::pair<T1, T2>> {
static constexpr bool value =
is_copy_constructible<T1>::value ||
is_copy_constructible<T1>::value &&
is_copy_constructible<T2>::value;
};

// Analogous template for checking copy-assignability
template <typename T, typename SFINAE = int>
struct is_copy_assignable : std::is_copy_assignable<T> { };

template <typename T>
struct is_copy_assignable<T,
enable_if_t<std::is_copy_assignable_v<T> &&
std::is_same_v<typename T::value_type &,
typename T::reference>>> {
static constexpr bool value = is_copy_assignable<typename T::value_type>::value;
};

template <typename T1, typename T2>
struct is_copy_assignable<std::pair<T1, T2>> {
static constexpr bool value =
is_copy_assignable<T1>::value &&
is_copy_assignable<T2>::value;
};

template <typename T>
constexpr bool is_copy_assignable_v = is_copy_assignable<T>::value;

NAMESPACE_END(detail)
NAMESPACE_END(NB_NAMESPACE)

2 changes: 2 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ nanobind_add_module(test_functions_ext test_functions.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_classes_ext test_classes.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_holders_ext test_holders.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_stl_ext test_stl.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_bind_map_ext test_stl_bind_map.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_enum_ext test_enum.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_tensor_ext test_tensor.cpp ${NB_EXTRA_ARGS})
nanobind_add_module(test_intrusive_ext test_intrusive.cpp object.cpp object.h ${NB_EXTRA_ARGS})
Expand Down Expand Up @@ -40,6 +41,7 @@ set(TEST_FILES
test_classes.py
test_holders.py
test_stl.py
test_stl_bind_map.py
test_enum.py
test_tensor.py
test_intrusive.py
Expand Down
76 changes: 76 additions & 0 deletions tests/test_stl_bind_map.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#include <map>
#include <string>
#include <unordered_map>
#include <vector>

#include <nanobind/stl/bind_map.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>

namespace nb = nanobind;

// testing for insertion of non-copyable class
class E_nc {
public:
explicit E_nc(int i) : value{i} {}
E_nc(const E_nc &) = delete;
E_nc &operator=(const E_nc &) = delete;
E_nc(E_nc &&) = default;
E_nc &operator=(E_nc &&) = default;

int value;
};

template <class Map>
Map *times_ten(int n) {
auto *m = new Map();
for (int i = 1; i <= n; i++) {
m->emplace(int(i), E_nc(10 * i));
}
return m;
}

template <class NestMap>
NestMap *times_hundred(int n) {
auto *m = new NestMap();
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
(*m)[i].emplace(int(j * 10), E_nc(100 * j));
}
}
return m;
}

NB_MODULE(test_bind_map_ext, m) {
// test_map_string_double
nb::bind_map<std::map<std::string, double>>(m, "MapStringDouble");
nb::bind_map<std::unordered_map<std::string, double>>(m, "UnorderedMapStringDouble");
// test_map_string_double_const
nb::bind_map<std::map<std::string, double const>>(m, "MapStringDoubleConst");
nb::bind_map<std::unordered_map<std::string, double const>>(m,
"UnorderedMapStringDoubleConst");

nb::class_<E_nc>(m, "ENC").def(nb::init<int>()).def_readwrite("value", &E_nc::value);

nb::bind_map<std::map<int, E_nc>>(m, "MapENC");
m.def("get_mnc", &times_ten<std::map<int, E_nc>>);
nb::bind_map<std::unordered_map<int, E_nc>>(m, "UmapENC");
m.def("get_umnc", &times_ten<std::unordered_map<int, E_nc>>);
// Issue #1885: binding nested std::map<X, Container<E>> with E non-copyable
nb::bind_map<std::map<int, std::vector<E_nc>>>(m, "MapVecENC");
m.def("get_nvnc", [](int n) {
auto *m = new std::map<int, std::vector<E_nc>>();
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
(*m)[i].emplace_back(j);
}
}
return m;
});

nb::bind_map<std::map<int, std::map<int, E_nc>>>(m, "MapMapENC");
m.def("get_nmnc", &times_hundred<std::map<int, std::map<int, E_nc>>>);
nb::bind_map<std::unordered_map<int, std::unordered_map<int, E_nc>>>(m, "UmapUmapENC");
m.def("get_numnc", &times_hundred<std::unordered_map<int, std::unordered_map<int, E_nc>>>);

}
Loading