diff --git a/docs/guides/new_op/index_cn.rst b/docs/guides/custom_op/index_cn.rst similarity index 87% rename from docs/guides/new_op/index_cn.rst rename to docs/guides/custom_op/index_cn.rst index 0c58c07093f..87950c1343c 100644 --- a/docs/guides/new_op/index_cn.rst +++ b/docs/guides/custom_op/index_cn.rst @@ -7,7 +7,7 @@ 1. C++算子:编写方法较为简洁,不涉及框架内部概念,无需重新编译飞桨框架,以外接模块的方式使用的算子 2. Python算子:使用Python编写实现前向(forward)和反向(backward)方法,在模型组网中使用的自定义API -- `自定义C++算子 <./new_custom_op_cn.html>`_ +- `自定义C++算子 <./new_cpp_op_cn.html>`_ - `自定义Python算子 <./new_python_op_cn.html>`_ @@ -15,6 +15,6 @@ .. toctree:: :hidden: - new_custom_op_cn.md + new_cpp_op_cn.md new_python_op_cn.md diff --git a/docs/guides/new_op/new_custom_op_cn.md b/docs/guides/custom_op/new_cpp_op_cn.md similarity index 78% rename from docs/guides/new_op/new_custom_op_cn.md rename to docs/guides/custom_op/new_cpp_op_cn.md index b64ca9f0375..0ed20bd46ee 100644 --- a/docs/guides/new_op/new_custom_op_cn.md +++ b/docs/guides/custom_op/new_cpp_op_cn.md @@ -17,9 +17,9 @@ 随后即可在模型中使用,下面通过实现一个 `relu` 运算,介绍具体的实现、编译与应用流程。 > 注意事项: -> - 在使用本机制实现自定义算子之前,请确保已经正确安装了 `PaddlePaddle 2.1` 及以上版本 +> - 在使用本机制实现自定义算子之前,请确保已经正确安装了 `PaddlePaddle 2.3` 及以上版本 > - 该机制已支持 `Linux` 、 `Mac` 和 `Windows` 平台。 -> - 本自定义外部算子机制仅保证源码级别的兼容,不保证二进制级别的兼容,例如,基于飞桨旧版本2.0编写的自定义算子源码实现,在飞桨2.1或者后续版本中编译链接使用没有问题,但基于飞桨旧版本2.0编译得到的自定义算子动态库文件(*.so, *.dylib, *.dll),在2.1或者后续发布的版本中可能会加载失败。 +> - 本自定义外部算子机制仅保证源码级别的兼容,不保证二进制级别的兼容,例如,基于飞桨2.3版本编写的自定义算子源码实现,在飞桨2.3或者后续版本中编译链接使用没有问题,但基于飞桨2.3之前的版本编译得到的自定义算子动态库文件(*.so, *.dylib, *.dll),在2.3或者后续发布的版本中可能会加载失败。 ## 自定义算子C++实现 @@ -66,50 +66,131 @@ std::vector OpFucntion(const paddle::Tensor& x, ..., int attr, . > 注:其他类型的数值作为函数输入参数或者返回值将无法编译通过 -#### 设备与数据类型支持 +#### 设备类型 -对于基础的设备和数据类型支持情况,我们定义了两个简单的枚举类: +设备类型使用 `Place` 表示,`Place` 含有内存类型AllocationType与设备ID信息,是 `Tensor` 的基础描述信息之一。 -- 设备表示:`enum class PlaceType { kUNK = -1, kCPU, kGPU };` -- 数据类型表示:`enum class DataType {BOOL, INT8, UINT8, INT16, INT32, INT64, FLOAT16, FLOAT32, FLOAT64, COMPLEX64, COMPLEX128};` +其中设备类型是枚举类型: -> 注:目前仅支持以上设备与数据类型,其他类型会视需求在后续版本支持 +```c++ +enum class AllocationType : int8_t { + UNDEFINED = 0, + CPU = 1, + GPU = 2, + GPUPINNED = 3, + ... +}; +``` + +设备ID是一个int8_t的数值,用于表示当前使用的设备卡号。 + +一些Place使用示例如下: + +```c++ +auto cpu_place = paddle::CPUPlace(); +auto gpu_place = paddle::GPUPlace(); // 默认设备ID为0,一般在自定义算子内使用默认的构造方式即可 +auto gpu_place = paddle::GPUPlace(1); // GPU 1号卡 +``` + +此外,Place还有两个常用的方法: + +- GetType():获取Place的内存类型AllocationType +- GetDeviceId():获取Place的设备ID + +使用示例如下: + +```c++ +auto gpu_place = paddle::GPUPlace(); +auto alloc_type = gpu_place.GetType(); // paddle::AllocationType::GPU +auto dev_id = gpu_place.GetDeviceId(); // 0 +``` + +详细的Place定义请参考 [paddle/phi/common/place.h](https://github.com/PaddlePaddle/Paddle/blob/release/2.3/paddle/phi/common/place.h)。 + +> 注:目前自定义算子仅在CPU与GPU上进行了验证,其他类型会视需求在后续版本支持 + +#### 数据类型 + +数据类型使用 `DataType` 表示,同样是 `Tensor` 的基础描述信息之一,目前主要支持的类型如下: + +```c++ +enum class DataType { + UNDEFINED = 0, + BOOL, + INT8, + UINT8, + INT16, + INT32, + UINT32, + INT64, + UINT64, + BFLOAT16, + FLOAT16, + UINT16, + FLOAT32, + FLOAT64, + COMPLEX64, + COMPLEX128, + ... +} +``` + +详细的DataType定义请参考 [paddle/phi/common/data_type.h](https://github.com/PaddlePaddle/Paddle/blob/release/2.3/paddle/phi/common/data_type.h)。 #### Tensor API -对于 `paddle::Tensor` ,我们目前提供了一些基础的API,包括: - -- 构造API: - - `Tensor(const PlaceType& place, const std::vector& shape)` - - 输入参数 `place` 和 `shape` ,返回一个 `Tensor` 对象 -- 设备相关API: - - `const PlaceType& place() const`:获取 `Tensor` 所在的设备 -- 数据类型相关API: - - `DataType type() const`:获取 `Tensor` 的数据类型 -- 长度与维度相关API: - - `int64_t size() const`:获取 `Tensor` 的数据长度 +(1) `Tensor` 构造 + +对于 `paddle::Tensor` 的构造,我们推荐使用相应的初始化paddle API,包括: + +```c++ +PADDLE_API Tensor empty(const IntArray& shape, DataType dtype=DataType::FLOAT32, const Place& place=CPUPlace()); +PADDLE_API Tensor full(const IntArray& shape, const Scalar& value, DataType dtype=DataType::FLOAT32, const Place& place=CPUPlace()); + +PADDLE_API Tensor empty_like(const Tensor& x, DataType dtype=DataType::UNDEFINED, const Place& place={}); +PADDLE_API Tensor full_like(const Tensor& x, const Scalar& value, DataType dtype=DataType::UNDEFINED, const Place& place={}); +``` + +使用示例如下: + +```c++ +auto tensor = paddle::empty({3, 4}); // default: float32, cpu +auto tensor = paddle::full({3, 4}, 1.0); // default: float32, cpu +auto gpu_tensor = paddle::empty({3, 4}, paddle::DataType::FLOAT64, paddle::GPUPlace()); +auto gpu_tensor = paddle::full({3, 4}, 1.0, paddle::DataType::FLOAT64, paddle::GPUPlace()); +``` + +(2) `Tensor` 成员方法 + +此外 `paddle::Tensor` 自身目前提供了一些基础的功能API,在定义算子最后那个常用的包括: + +- 设备、数据类型获取API: + - `const Place& place() const`:获取 `Tensor` 所在的设备 + - `DataType dtype() const`:获取 `Tensor` 的数据类型 +- 长度与维度获取API: + - `int64_t numel() const`:获取 `Tensor` 的数据长度 - `std::vector shape() const`:获取 `Tensor` 的维度信息 - - `void reshape(const std::vector& shape)`: - - 输入参数 `shape` ,修改 `Tensor` 记录的维度信息,此处不会重新分配存储 - 数据访问API: - - `is_initialized() const`: 确认 `Tensor` 是否已被初始化 - - `template T* data() const`: - - 模板类方法,获取数据内存的起始地址(只读访问) - - `template T* mutable_data(const PlaceType& place)`: - - 模板类方法,输入参数 `place` ,根据 `Tensor.shape` 在指定设备上申请内存,并返回内存的起始地址 - - `Tensor slice(const int64_t begin_idx, const int64_t end_idx) const`: - - 输入参数起始行 `begin_idx` 和终止行 `end_idx`,返回当前 `Tensor` 从起始行(含)到终止行(不含)的一个视图 - > 注:本API仅支持对当前 `Tensor` 的第一个维度(即 axis = 0)进行切分 + - `template const T* data() const`:模板类方法,获取数据内存的起始地址(只读) + - `template T* data()`:模板类方法,获取数据内存的起始地址(读写) +- 状态或属性判断API: + - `bool defined() const`: 确认 `Tensor` 是否有效 + - `bool initialized() const`: 确认 `Tensor` 是否已被初始化 + - `bool is_cpu() const`:确认 `Tensor` 是否在CPU上 + - `bool is_gpu() const`:确认 `Tensor` 是否在GPU上 - 工具类API: - - `template Tensor copy_to(const PlaceType& place) const`: + - `Tensor copy_to(const Place& place, bool blocking) const`: - 模板类方法,输入参数 `place`,将当前 `Tensor` 拷贝到指定设备上并返回 - - `Tensor cast(const DataType& target_type) const`: + - `Tensor cast(DataType target_type) const`: - 输入参数 `target_type` ,将当前 `Tensor` 转换为指定数据类型的 `Tensor` 并返回 + - `Tensor slice(const int64_t begin_idx, const int64_t end_idx) const`: + - 输入参数起始行 begin_idx 和终止行 end_idx,返回当前 Tensor 从起始行(含)到终止行(不含)的一个视图 + - 目前仅支持对当前 Tensor 的第一个维度(即 axis = 0)进行切分 - `cudaStream_t stream() const`: - 用于获取当前 `Tensor` 所处的CUDA Stream(仅在GPU编译版本中生效) - 仅能够获取函数输入 `Tensor` 的stream -> 注:后续会继续扩展其他API,API的声明详见 [Paddle Extension Headers in 2.1](https://github.com/PaddlePaddle/Paddle/tree/release/2.1/paddle/fluid/extension/include) 。 +后续我们会继续扩展其他Tensor API,详细的Tensor定义请参考 [paddle/phi/api/include/tensor.h](https://github.com/PaddlePaddle/Paddle/blob/release/2.3/paddle/phi/api/include/tensor.h) 。 #### Exception API @@ -145,10 +226,135 @@ PD_THROW("PD_THROW returns ", false) // [/User/custom_op/custom_relu_op.cc:82] ``` -对函数写法以及基础API的定义有了初步认识后,下面结合具体的示例进行介绍。 +#### 类Python的C++运算API + +自paddle 2.3版本开始,我们提供定义与用法与相应Python API类似的C++ API,其API命名、参数顺序及类型均和相应的paddle Python API对齐,可以通过查找相应Python API的官方文档了解其用法,并在自定义算子开发时使用。通过调用这些接口,可以省去封装基础运算的时间,从而提高开发效率。 + +在2.3版本支持的C++ API列表如下,可以通过 `paddle::xxx` 进行调用: + +```c++ +PADDLE_API Tensor abs(const Tensor& x); +PADDLE_API Tensor acos(const Tensor& x); +PADDLE_API Tensor acosh(const Tensor& x); +PADDLE_API Tensor add(const Tensor& x, const Tensor& y); +PADDLE_API Tensor allclose(const Tensor& x, const Tensor& y, const Scalar& rtol, const Scalar& atol, bool equal_nan); +PADDLE_API std::tuple argsort(const Tensor& x, int axis, bool descending); +PADDLE_API Tensor asin(const Tensor& x); +PADDLE_API Tensor asinh(const Tensor& x); +PADDLE_API Tensor atan(const Tensor& x); +PADDLE_API Tensor atan2(const Tensor& x, const Tensor& y); +PADDLE_API Tensor atanh(const Tensor& x); +PADDLE_API Tensor bernoulli(const Tensor& x); +PADDLE_API Tensor ceil(const Tensor& x); +PADDLE_API Tensor cholesky(const Tensor& x, bool upper); +PADDLE_API Tensor cholesky_solve(const Tensor& x, const Tensor& y, bool upper); +PADDLE_API Tensor clip(const Tensor& x, const Scalar& min, const Scalar& max); +PADDLE_API Tensor concat(const std::vector& x, const Scalar& axis); +PADDLE_API Tensor conj(const Tensor& x); +PADDLE_API Tensor cos(const Tensor& x); +PADDLE_API Tensor cosh(const Tensor& x); +PADDLE_API Tensor cross(const Tensor& x, const Tensor& y, int axis=9); +PADDLE_API Tensor det(const Tensor& x); +PADDLE_API Tensor diag(const Tensor& x, int offset, float padding_value); +PADDLE_API Tensor diagonal(const Tensor& x, int offset, int axis1, int axis2); +PADDLE_API Tensor digamma(const Tensor& x); +PADDLE_API Tensor dist(const Tensor& x, const Tensor& y, float p); +PADDLE_API Tensor divide(const Tensor& x, const Tensor& y); +PADDLE_API Tensor dot(const Tensor& x, const Tensor& y); +PADDLE_API Tensor elu(const Tensor& x, float alpha); +PADDLE_API Tensor empty(const IntArray& shape, DataType dtype=DataType::FLOAT32, const Place& place=CPUPlace()); +PADDLE_API Tensor empty_like(const Tensor& x, DataType dtype=DataType::UNDEFINED, const Place& place={}); +PADDLE_API Tensor equal_all(const Tensor& x, const Tensor& y); +PADDLE_API Tensor erf(const Tensor& x); +PADDLE_API Tensor erfinv(const Tensor& x); +PADDLE_API Tensor exp(const Tensor& x); +PADDLE_API Tensor expand(const Tensor& x, const IntArray& shape); +PADDLE_API Tensor expm1(const Tensor& x); +PADDLE_API std::tuple flatten(const Tensor& x, int start_axis, int stop_axis); +PADDLE_API Tensor flip(const Tensor& x, const std::vector& axis); +PADDLE_API Tensor floor(const Tensor& x); +PADDLE_API Tensor floor_divide(const Tensor& x, const Tensor& y); +PADDLE_API Tensor full(const IntArray& shape, const Scalar& value, DataType dtype=DataType::FLOAT32, const Place& place=CPUPlace()); +PADDLE_API Tensor gather(const Tensor& x, const Tensor& index, const Scalar& axis=0); +PADDLE_API Tensor gather_nd(const Tensor& x, const Tensor& index); +PADDLE_API Tensor gelu(const Tensor& x, bool approximate); +PADDLE_API Tensor gumbel_softmax(const Tensor& x, float temperature, bool hard, int axis); +PADDLE_API Tensor imag(const Tensor& x); +PADDLE_API Tensor increment(const Tensor& x, float value); +PADDLE_API Tensor index_sample(const Tensor& x, const Tensor& index); +PADDLE_API Tensor is_empty(const Tensor& x); +PADDLE_API Tensor isclose(const Tensor& x, const Tensor& y, const Scalar& rtol, const Scalar& atol, bool equal_nan); +PADDLE_API Tensor isfinite(const Tensor& x); +PADDLE_API Tensor isinf(const Tensor& x); +PADDLE_API Tensor isnan(const Tensor& x); +PADDLE_API Tensor kron(const Tensor& x, const Tensor& y); +PADDLE_API std::tuple kthvalue(const Tensor& x, int k, int axis, bool keepdim); +PADDLE_API Tensor label_smooth(const Tensor& label, paddle::optional prior_dist, float epsilon); +PADDLE_API Tensor lerp(const Tensor& x, const Tensor& y, const Tensor& weight); +PADDLE_API Tensor lgamma(const Tensor& x); +PADDLE_API Tensor log(const Tensor& x); +PADDLE_API Tensor log10(const Tensor& x); +PADDLE_API Tensor log1p(const Tensor& x); +PADDLE_API Tensor log2(const Tensor& x); +PADDLE_API Tensor logit(const Tensor& x, float eps=1e-6f); +PADDLE_API Tensor masked_select(const Tensor& x, const Tensor& mask); +PADDLE_API Tensor matmul(const Tensor& x, const Tensor& y, bool transpose_x=false, bool transpose_y=false); +PADDLE_API Tensor matrix_power(const Tensor& x, int n); +PADDLE_API Tensor maximum(const Tensor& x, const Tensor& y); +PADDLE_API Tensor maxout(const Tensor& x, int groups, int axis); +PADDLE_API Tensor minimum(const Tensor& x, const Tensor& y); +PADDLE_API std::tuple mode(const Tensor& x, int axis, bool keepdim); +PADDLE_API Tensor multi_dot(const std::vector& x); +PADDLE_API Tensor multinomial(const Tensor& x, int num_samples, bool replacement); +PADDLE_API Tensor multiply(const Tensor& x, const Tensor& y); +PADDLE_API Tensor mv(const Tensor& x, const Tensor& vec); +PADDLE_API std::tuple nll_loss(const Tensor& input, const Tensor& label, paddle::optional weight, int64_t ignore_index, const std::string& reduction); +PADDLE_API Tensor one_hot(const Tensor& x, const Scalar& num_classes); +PADDLE_API Tensor pixel_shuffle(const Tensor& x, int upscale_factor, const std::string& data_format); +PADDLE_API Tensor poisson(const Tensor& x); +PADDLE_API std::tuple qr(const Tensor& x, const std::string& mode); +PADDLE_API Tensor real(const Tensor& x); +PADDLE_API Tensor reciprocal(const Tensor& x); +PADDLE_API Tensor relu(const Tensor& x); +PADDLE_API Tensor reshape(const Tensor& x, const IntArray& shape); +PADDLE_API Tensor roll(const Tensor& x, const IntArray& shifts, const std::vector& axis); +PADDLE_API Tensor round(const Tensor& x); +PADDLE_API Tensor rsqrt(const Tensor& x); +PADDLE_API Tensor scatter(const Tensor& x, const Tensor& index, const Tensor& updates, bool overwrite); +PADDLE_API Tensor scatter_nd_add(const Tensor& x, const Tensor& index, const Tensor& updates); +PADDLE_API Tensor selu(const Tensor& x, float scale, float alpha); +PADDLE_API Tensor sign(const Tensor& x); +PADDLE_API Tensor silu(const Tensor& x); +PADDLE_API Tensor sin(const Tensor& x); +PADDLE_API Tensor sinh(const Tensor& x); +PADDLE_API std::vector split(const Tensor& x, const IntArray& num_or_sections, const Scalar& axis); +PADDLE_API Tensor sqrt(const Tensor& x); +PADDLE_API Tensor square(const Tensor& x); +PADDLE_API Tensor stack(const std::vector& x, int axis); +PADDLE_API Tensor strided_slice(const Tensor& x, const std::vector& axes, const IntArray& starts, const IntArray& ends, const IntArray& strides); +PADDLE_API Tensor subtract(const Tensor& x, const Tensor& y); +PADDLE_API Tensor tanh(const Tensor& x); +PADDLE_API Tensor thresholded_relu(const Tensor& x, float threshold); +PADDLE_API Tensor tile(const Tensor& x, const IntArray& repeat_times); +PADDLE_API Tensor trace(const Tensor& x, int offset, int axis1, int axis2); +PADDLE_API Tensor triangular_solve(const Tensor& x, const Tensor& y, bool upper, bool transpose, bool unitriangular); +PADDLE_API std::vector unbind(const Tensor& input, int axis); +PADDLE_API std::tuple unique(const Tensor& x, bool return_index, bool return_inverse, bool return_counts, const std::vector& axis, DataType dtype=DataType::INT64); +PADDLE_API std::tuple unsqueeze(const Tensor& x, const IntArray& axis); +PADDLE_API Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y); +``` + +> 注:后续我们会提供更方便的查阅C++ API文档的入口。 + +在2.3版本,我们共支持了大约250个类似的C++ API,能够覆盖大部分的基础运算,但是除前述的109个C++ API之外,剩余的C++ API由于一些历史原因,其参数列表尚未和相应的Python API对齐,因此目前剩余这些API只能作为experimental的API使用,需要通过 `paddle::experimental::xxx` 进行调用,且这些experimental API在下个版本可能会有不兼容的升级,如果不介意随下一版本升级的话,可以使用,追求稳定的话则不建议使用。 + +如有需要,目前支持的全量API列表(包含experimental API)请参考paddle安装路径下的api.h头文件,以Python3.7为例,其路径是 `python3.7/site-packages/paddle/include/paddle/phi/api/include/api.h`。 + ### 运算函数实现 +对函数写法以及基础API的定义有了初步认识后,下面结合具体的示例进行介绍。 + #### CPU实现 以 `relu` 算子为例,一个支持 `float32` 类型的CPU `relu` 算子运算函数可以实现如下: @@ -160,18 +366,18 @@ PD_THROW("PD_THROW returns ", false) #include -#define CHECK_INPUT(x) PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.") +#define CHECK_INPUT(x) PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.") std::vector ReluCPUForward(const paddle::Tensor& x) { CHECK_INPUT(x); - auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape()); + auto out = paddle::empty_like(x); - auto x_numel = x.size(); + auto x_numel = x.numel(); auto* x_data = x.data(); - auto* out_data = out.mutable_data(x.place()); + auto* out_data = out.data(); - for (int i = 0; i < x_numel; ++i) { + for (int64_t i = 0; i < x_numel; ++i) { out_data[i] = std::max(static_cast(0.), x_data[i]); } @@ -185,14 +391,14 @@ std::vector ReluCPUBackward(const paddle::Tensor& x, CHECK_INPUT(out); CHECK_INPUT(grad_out); - auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, x.shape()); + auto grad_x = paddle::empty_like(x); - auto out_numel = out.size(); + auto out_numel = out.numel(); auto* out_data = out.data(); auto* grad_out_data = grad_out.data(); - auto* grad_x_data = grad_x.mutable_data(x.place()); + auto* grad_x_data = grad_x.data(); - for (int i = 0; i < out_numel; ++i) { + for (int64_t i = 0; i < out_numel; ++i) { grad_x_data[i] = grad_out_data[i] * (out_data[i] > static_cast(0) ? 1. : 0.); } @@ -203,15 +409,13 @@ std::vector ReluCPUBackward(const paddle::Tensor& x, 主要逻辑包括: -1. 创建指定 `place` 和 `shape` 的输出 `Tensor` -2. 获取输入 `Tensor` 的数据区起始地址,为输出 `Tensor` 申请内存并返回数据区起始地址 +1. 创建输出 `Tensor` +2. 获取输入和输出 `Tensor` 的数据区起始地址 3. 计算得到输出 `Tensor` 的数值,返回结果 -> 注:目前尚不支持输入 `Tensor` 的 `inplace` 改动,将会在后续版本支持 - 前述 `relu` 示例实现仅支持 `float32` 类型的计算,如果仅有一种数据类型的支持需求,用以上写法即可。 -如果需要同时支持多种数据类型,例如同时支持 `float32` 与 `float64` 的计算,可以使用相应的dispatch宏进行声明,示例如下: +如果需要同时支持多种数据类型,例如同时支持 `float32` 与 `float64` 的计算,可以使用相应的DIAPATCH宏进行声明,示例如下: - relu_cpu.cc @@ -220,13 +424,13 @@ std::vector ReluCPUBackward(const paddle::Tensor& x, #include -#define CHECK_INPUT(x) PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.") +#define CHECK_INPUT(x) PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.") template void relu_cpu_forward_kernel(const data_t* x_data, data_t* out_data, int64_t x_numel) { - for (int i = 0; i < x_numel; ++i) { + for (int64_t i = 0; i < x_numel; ++i) { out_data[i] = std::max(static_cast(0.), x_data[i]); } } @@ -236,7 +440,7 @@ void relu_cpu_backward_kernel(const data_t* grad_out_data, const data_t* out_data, data_t* grad_x_data, int64_t out_numel) { - for (int i = 0; i < out_numel; ++i) { + for (int64_t i = 0; i < out_numel; ++i) { grad_x_data[i] = grad_out_data[i] * (out_data[i] > static_cast(0) ? 1. : 0.); } @@ -245,12 +449,12 @@ void relu_cpu_backward_kernel(const data_t* grad_out_data, std::vector ReluCPUForward(const paddle::Tensor& x) { CHECK_INPUT(x); - auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape()); + auto out = paddle::empty_like(x); PD_DISPATCH_FLOATING_TYPES( x.type(), "relu_cpu_forward_kernel", ([&] { relu_cpu_forward_kernel( - x.data(), out.mutable_data(x.place()), x.size()); + x.data(), out.data(), x.numel()); })); return {out}; @@ -263,14 +467,14 @@ std::vector ReluCPUBackward(const paddle::Tensor& x, CHECK_INPUT(out); CHECK_INPUT(grad_out); - auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, x.shape()); + auto grad_x = paddle::empty_like(x); PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward_kernel", ([&] { relu_cpu_backward_kernel( grad_out.data(), out.data(), - grad_x.mutable_data(x.place()), - out.size()); + grad_x.data(), + out.numel()); })); return {grad_x}; @@ -287,11 +491,11 @@ std::vector ReluCPUBackward(const paddle::Tensor& x, switch(x.type()) { case paddle::DataType::FLOAT32: relu_cpu_forward_kernel( - x.data(), out.mutable_data(x.place()), x.size()); + x.data(), out.data(), x.numel()); break; case paddle::DataType::FLOAT64: relu_cpu_forward_kernel( - x.data(), out.mutable_data(x.place()), x.size()); + x.data(), out.data(), x.numel()); break; default: PD_THROW( @@ -323,9 +527,9 @@ switch(x.type()) { template __global__ void relu_cuda_forward_kernel(const data_t* x, data_t* y, - int num) { - int gid = blockIdx.x * blockDim.x + threadIdx.x; - for (int i = gid; i < num; i += blockDim.x * gridDim.x) { + int64_t num) { + int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; + for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) { y[i] = max(x[i], static_cast(0.)); } } @@ -334,23 +538,23 @@ template __global__ void relu_cuda_backward_kernel(const data_t* dy, const data_t* y, data_t* dx, - int num) { - int gid = blockIdx.x * blockDim.x + threadIdx.x; - for (int i = gid; i < num; i += blockDim.x * gridDim.x) { + int64_t num) { + int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; + for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) { dx[i] = dy[i] * (y[i] > 0 ? 1. : 0.); } } std::vector relu_cuda_forward(const paddle::Tensor& x) { - auto out = paddle::Tensor(paddle::PlaceType::kGPU, x.shape()); + auto out = paddle::empty_like(x); - int numel = x.size(); - int block = 512; - int grid = (numel + block - 1) / block; + int64_t numel = x.numel(); + int64_t block = 512; + int64_t grid = (numel + block - 1) / block; PD_DISPATCH_FLOATING_TYPES( x.type(), "relu_cuda_forward_kernel", ([&] { relu_cuda_forward_kernel<<>>( - x.data(), out.mutable_data(x.place()), numel); + x.data(), out.data(), numel); })); return {out}; @@ -359,17 +563,17 @@ std::vector relu_cuda_forward(const paddle::Tensor& x) { std::vector relu_cuda_backward(const paddle::Tensor& x, const paddle::Tensor& out, const paddle::Tensor& grad_out) { - auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU, x.shape()); + auto grad_x = paddle::empty_like(x); - int numel = out.size(); - int block = 512; - int grid = (numel + block - 1) / block; + int64_t numel = out.numel(); + int64_t block = 512; + int64_t grid = (numel + block - 1) / block; PD_DISPATCH_FLOATING_TYPES( out.type(), "relu_cuda_backward_kernel", ([&] { relu_cuda_backward_kernel<<>>( grad_out.data(), out.data(), - grad_x.mutable_data(x.place()), + grad_x.data(), numel); })); @@ -383,7 +587,7 @@ std::vector relu_cuda_backward(const paddle::Tensor& x, #include -#define CHECK_INPUT(x) PD_CHECK(x.place() == paddle::PlaceType::kGPU, #x " must be a GPU Tensor.") +#define CHECK_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.") std::vector relu_cuda_forward(const paddle::Tensor& x); std::vector relu_cuda_backward(const paddle::Tensor& x, @@ -409,7 +613,7 @@ std::vector ReluCUDABackward(const paddle::Tensor& x, 在 `.cu` 文件中实现对应的CUDA kernel和计算函数,在 `.cc` 文件中声明调用即可。 -注意这里的 `CHECK_INPUT` 也改为检查输入 `Tensor` 是否在GPU上,如果后续仍然在CPU上执行,将会报错如下,可以看到报错提示与 `CHECK_INPUT` 缩写提示一致。至于错误类型,`PaddlePaddle` 将外部扩展自定义算子视为第三方模块,错误类型统一为 `OSError: (External)` ,与其他第三方库报错类型一致。 +注意这里的 `CHECK_INPUT` 也改为检查输入 `Tensor` 是否在GPU上,如果后续仍然在CPU上执行,将会报错如下,可以看到报错提示与 `CHECK_INPUT` 缩写提示一致。至于错误类型,`PaddlePaddle` 将外部扩展自定义算子视为第三方模块,错误类型统一为 `OSError: (External)` ,与其他第三方库报错类型一致。报错示例如下: ``` Traceback (most recent call last): @@ -440,13 +644,13 @@ OSError: (External) x must be a GPU Tensor. #include -#define CHECK_CPU_INPUT(x) PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.") +#define CHECK_CPU_INPUT(x) PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.") template void relu_cpu_forward_kernel(const data_t* x_data, data_t* out_data, int64_t x_numel) { - for (int i = 0; i < x_numel; ++i) { + for (int64_t i = 0; i < x_numel; ++i) { out_data[i] = std::max(static_cast(0.), x_data[i]); } } @@ -456,7 +660,7 @@ void relu_cpu_backward_kernel(const data_t* grad_out_data, const data_t* out_data, data_t* grad_x_data, int64_t out_numel) { - for (int i = 0; i < out_numel; ++i) { + for (int64_t i = 0; i < out_numel; ++i) { grad_x_data[i] = grad_out_data[i] * (out_data[i] > static_cast(0) ? 1. : 0.); } @@ -465,12 +669,12 @@ void relu_cpu_backward_kernel(const data_t* grad_out_data, std::vector relu_cpu_forward(const paddle::Tensor& x) { CHECK_CPU_INPUT(x); - auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape()); + auto out = paddle::empty_like(x); PD_DISPATCH_FLOATING_TYPES( x.type(), "relu_cpu_forward_kernel", ([&] { relu_cpu_forward_kernel( - x.data(), out.mutable_data(x.place()), x.size()); + x.data(), out.data(), x.numel()); })); return {out}; @@ -483,14 +687,14 @@ std::vector relu_cpu_backward(const paddle::Tensor& x, CHECK_CPU_INPUT(out); CHECK_CPU_INPUT(grad_out); - auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, x.shape()); + auto grad_x = paddle::empty_like(x); PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward_kernel", ([&] { relu_cpu_backward_kernel( grad_out.data(), out.data(), - grad_x.mutable_data(x.place()), - out.size()); + grad_x.data(), + out.numel()); })); return {grad_x}; @@ -508,10 +712,10 @@ std::vector relu_cuda_backward(const paddle::Tensor& x, #endif std::vector ReluForward(const paddle::Tensor& x) { - if (x.place() == paddle::PlaceType::kCPU) { + if (x.is_cpu()) { return relu_cpu_forward(x); #ifdef PADDLE_WITH_CUDA - } else if (x.place() == paddle::PlaceType::kGPU) { + } else if (x.is_gpu()) { return relu_cuda_forward(x); #endif } else { @@ -522,10 +726,10 @@ std::vector ReluForward(const paddle::Tensor& x) { std::vector ReluBackward(const paddle::Tensor& x, const paddle::Tensor& out, const paddle::Tensor& grad_out) { - if (x.place() == paddle::PlaceType::kCPU) { + if (x.is_cpu()) { return relu_cpu_backward(x, out, grad_out); #ifdef PADDLE_WITH_CUDA - } else if (x.place() == paddle::PlaceType::kGPU) { + } else if (x.is_gpu()) { return relu_cuda_backward(x, out, grad_out); #endif } else { @@ -538,14 +742,14 @@ std::vector ReluBackward(const paddle::Tensor& x, ```c++ #include "paddle/extension.h" -#define CHECK_CUDA_INPUT(x) PD_CHECK(x.place() == paddle::PlaceType::kGPU, #x " must be a GPU Tensor.") +#define CHECK_CUDA_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.") template __global__ void relu_cuda_forward_kernel(const data_t* x, data_t* y, - int num) { - int gid = blockIdx.x * blockDim.x + threadIdx.x; - for (int i = gid; i < num; i += blockDim.x * gridDim.x) { + int64_t num) { + int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; + for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) { y[i] = max(x[i], static_cast(0.)); } } @@ -554,9 +758,9 @@ template __global__ void relu_cuda_backward_kernel(const data_t* dy, const data_t* y, data_t* dx, - int num) { - int gid = blockIdx.x * blockDim.x + threadIdx.x; - for (int i = gid; i < num; i += blockDim.x * gridDim.x) { + int64_t num) { + int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; + for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) { dx[i] = dy[i] * (y[i] > 0 ? 1. : 0.); } } @@ -564,15 +768,15 @@ __global__ void relu_cuda_backward_kernel(const data_t* dy, std::vector relu_cuda_forward(const paddle::Tensor& x) { CHECK_CUDA_INPUT(x); - auto out = paddle::Tensor(paddle::PlaceType::kGPU, x.shape()); + auto out = paddle::empty_like(x); - int numel = x.size(); - int block = 512; - int grid = (numel + block - 1) / block; + int64_t numel = x.numel(); + int64_t block = 512; + int64_t grid = (numel + block - 1) / block; PD_DISPATCH_FLOATING_TYPES( x.type(), "relu_cuda_forward_kernel", ([&] { relu_cuda_forward_kernel<<>>( - x.data(), out.mutable_data(x.place()), numel); + x.data(), out.data(), numel); })); return {out}; @@ -585,17 +789,17 @@ std::vector relu_cuda_backward(const paddle::Tensor& x, CHECK_CUDA_INPUT(out); CHECK_CUDA_INPUT(grad_out); - auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU, x.shape()); + auto grad_x = paddle::empty_like(x); - int numel = out.size(); - int block = 512; - int grid = (numel + block - 1) / block; + int64_t numel = out.numel(); + int64_t block = 512; + int64_t grid = (numel + block - 1) / block; PD_DISPATCH_FLOATING_TYPES( out.type(), "relu_cuda_backward_kernel", ([&] { relu_cuda_backward_kernel<<>>( grad_out.data(), out.data(), - grad_x.mutable_data(x.place()), + grad_x.data(), numel); })); @@ -673,7 +877,9 @@ std::vector ConcatInferDtypeStaticAxis( - `PD_BUILD_OP` :用于构建前向算子 - `PD_BUILD_GRAD_OP` :用于构建前向算子对应的反向算子 -- `PD_BUILD_DOUBLE_GRAD_OP` :用于构建前反向算子对应的二次求导算子 +- `PD_BUILD_DOUBLE_GRAD_OP` :用于构建前反向算子对应的二阶反向算子 + +> 注:二阶以上的反向算子构建暂不支持。 对于 `relu` CPU示例来说,构建算子描述如下: @@ -695,7 +901,7 @@ PD_BUILD_GRAD_OP(custom_relu) 这里写法上需要注意以下几点: - `PD_BUILD_OP` 系列宏后面的括号内为算子名,也是后面在python端使用的接口名,注意前后不需要引号,注意该算子名不能与 `PaddlePaddle` 内已有算子名重名,比如 `relu` 为 `PaddlePaddle` 内已有算子,如果直接使用relu作为算子名将无法注册成功,所以此处增加了前缀 `custom_` -- `PD_BUILD_OP`、 `PD_BUILD_GRAD_OP` 和 `PD_BUILD_DOUBLE_GRAD_OP` 构建同一个算子的前向、反向、二次反向实现,宏后面使用的算子名需要保持一致,比如该示例中均使用 `custom_relu` +- `PD_BUILD_OP`、 `PD_BUILD_GRAD_OP` 和 `PD_BUILD_DOUBLE_GRAD_OP` 构建同一个算子的前向、反向、二阶反向实现,宏后面使用的算子名需要保持一致,比如该示例中均使用 `custom_relu` - `PD_BUILD_OP`、 `PD_BUILD_GRAD_OP` 和 `PD_BUILD_DOUBLE_GRAD_OP` 必须顺次调用,不允许在未调用 `PD_BUILD_OP` 构建前向算子的情况下,直接调用 `PD_BUILD_GRAD_OP` 构建反向算子 - Inputs与Outputs的输入参数为 `std::vector` ,依次是前面算子运算函数的输入输出 `Tensor` 的name,需要按顺序一一对应,此处的name与函数输入参数的变量名没有强关联,比如函数输入参数是 `const paddle::Tensor& x` ,Inputs中的name可以是 `Input, x, X, In` 等等 - `PD_BUILD_OP` 与 `PD_BUILD_GRAD_OP` 中的Inputs与Outputs的name有强关联,对于前向算子的某个输入,如果反向算子仍然要复用,那么其name一定要保持一致,因为内部执行时,会以name作为key去查找对应的变量,比如这里前向算子的 `X, Out` 与反向算子的 `X, Out` 指代同一个 `Tensor` @@ -1164,7 +1370,7 @@ for epoch_id in range(EPOCH_NUM): if batch_id % 300 == 0: print("Epoch {} batch {}: loss = {}".format( - epoch_id, batch_id, np.mean(loss.numpy()))) + epoch_id, batch_id, paddle.mean(loss).numpy())) opt.step() opt.clear_grad() @@ -1364,59 +1570,7 @@ int main() { 编写 `CMakeList` 编译构建文件,示例如下: -由于目前自定义算子仍然依赖于 `boost` 库,所以需要编写 `boost` 的编译文件,在当前目录下创建文件夹 `cmake/external` ,在其中创建文件 `boost.cmake` ,文件内容如下: - -- cmake/external/boost.cmake -```cmake -include(ExternalProject) - -set(BOOST_PROJECT "extern_boost") -# To release PaddlePaddle as a pip package, we have to follow the -# manylinux1 standard, which features as old Linux kernels and -# compilers as possible and recommends CentOS 5. Indeed, the earliest -# CentOS version that works with NVIDIA CUDA is CentOS 6. And a new -# version of boost, say, 1.66.0, doesn't build on CentOS 6. We -# checked that the devtools package of CentOS 6 installs boost 1.41.0. -# So we use 1.41.0 here. -set(BOOST_VER "1.41.0") -set(BOOST_TAR "boost_1_41_0" CACHE STRING "" FORCE) -set(BOOST_URL "http://paddlepaddledeps.bj.bcebos.com/${BOOST_TAR}.tar.gz" CACHE STRING "" FORCE) - -MESSAGE(STATUS "BOOST_TAR: ${BOOST_TAR}, BOOST_URL: ${BOOST_URL}") - -set(BOOST_SOURCES_DIR ${THIRD_PARTY_PATH}/boost) -set(BOOST_DOWNLOAD_DIR "${BOOST_SOURCES_DIR}/src/${BOOST_PROJECT}") - -set(BOOST_INCLUDE_DIR "${BOOST_DOWNLOAD_DIR}" CACHE PATH "boost include directory." FORCE) -set_directory_properties(PROPERTIES CLEAN_NO_CUSTOM 1) -include_directories(${BOOST_INCLUDE_DIR}) - -ExternalProject_Add( - ${BOOST_PROJECT} - ${EXTERNAL_PROJECT_LOG_ARGS} - DOWNLOAD_DIR ${BOOST_DOWNLOAD_DIR} - URL ${BOOST_URL} - DOWNLOAD_NO_PROGRESS 1 - PREFIX ${BOOST_SOURCES_DIR} - CONFIGURE_COMMAND "" - BUILD_COMMAND "" - INSTALL_COMMAND "" - UPDATE_COMMAND "" - ) - -if (${CMAKE_VERSION} VERSION_LESS "3.3.0" OR NOT WIN32) - set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/boost_dummy.c) - file(WRITE ${dummyfile} "const char *dummy = \"${dummyfile}\";") - add_library(boost STATIC ${dummyfile}) -else() - add_library(boost INTERFACE) -endif() - -add_dependencies(boost ${BOOST_PROJECT}) -set(Boost_INCLUDE_DIR ${BOOST_INCLUDE_DIR}) -``` - -然后在当前目录创建文件 `CMakeLists.txt` ,其内容为: +在当前目录创建文件 `CMakeLists.txt` ,其内容为: - CMakeLists.txt ```cmake @@ -1428,7 +1582,6 @@ option(USE_TENSORRT "Compile demo with TensorRT." ON) option(CUSTOM_OPERATOR_FILES "List of file names for custom operators" "") set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") -include(external/boost) if(WITH_GPU) find_package(CUDA REQUIRED) @@ -1667,8 +1820,8 @@ make -j ```sh # 根据预编译库中的version.txt信息判断是否将以下三个标记打开 -WITH_MKL=ON -WITH_GPU=ON +WITH_MKL=ON +WITH_GPU=ON USE_TENSORRT=OFF # 配置预测库的根目录 diff --git a/docs/guides/new_op/new_python_op_cn.md b/docs/guides/custom_op/new_python_op_cn.md similarity index 100% rename from docs/guides/new_op/new_python_op_cn.md rename to docs/guides/custom_op/new_python_op_cn.md diff --git a/docs/guides/index_cn.rst b/docs/guides/index_cn.rst index c20007ced31..0ef0a570059 100644 --- a/docs/guides/index_cn.rst +++ b/docs/guides/index_cn.rst @@ -16,7 +16,7 @@ - `性能调优 <./performance_improving/index_cn.html>`_ - `模型迁移 <./model_convert/index_cn.html>`_ - `硬件支持 <./hardware_support/index_cn.html>`_ -- `自定义算子 <./new_op/index_cn.html>`_ +- `自定义算子 <./custom_op/index_cn.html>`_ - `环境变量 <./flags/flags_cn.html>`_ .. toctree:: @@ -30,5 +30,5 @@ performance_improving/index_cn.rst model_convert/index_cn.rst hardware_support/index_cn.rst - new_op/index_cn.rst + custom_op/index_cn.rst flags/flags_cn.rst