diff --git a/cuvec/include/cuvec.cuh b/cuvec/include/cuvec.cuh index 721f82e..9b69804 100644 --- a/cuvec/include/cuvec.cuh +++ b/cuvec/include/cuvec.cuh @@ -112,6 +112,13 @@ template struct NDCuVec { if (size != vec.size()) throw std::length_error("reshape: size mismatch"); this->shape = shape; } + std::vector strides() const { + const size_t ndim = this->shape.size(); + std::vector s(ndim); + s[ndim - 1] = sizeof(T); + for (int i = ndim - 2; i >= 0; i--) s[i] = this->shape[i + 1] * s[i + 1]; + return s; + } }; #endif // _CUVEC_H_ diff --git a/cuvec/include/cuvec_pybind11.cuh b/cuvec/include/cuvec_pybind11.cuh index 9aedf47..999be18 100644 --- a/cuvec/include/cuvec_pybind11.cuh +++ b/cuvec/include/cuvec_pybind11.cuh @@ -32,13 +32,9 @@ PYBIND11_MAKE_OPAQUE(NDCuVec); pybind11::class_>(m, PYBIND11_TOSTRING(NDCuVec_##typechar), \ pybind11::buffer_protocol()) \ .def_buffer([](NDCuVec &v) -> pybind11::buffer_info { \ - size_t ndim = v.shape.size(); \ - std::vector strides(ndim); \ - strides[ndim - 1] = sizeof(T); \ - for (int i = ndim - 2; i >= 0; i--) strides[i] = v.shape[i + 1] * strides[i + 1]; \ return pybind11::buffer_info(v.vec.data(), sizeof(T), \ pybind11::format_descriptor::format(), v.shape.size(), \ - v.shape, strides); \ + v.shape, v.strides()); \ }) \ .def(pybind11::init<>()) \ .def(pybind11::init>()) \