Skip to content

Commit

Permalink
expose NDCuVec.strides()
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl committed Mar 13, 2024
1 parent e631d2a commit b6ed960
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
7 changes: 7 additions & 0 deletions cuvec/include/cuvec.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ template <class T> struct NDCuVec {
if (size != vec.size()) throw std::length_error("reshape: size mismatch");
this->shape = shape;
}
std::vector<size_t> strides() const {
const size_t ndim = this->shape.size();
std::vector<size_t> s(ndim);
s[ndim - 1] = sizeof(T);
for (int i = ndim - 2; i >= 0; i--) s[i] = this->shape[i + 1] * s[i + 1];
return s;
}
};

#endif // _CUVEC_H_
6 changes: 1 addition & 5 deletions cuvec/include/cuvec_pybind11.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,9 @@ PYBIND11_MAKE_OPAQUE(NDCuVec<double>);
pybind11::class_<NDCuVec<T>>(m, PYBIND11_TOSTRING(NDCuVec_##typechar), \
pybind11::buffer_protocol()) \
.def_buffer([](NDCuVec<T> &v) -> pybind11::buffer_info { \
size_t ndim = v.shape.size(); \
std::vector<size_t> strides(ndim); \
strides[ndim - 1] = sizeof(T); \
for (int i = ndim - 2; i >= 0; i--) strides[i] = v.shape[i + 1] * strides[i + 1]; \
return pybind11::buffer_info(v.vec.data(), sizeof(T), \
pybind11::format_descriptor<T>::format(), v.shape.size(), \
v.shape, strides); \
v.shape, v.strides()); \
}) \
.def(pybind11::init<>()) \
.def(pybind11::init<std::vector<size_t>>()) \
Expand Down

0 comments on commit b6ed960

Please sign in to comment.