Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#32 from Superjomn/fea/runtime-buffer
Browse files Browse the repository at this point in the history
init runtime buffer
  • Loading branch information
Superjomn authored Feb 20, 2020
2 parents 7cf7c03 + e5bcfed commit 41e59a5
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 0 deletions.
1 change: 1 addition & 0 deletions cinn/runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
cc_library(runtime SRCS
intrinsic.cc
buffer.cc
DEPS common ir)
38 changes: 38 additions & 0 deletions cinn/runtime/buffer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include "cinn/runtime/buffer.h"

namespace cinn {
namespace runtime {

Shape::Shape(const Shape &other) : data_(new value_type[other.ndims()]), ndims_(other.ndims()) {
if (ndims() > 0) {
memcpy(data_, other.data(), ndims_ * sizeof(value_type));
}
}

void Shape::Resize(int ndim) {
CHECK_GT(ndim, 0);
ndims_ = ndim;
if (data_) delete data_;
data_ = new value_type[ndim];
}

Shape::value_type &Shape::operator[](int i) {
CHECK_GT(ndims_, 0) << "shape is empty";
CHECK_LT(i, ndims_) << "index " << i << "out of range " << ndims_;
return data_[i];
}

Shape::value_type Shape::operator[](int i) const {
CHECK_GT(ndims_, 0) << "shape is empty";
CHECK_LT(i, ndims_) << "index " << i << "out of range " << ndims_;
return data_[i];
}

uint32_t Shape::num_elements() const {
uint32_t res = ndims_ > 0 ? 1 : 0;
for (int i = 0; i < ndims(); i++) res *= (*this)[i];
return res;
}

} // namespace runtime
} // namespace cinn
86 changes: 86 additions & 0 deletions cinn/runtime/buffer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#pragma once
#include <glog/logging.h>

#include <string>
/**
* runtime::Buffer is an encapsulation of memory operations.
*/
namespace cinn {
namespace runtime {

/**
* Shape of the buffers.
*/
struct Shape {
using value_type = int32_t;

Shape() = default;

Shape(const Shape& other);

bool defined() const { return data_; }

//! Get the number of dimensions.
uint32_t ndims() const { return ndims_; }

//! Get the mutable data.
value_type* data() { return data_; }
//! Get the immutable data.
const value_type* data() const { return data_; }

//! Resize the number of dimensions.
void Resize(int ndim);

//! Get the number of elements the shape defines.
uint32_t num_elements() const;

//! Get i-th element.
value_type& operator[](int i);
//! Get i-th element.
value_type operator[](int i) const;

private:
uint32_t ndims_{0};
int32_t* data_{};
};

/**
* A C++ wrapper for buffer.
*/
template <typename T>
class Buffer {
public:
Buffer(const Shape& shape) : shape_(shape) {}

//! Allocate the memory in host device.
void AllocHost() {
CHECK(shape_.defined());
data_ = new T[shape_.num_elements()];
CHECK(data_) << "alloc buffer failed";
}
//! Deallocate the memory in host device.
void DeallocHost() {
if (data_) delete data_;
data_ = nullptr;
}

T& operator()(int i0) {
CHECK_EQ(shape_.ndims(), 1);
return static_cast<T*>(data_)[i0];
}
T& operator()(int i0, int i1) {
CHECK_EQ(shape_.ndims(), 2);
return static_cast<T*>(data_)[i0 * shape_[0] + i1];
}
T& operator()(int i0, int i1, int i2) {
CHECK_EQ(shape_.ndims(), 3);
return static_cast<T*>(data_)[i0 * shape_[1] * shape_[2] + i1 * shape_[2] + i2];
}

private:
Shape shape_;
void* data_{};
};

} // namespace runtime
} // namespace cinn

0 comments on commit 41e59a5

Please sign in to comment.