diff --git a/cinn/runtime/CMakeLists.txt b/cinn/runtime/CMakeLists.txt index 51581bcdbd142..20f344b472f34 100644 --- a/cinn/runtime/CMakeLists.txt +++ b/cinn/runtime/CMakeLists.txt @@ -1,3 +1,4 @@ cc_library(runtime SRCS intrinsic.cc + buffer.cc DEPS common ir) diff --git a/cinn/runtime/buffer.cc b/cinn/runtime/buffer.cc new file mode 100644 index 0000000000000..c89a9b007090b --- /dev/null +++ b/cinn/runtime/buffer.cc @@ -0,0 +1,38 @@ +#include "cinn/runtime/buffer.h" + +namespace cinn { +namespace runtime { + +Shape::Shape(const Shape &other) : data_(new value_type[other.ndims()]), ndims_(other.ndims()) { + if (ndims() > 0) { + memcpy(data_, other.data(), ndims_ * sizeof(value_type)); + } +} + +void Shape::Resize(int ndim) { + CHECK_GT(ndim, 0); + ndims_ = ndim; + if (data_) delete data_; + data_ = new value_type[ndim]; +} + +Shape::value_type &Shape::operator[](int i) { + CHECK_GT(ndims_, 0) << "shape is empty"; + CHECK_LT(i, ndims_) << "index " << i << "out of range " << ndims_; + return data_[i]; +} + +Shape::value_type Shape::operator[](int i) const { + CHECK_GT(ndims_, 0) << "shape is empty"; + CHECK_LT(i, ndims_) << "index " << i << "out of range " << ndims_; + return data_[i]; +} + +uint32_t Shape::num_elements() const { + uint32_t res = ndims_ > 0 ? 1 : 0; + for (int i = 0; i < ndims(); i++) res *= (*this)[i]; + return res; +} + +} // namespace runtime +} // namespace cinn diff --git a/cinn/runtime/buffer.h b/cinn/runtime/buffer.h new file mode 100644 index 0000000000000..58dfcbcdf47fe --- /dev/null +++ b/cinn/runtime/buffer.h @@ -0,0 +1,86 @@ +#pragma once +#include + +#include +/** + * runtime::Buffer is an encapsulation of memory operations. + */ +namespace cinn { +namespace runtime { + +/** + * Shape of the buffers. + */ +struct Shape { + using value_type = int32_t; + + Shape() = default; + + Shape(const Shape& other); + + bool defined() const { return data_; } + + //! Get the number of dimensions. + uint32_t ndims() const { return ndims_; } + + //! Get the mutable data. + value_type* data() { return data_; } + //! Get the immutable data. + const value_type* data() const { return data_; } + + //! Resize the number of dimensions. + void Resize(int ndim); + + //! Get the number of elements the shape defines. + uint32_t num_elements() const; + + //! Get i-th element. + value_type& operator[](int i); + //! Get i-th element. + value_type operator[](int i) const; + + private: + uint32_t ndims_{0}; + int32_t* data_{}; +}; + +/** + * A C++ wrapper for buffer. + */ +template +class Buffer { + public: + Buffer(const Shape& shape) : shape_(shape) {} + + //! Allocate the memory in host device. + void AllocHost() { + CHECK(shape_.defined()); + data_ = new T[shape_.num_elements()]; + CHECK(data_) << "alloc buffer failed"; + } + //! Deallocate the memory in host device. + void DeallocHost() { + if (data_) delete data_; + data_ = nullptr; + } + + T& operator()(int i0) { + CHECK_EQ(shape_.ndims(), 1); + return static_cast(data_)[i0]; + } + T& operator()(int i0, int i1) { + CHECK_EQ(shape_.ndims(), 2); + return static_cast(data_)[i0 * shape_[0] + i1]; + } + T& operator()(int i0, int i1, int i2) { + CHECK_EQ(shape_.ndims(), 3); + return static_cast(data_)[i0 * shape_[1] * shape_[2] + i1 * shape_[2] + i2]; + } + + private: + Shape shape_; + void* data_{}; +}; + +} // namespace runtime +} // namespace cinn