diff --git a/docs/dev_guides/api_contributing_guides/api_contributing_guides_cn.rst b/docs/dev_guides/api_contributing_guides/api_contributing_guides_cn.rst index 7bd2a84a8b2..c027d16e4d2 100644 --- a/docs/dev_guides/api_contributing_guides/api_contributing_guides_cn.rst +++ b/docs/dev_guides/api_contributing_guides/api_contributing_guides_cn.rst @@ -105,6 +105,5 @@ API设计文档的目的是为了社区开发者更容易的参与开源项目 api_design_guidelines_standard_cn.md new_python_api_cn.md new_cpp_op_cn.md - new_cpp_op_notes_cn.md api_docs_guidelines_cn.md api_accpetance_criteria_cn.md diff --git a/docs/dev_guides/api_contributing_guides/code_gen_by_yaml.png b/docs/dev_guides/api_contributing_guides/code_gen_by_yaml.png new file mode 100644 index 00000000000..32dd779608f Binary files /dev/null and b/docs/dev_guides/api_contributing_guides/code_gen_by_yaml.png differ diff --git a/docs/dev_guides/api_contributing_guides/new_cpp_op_cn.md b/docs/dev_guides/api_contributing_guides/new_cpp_op_cn.md index 5215f6ed5ee..f2357428237 100644 --- a/docs/dev_guides/api_contributing_guides/new_cpp_op_cn.md +++ b/docs/dev_guides/api_contributing_guides/new_cpp_op_cn.md @@ -1,10 +1,10 @@ -# C++ OP 开发 +# C++ 算子开发指南 -> 注:飞桨原生算子的开发范式正在进行重构与升级,升级后算子开发方式会大幅简化,我们会及时更新本文档内容,升级后的算子开发范式预计会在2.3版本正式上线。 +> 注:飞桨算子的开发范式正处在重构升级后的上线初期,如果在开发过程中遇到问题欢迎通过[Issue](https://github.com/PaddlePaddle/Paddle/issues)向我们反馈。 ## 1. 概念简介 -本教程对新增原生算子的方法进行介绍,首先新增一个算子大概需要以下几个步骤: +本教程对新增算子的方法进行介绍,首先新增一个算子大概需要以下几个步骤: 1. 新增算子描述及定义:描述前反向算子的输入、输出、属性,实现InferMeta函数 2. 新增算子Kernel:实现算子在各种设备上的计算逻辑 @@ -23,7 +23,7 @@ 算子描述及定义 -paddle/fluid/operators/xxx_op.cc +python/paddle/utils/code_gen/api.yaml & python/paddle/utils/code_gen/backward.yaml 算子InferMeta @@ -44,149 +44,192 @@ -关于Python API所处位置,可以参考 [飞桨官方 API 文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/index_cn.html) ,了解各个目录存放API的性质,从而决定具体的放置目录。 - -接下来,我们以Trace操作,计算输入 Tensor 在指定平面上的对角线元素之和,并输出相应的计算结果,即 [TraceOp](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/trace_op.cc) 为例来介绍如何新增算子。 +接下来,我们以Trace操作,计算输入 Tensor 在指定平面上的对角线元素之和,并输出相应的计算结果,即 [trace](../../api/paddle/trace_cn.html#trace) 为例来介绍如何新增算子。 ## 2. 新增算子描述及定义 -算子描述及定义是定义运算的基本属性,本身是设备无关的。 - -首先简单介绍新增算子(以下简称Op)描述需要用到的基类。 - -- `framework::OpProtoAndCheckerMaker`:描述该Op的输入、输出、属性、注释。 -- `framework::OperatorBase`: Operator(简写,Op)基类。 -- `framework::OperatorWithKernel`:继承自OperatorBase,Op有计算函数,称作有Kernel。 - -根据是否包含Kernel,可以将Op分为两种:包含Kernel的Op和不包含kernel的Op: - -- 包含 Kernel 的 Op 继承自 `OperatorWithKernel`:这类Op的功能实现与输入的数据类型、数据布局、数据所在的设备以及Op实现所调用第三方库等有关。比如ConvOp,如果使用CPU计算,一般通过调用mkl库中的矩阵乘操作实现,如果使用GPU计算,一般通过调用cublas库中的矩阵乘操作实现,或者直接调用cudnn库中的卷积操作。 -- 不包含 Kernel 的 Op 继承自 `OperatorBase`:因为这类Op的功能实现与设备以及输入的数据不相关。比如WhileOp、IfElseOp等。 - -> 注:本教程仅介绍如何实现带有计算Kernel的算子,不带Kernel的算子主要用于特殊场景,一般没有需求。 - -### 2.1 定义OpProtoMaker类 - -Trace运算由一个输入,三个属性与一个输出组成。 - -首先定义`ProtoMaker`来描述该Op的输入、输出、属性并添加注释: - -```cpp -class TraceOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("Input", - "(Tensor) The input tensor, from which the diagonals are taken."); - AddOutput("Out", "(Tensor) the sum along diagonals of the input tensor"); - AddAttr( - "offset", - R"DOC((int, default 0), offset of the diagonal from the main diagonal. Can be both positive and negative. Defaults to 0. - )DOC") - .SetDefault(0); - AddAttr( - "axis1", - R"DOC((int, default 0), the first axis of the 2-D planes from which the diagonals should be taken. - Can be either positive or negative. Default: 0. - )DOC") - .SetDefault(0); - AddAttr( - "axis2", - R"DOC((int, default 1), the second axis of the 2-D planes from which the diagonals should be taken. - Can be either positive or negative. Default: 1. - )DOC") - .SetDefault(1); - AddComment(R"DOC( -Trace Operator. -Return the sum along diagonals of the input tensor. -The behavior of this operator is similar to how `numpy.trace` works. - -If Input is 2-D, returns the sum of diagonal. -If Input has larger dimensions, then returns an tensor of diagonals sum, diagonals be taken from -the 2-D planes specified by dim1 and dim2. - -)DOC"); - } -}; +算子描述及定义是定义运算的基本属性,主要包括算子的输入、输出以及各项非计算逻辑的配置,这些都是设备无关的。 + +### 2.1 算子Yaml文件配置 +我们在`python/paddle/utils/code_gen/api.yaml`和`python/paddle/utils/code_gen/backward.yaml`文件中对算子进行描述及定义,在框架编译时会根据Yaml文件中的配置自动生成C++端的相关代码接口以及内部实现(详见[Paddle基于Yaml配置的算子代码自动生成](new_cpp_op_cn.md#paddleyaml)),下面主要以Trace为例介绍算子的Yaml配置规则: + +python/paddle/utils/code_gen/api.yaml: +```yaml +- api : trace + args : (Tensor x, int offset = 0, int axis1 = 0, int axis2 = 1) + output : Tensor(out) + infer_meta : + func : TraceInferMeta + kernel : + func : trace + backward : trace_grad ``` - -[`TraceOpMaker`](https://github.com/PaddlePaddle/Paddle/blob/befa78ea3fa9d0dae096a7de91f626b0c31daee8/paddle/fluid/operators/trace_op.cc#L29)继承自`framework::OpProtoAndCheckerMaker`。 - -开发者通过覆盖`framework::OpProtoAndCheckerMaker`中的`Make`函数来定义Op所对应的Proto,通过`AddInput`添加输入参数,通过`AddOutput`添加输出参数,通过`AddAttr`添加属性参数,通过`AddComment`添加Op的注释。这些函数会将对应内容添加到`OpProto`中。 - -上面的代码在`TraceOp`中添加两个输入`X`和`Y`,添加了一个输出`Out`,并简要解释了各自含义,命名请遵守[命名规范](https://github.com/PaddlePaddle/FluidDoc/blob/release/1.2/doc/fluid/dev/name_convention.md)。 - -> 注意:OpProtoMaker中不允许定义未使用的输入、输出或属性。 - -### 2.2 定义GradOpMaker类 - -通常情况下,大部分Op只有一个对应的反向Op,每个Op都会有一个对应的`GradOpMaker`。为方便代码编写,paddle为只有一个反向的Op提供了一个模板类[`SingleGradOpMaker`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/framework/grad_op_desc_maker.h#L188)。`TraceOp`的`GradOpMaker`需要继承这个模板类,并在`Apply()`方法中设置反向Op的输入、输出和属性。此外,paddle还提供了一个默认的`GradOpMaker`, -[`DefaultGradOpMaker`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/framework/grad_op_desc_maker.h#L227),该模板类会使用前向Op的全部输入(`Input`)输出(`Output`)以及输出变量所对应的梯度(`Output@Grad`)作为反向Op的输入,将前向Op的输入变量所对应的的梯度(`Input@Grad`)作为输出。 - -**注意:** - -不要将反向Op不会用到的变量放到反向Op的输入列表中,这样会导致这些不会被反向Op用到的变量的空间不能够及时回收,进而有可能导致用到该Op的模型可以设置的batch_size较低。 -比如`relu`操作的前向操作为:`out.device(d) = x.cwiseMax(static_cast(0));`反向操作为:`dx.device(d) = dout * (out > static_cast(0)).template cast();`。显然,反向操作中只是用到了`out`、`dout`、`dx`,没有用到`x`。因此,通常不建议使用默认的`DefaultGradOpMaker`。 - -下面示例定义了`TraceOp`的`GradOpMaker`。 - -```cpp -template -class TraceGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr grad_op) const override { - grad_op->SetType("trace_grad"); - grad_op->SetInput("Input", this->Input("Input")); - grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - grad_op->SetOutput(framework::GradVarName("Input"), - this->InputGrad("Input")); - grad_op->SetAttrMap(this->Attrs()); - } -}; +python/paddle/utils/code_gen/backward.yaml: +```yaml +- backward_api : trace_grad + forward : trace (Tensor x, int offset, int axis1, int axis2) -> Tensor(out) + args : (Tensor x, Tensor out_grad, int offset, int axis1, int axis2) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : trace_grad + data_type : x + no_need_buffer : x ``` -**注意:** - -- 有些Op的前向逻辑和反向逻辑是一样的,比如[`ScaleOp`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/scale_op.cc).这种情况下,前向Op和反向Op的Kernel可以为同一个。 -- 有些前向Op所对应的反向Op可能有多个,比如[`SumOp`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/sum_op.cc),这种情况下,`GradMaker`需要继承`framework::GradOpDescMakerBase`。 -- 有些Op的反向对应另一个Op的前向,比如[`SplitOp`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/split_op.h),这种情况下,[`SplitGradMaker`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/split_op.h#L157)中定义的`SplitOp`反向Op的Type就是`concat`, -- 为高效地同时支持命令式编程模式(动态图)和声明式编程模式(静态图),`SingleGradOpMaker`是一个模板类,在注册Operator时需要同时注册`TraceOpGradMaker`(静态图使用)和`TraceOpGradMaker`(动态图使用)。 - -### 2.3 定义Op类 - -下面实现了TraceOp的定义: - -```cpp -class TraceOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; -``` - -[`TraceOp`](https://github.com/PaddlePaddle/Paddle/blob/bd4dc3be34584f9b273ecec07297fb05e1cf4c52/paddle/fluid/operators/trace_op.cc#L24)继承自`OperatorWithKernel`。`public`成员: - -```cpp -using framework::OperatorWithKernel::OperatorWithKernel; -``` - -这句表示使用基类`OperatorWithKernel`的构造函数,也可写成: - -```cpp -TraceOp(const std::string &type, const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorWithKernel(type, inputs, outputs, attrs) {} -``` - -此外,Operator类需要在有必要时重写`GetExpectedKernelType`接口。 +`api.yaml`和`backward.yaml`分别对算子的前向和反向进行配置,首先`api.yaml`中前向算子的配置规则如下: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
配置项配置内容及规则
api算子名称,与该算子Python API函数名相同(命名方式为:全小写+下划线),示例中为trace
args算子输入参数,与该算子Python API函数的输入参数对应(当前支持的输入数据类型包括:Tensor, Tensor[], float, double, bool, int, int64_t, int[], int64_t[], str, Place, DataType, DataLayout, IntArray, Scalar)。我们一般称这里Tensor类型的参数为Input(输入),非Tensor类型的参数为Attribute(属性)
+注:Tensor[]表示Tensor数组;IntArray为int类型数组,主要用于表示shape,index和axes等类型数据,可以直接使用Tensor或者普通整型数组构造,目前仍在测试阶段,如非必要暂不建议使用;Scalar表示标量,可以支持不同的普通数据类型 +
output算子输出类型(目前支持Tensor和Tensor[]类型),多个输出间用逗号“,”分隔开。可以使用”()”选择性标记输入的名字,如未标记默认为'out'
+注:当返回类型为Tensor[]时,由于数组的size要在kernel执行前推导完成,所以需要在Tensor[]后的'{}'内通过表达式指定返回数组的size,如:Tensor[](out){input.size()} +
infer_metaInferMeta函数负责根据输入变量推断返回Tensor的维度与类型,这里是对算子使用的InferMeta函数进行配置
infer_meta:func调用的InferMeta函数,这里trace调用的是TraceInferMeta函数
infer_meta:paramInferMeta函数的输入参数,可以对args中的参数进行选择传入,未配置则默认传入args中的所有参数。示例中未配置本项,所以传入的参数为[x, offset, axis1, axis2]。output项中的参数作为输出无需配置会自动传入InferMeta函数中
kernel算子的计算Kernel配置
kernel:func算子对应kernel函数的注册名
kernel:paramkernel函数的输入参数,配置规则与InferMeta函数的param配置项相同
kernel:data_type根据指定参数推导调用kernel的data_type(对应kernel函数的模板参数'T'),默认不进行配置,会根据输入Tensor自动进行推导。如果kernel的data_type类型由某个输入参数(Tensor或者DataType参数),需要将该参数的变量名填入该项。示例中未配置则kernel的data_type由输入变量'x'决定
kernel:backend根据指定参数来选择调用kernel的Backend(Kernel执行的具体设备,如CPU、GPU等),默认不进行配置,会根据输入Tensor自动进行推导。如果kernel执行的backend类型由某个输入参数(Tensor或者Backend参数)决定,需要将该参数的变量名填入该项。示例中未配置则kernel执行的Backend与输入变量'x'的Backend相同
backward算子对应的反向算子名称,如果没有反向则不需要配置,示例中trace算子的反向为trace_grad
特殊配置项(目前特殊配置项还处于不稳定阶段,后续可能会有调整更新)
optional指定输入Tensor为可选输入,用法可参考dropout中seed_tensor(python/paddle/utils/code_gen/legacy_api.yaml中)
inplace算子对指定的输入做原位处理并作为输出结果返回,使用格式:(x -> out),具体用法可参考relu算子
+特殊规则:如果api中算子名称有'_'后缀则只生成支持inplace功能的接口,如果算子名称没有'_'后缀,则会同时生成支持inplace操作的接口(自动添加'_'后缀)和不支持inplace的普通接口共两套接口 +
view与inplace机制类似,区别在于view模式返回的结果只是与输入共享内存,并不是输入Tensor变量本身,使用格式:(x -> out),具体用法可参考reshape算子
intermediate标记前向计算中输出的用于反向计算的中间变量,不会出现在Python API的返回结果中,相关设计正在完善中,新增算子时不建议使用
invoke复用已有的算子接口或实现自定义的C++ API,配置时以函数调用的形式配置即可,使用invoke时则不需要配置infer_meta和kernel。
+a. 如果是复用已有算子,需要被复用的算子为前向算子且两者的返回值类型相同,可参考zeros_like算子
+b. 如果是实现自定义的C++ API,需要在'paddle/phi/api/lib/api_custom_impl.h'声明自定义实现函数并在'paddle/phi/api/lib/api_custom_impl.cc'中进行实现,具体可参考embedding算子
-`GetExpectedKernelType`接口OperatorWithKernel类中用于获取指定设备(例如CPU,GPU)上指定数据类型(例如double,float)的OpKernel的方法。该方法的重写可见请参考 [原生算子开发注意事项](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/07_new_op/op_notes_cn.html#getexpectedkerneltype) 第4点 GetExpectedKernelType方法重写。 -通常`OpProtoMaker`和`Op`类的定义写在`.cc`文件中,和下面将要介绍的注册函数一起放在`.cc`中 +`backward.yaml`中反向算子的配置规则如下: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
配置项配置内容及规则
backward_api反向算子名称,一般命名方式为:前向算子名称+'_grad',二阶算子则为前向算子名称+'_double_grad'
forward对应前向算子的名称、参数、返回值,需要与api.yaml中前向算子配置一致
args反向算子输入参数, 示例中'x'表示将前向的'x'变量输入到反向,'out_grad'表示前向输出'out'对应的反向梯度
+约束条件1:所有参数需要在forward配置项的参数中(输入、输出以及输出对应的反向梯度)找到对应(根据变量名匹配)
+约束条件2:反向输入参数需要以:a.前向输入Tensor b.前向输出Tensor c.前向输出Tensor的反向梯度 d.前向非Tensor类型属性变量(Attribute) 的顺序排列,反向计算中不需要使用的前向变量无须添加
+
output反向算子输出,顺序需要与前向输入Tensor一致,比如前向输入(Tensor x, Tensor y),则反向输出必须为Tensor(x_grad), Tensor(y_grad)
infer_meta与前向配置规则相同
kernel与前向配置规则相同
backward反向算子对应的更高阶反向算子名称,如一阶反向算子的反向为二阶反向算子
特殊配置项(目前特殊配置项还处于不稳定阶段,后续可能会有调整更新)
no_need_buffer可选配置,标记的Tensor变量在前向运行完成后,持有的内存或显存会被释放,以减少训练过程中的内存使用。trace_grad由于反向算子只需要前向变量'x'的维度信息,不需要内存数据,所以可以标记为no_need_buffer提前释放内存
+注意:由于Tensor内存被释放后会影响dtype接口的使用,所以需要在kernel的data_type配置项中指定其他的Tensor来推导kernel的data_type
optional与前向配置规则相同
inplace与前向配置规则相同
-### 2.4 实现InferMeta函数 +### 2.2 实现InferMeta函数 `InferMeta`函数是根据输入参数,推断算子输出Tensor基本信息的函数,推断的信息包括输出Tensor的`shape`、`data type`及`data layout`,同时它也承担了检查输入数据维度、类型等是否合法的功能。 @@ -409,25 +452,6 @@ void ConcatInferMeta(const std::vector& x, } ``` -### 2.5 注册Op - -在`xxx_op.cc`文件中声明InferShapeFunctor,并注册前向、反向Op。 - -```cpp -namespace ops = paddle::operators; -DECLARE_INFER_SHAPE_FUNCTOR(trace, TraceInferShapeFunctor, - PD_INFER_META(phi::TraceInferMeta)); -REGISTER_OPERATOR(trace, ops::TraceOp, ops::TraceOpMaker, - ops::TraceGradOpMaker, - ops::TraceGradOpMaker, - TraceInferShapeFunctor); -REGISTER_OPERATOR(trace_grad, ops::TraceOpGrad, - ops::TraceGradNoNeedBufferVarsInferer); -``` - -在上面的代码中,首先使用`DECLARE_INFER_SHAPE_FUNCTOR`声明InferShapeFunctor,然后使用`REGISTER_OPERATOR`注册了`ops::TraceOp`类,算子名为`trace`,该类的`ProtoMaker`为`ops::TraceOpMaker`,其`GradOpMaker`分别是`ops::TraceOpGradMaker`(静态图模式使用)和`ops::TraceOpGradMaker`(动态图模式使用),同时将前面声明的TraceInferShapeFunctor一并放入注册列表。 -前向算子注册完成后,再使用`REGISTER_OPERATOR`注册`ops::TraceGradOp`,类型名为`trace_grad`。 - ## 3. 新增算子Kernel 新增算子Kernel在 `paddle/phi/kernels` 目录中完成 @@ -514,11 +538,41 @@ void TraceKernel(const Context& dev_ctx, > **特殊情况说明:** > 1. **特殊模板参数**:对于某些Kernel (如reshape ,copy),这些kernel不关注数据类型T, 可以省去第一个模板参数,即为:`template ` -> 2. **特殊输入类型**:对于某些特殊Kernel (如concat 和split kernel)的部分输入或输出是数组类型的DenseTensor(OpMaker中有`AsDuplicable`标记), 此时输入类型为:`const std::vector&`; 输出类型为:`std::vector` +> 2. **特殊输入类型**:对于某些特殊Kernel (如concat 和split kernel)的部分输入或输出是数组类型的DenseTensor, 此时输入类型为:`const std::vector&`; 输出类型为:`std::vector` #### 3.2.2 实现 Kernel 函数 -此处trace op的kernel属于前述第2中情况,即CPU与GPU Kernel需要分别实现。 +**复用已有Kernel实现设备无关Kernel函数** + +由于目前的Kernel复用机制为新推出的功能,暂未对已有算子进行升级改造,所以这里我们以一个不在框架中的linear算子(out = x * w + b)为例来介绍复用已有Kernel实现设备无关Kernel函数。(linear kernel 的实现源码需要放置在`paddle/phi/kernels/linear_kernel.cc`) + +`LinearKernel` 的实现代码如下: + +```cpp +#include ... +#include "paddle/phi/kernels/elementwise_add_kernel.h" +#include "paddle/phi/kernels/elementwise_multiply_kernel.h" + +template +void LinearKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& w, + const DenseTensor& b, + DenseTensor* out) { + dev_ctx.template Alloc(out); // 为 out 分配内存 + MultiplyKernel(dev_ctx, x, w, out); // 复用 MultiplyKernel + AddKernel(dev_ctx, out, b, out); // 复用 AddKernel +} +``` +复用Kernel的流程包括: +1. 在源文件中 include 要复用 Kernel 的头文件 +2. 直接调用相应的Kernel函数进行复用 + +注意:设备无关Kernel实现时计算逻辑部分只能复用现有Kernel或设备无关的Functor,不能使用设备相关的语法或者函数接口(如cuda、cudnn等)进行计算处理 + +**实现设备相关Kernel函数** + +此处 trace 算子的kernel属于前述第2中情况,即与设备相关,CPU和GPU Kernel需要分别实现。 - cpu kernel实现位于:[paddle/phi/kernels/cpu/trace_kernel.cc](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/kernels/cpu/trace_kernel.cc) - gpu kernel实现位于:[paddle/phi/kernels/gpu/trace_kernel.cu](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/kernels/gpu/trace_kernel.cu) @@ -549,9 +603,7 @@ void TraceKernel(const Context& dev_ctx, } ``` -**Kernel复用:** - -此处TraceKernel的实现并未复用其他Kernel,但如果有需要也是可以复用的,Kernel复用时,直接 include 相应Kernel头文件,在函数中调用即可,例如 [triangular_solve_kernel](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/kernels/cpu/triangular_solve_kernel.cc) 复用 empty和expand kernel。 +此处TraceKernel的实现并未复用其他Kernel,但如果有需要也是可以复用的,Kernel复用时,同样是直接 include 相应Kernel头文件,在函数中调用即可,例如 [triangular_solve_kernel](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/kernels/cpu/triangular_solve_kernel.cc) 复用 empty和expand kernel。 首先在triangular_solve_kernel.cc头部include相应头文件: @@ -564,12 +616,14 @@ void TraceKernel(const Context& dev_ctx, ```cpp // Tensor broadcast to 'out' and temp 'x_bst' - ScalarArray x_bst_dims(x_bst_dims_vec); + IntArray x_bst_dims(x_bst_dims_vec); DenseTensor x_bst = phi::Empty(dev_ctx, x_bst_dims); const T* x_bst_data = x_bst.data(); ExpandKernel(dev_ctx, x, x_bst_dims, &x_bst); ``` +补充:对于Kernel内部临时使用的`DenseTensor`目前推荐使用`Empty`、`EmptyLike`、`Full`和`FullLike`接口进行创建。 + 反向Kernel的实现与前向是类似的,此处不再赘述,可以直接参考前述对应链接中的代码实现。 **公共函数管理:** @@ -583,38 +637,6 @@ void TraceKernel(const Context& dev_ctx, 3. 有跨设备多个kernel使用的辅助函数,在`kernels/funcs`目录下创建`.h/cc/cu`管理代码 4. 如果当前依赖的辅助函数可以直接归类到`kernels/funcs`目录下已有的文件中,则直接放过去,不用创建新的文件 -**反向Kernel参数映射函数添加** - -现阶段,反向Kernel除了实现外,还需要添加一个参数映射函数。 - -仍然以trace op为例,首先在`paddle/phi/ops/compat`目录下新建`trace_sig.cc`文件,用于放置这里的映射函数。 - -- 由于函数式kernel的一个最重要的特别就是参数顺序和类型(顺序和类型是关键,变量名称不影响),我们需要定义一个函数来做一个从OpMaker中如何获取信息,并且按照顺序传递给新的kernel函数; 这个模块就是OpArgumentMapping, trace反向op的OpArgumentMapping定义如下, KernelSignature共包含4个内容 - 1. kernel名称,这个是我们给kernel注册的时候的名称 - 2. input list: 这个要和OpMaker(或者GradOpMaker)中定义的Key要完全一致 - 3. attribute list: 这个要和OpMaker(或者GradOpMaker)中定义的Key要完全一致 - 4. output list: 这个要和OpMaker(或者GradOpMaker)中定义的Key要完全一致 - - - ```cpp - #include "paddle/phi/core/compat/op_utils.h" - - namespace phi { - - KernelSignature TraceGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("trace_grad", - {GradVarName("Out"), "Input"}, - {"offset", "axis1", "axis2"}, - {GradVarName("Input")}); - } - - } // namespace phi - - PD_REGISTER_ARG_MAPPING_FN(trace_grad, phi::TraceGradOpArgumentMapping); - ``` - ->注:没有input list或attribute list的,相应花括号内留空,不能省略花括号 - #### 3.2.3 注册 Kernel 函数 注册kernel的方式比较简单,直接使用注册宏注册即可,示例如下: @@ -748,7 +770,10 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None): __check_input(input, offset, axis1, axis2) - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_trace( x, offset, axis1, axis2 ) + + if _in_legacy_dygraph(): return _C_ops.trace(x, 'offset', offset, 'axis1', axis1, 'axis2', axis2) inputs = {'Input': [x]} @@ -767,39 +792,32 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None): return out ``` -> 概念解释:LayerHelper是一个用于创建op输出变量、向program中添加op的辅助工具类 - -- Python API 实现要点 +- Python API 实现要点(详见[飞桨API Python 端开发指南](./new_python_api_cn.html)) - 对输入参数进行合法性检查,即 `__check_input(input, offset, axis1, axis2)` - - 添加动态图分支调用,即 `if paddle.in_dynamic_mode()` 分支 - - 添加静态图分支调用,即dygraph mode分支后剩余的代码 - -- Python API 放置位置 - - 根据 API 自身属性,结合现有目录分类情况,放置导致对应子目录中的相应文件中 - - 可以参考 [飞桨官方 API 文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/index_cn.html) 中对各个子目录 **功能和包含的API** 的介绍 + - 添加动态图分支调用,即 `if in_dygraph_mode` 新动态图分支和 `if _in_legacy_dygraph` 旧动态图分支 + - 添加静态图分支调用,即dygraph分支后剩余的代码 -- Python API 文档 - - 参考示例格式进行添加,内容尽可能准确、翔实,详细规范请参考 [PaddlePaddle 文档](https://github.com/PaddlePaddle/docs/wiki) ## 5. 添加单元测试 -单测包括对比前向Op不同设备(CPU、CUDA)的实现、对比反向OP不同设备(CPU、CUDA)的实现、反向Op的梯度测试。下面介绍介绍[`TraceOp`的单元测试](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/tests/unittests/test_trace_op.py)。 +单测包括对比前向算子不同设备(CPU、CUDA)的实现、对比反向算子不同设备(CPU、CUDA)的实现、反向算子的梯度测试。下面介绍介绍[`TraceOp`的单元测试](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/tests/unittests/test_trace_op.py)。 **注意:** 单测中的测试用例需要尽可能的覆盖Kernel中的所有分支。 -### 5.1 前向 Operator 单测 +### 5.1 前向算子单测 -Op单元测试继承自`OpTest`。各项具体的单元测试在`TestTraceOp`里完成。测试Operator,需要: +算子单元测试继承自`OpTest`。各项具体的单元测试在`TestTraceOp`里完成。测试算子,需要: 1. 在`setUp`函数定义输入、输出,以及相关的属性参数。 2. 生成随机的输入数据。 -3. 在Python脚本中实现与前向operator相同的计算逻辑,得到输出值,与operator前向计算的输出进行对比。 +3. 在Python脚本中实现与前向算子相同的计算逻辑,得到输出值,与算子前向计算的输出进行对比。 4. 反向计算已经自动集成进测试框架,直接调用相应接口即可。 ```python + import paddle import unittest import numpy as np from op_test import OpTest @@ -808,14 +826,15 @@ Op单元测试继承自`OpTest`。各项具体的单元测试在`TestTraceOp`里 class TestTraceOp(OpTest): def setUp(self): self.op_type = "trace" + self.python_api = paddle.trace self.init_config() self.outputs = {'Out': self.target} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['Input'], 'Out') + self.check_grad(['Input'], 'Out', check_eager=True) def init_config(self): self.case = np.random.randn(20, 6).astype('float64') @@ -826,19 +845,21 @@ Op单元测试继承自`OpTest`。各项具体的单元测试在`TestTraceOp`里 上面的代码首先导入依赖的包,下面是对`setUp`函数中操作的重要变量的详细解释: - - `self.op_type = "trace" ` : 定义类型,与operator注册时注册的类型一致。 + - `self.op_type = "trace" ` : 定义类型,与算子定义的名称相同。 + - `self.python_api = paddle.trace` : 定义python api,与python调用接口一致。 - `self.inputs` : 定义输入,类型为`numpy.array`,并初始化。 - - `self.outputs` : 定义输出,并在Python脚本中完成与operator同样的计算逻辑,返回Python端的计算结果。 + - `self.outputs` : 定义输出,并在Python脚本中完成与算子同样的计算逻辑,返回Python端的计算结果。 -### 5.2 反向 operator 单测 +### 5.2 反向算子单测 而反向测试中: - `test_check_grad`中调用`check_grad`使用数值法检测梯度正确性和稳定性。 - 第一个参数`['Input']` : 指定对输入变量`Input`做梯度检测。 - 第二个参数`'Out'` : 指定前向网络最终的输出目标变量`Out`。 + - 第三个参数`check_eager` : `check_eager=True`表示开启新动态图(eager模式)单测,`check_eager`默认为`False`。 -- 对于存在多个输入的反向Op测试,需要指定只计算部分输入梯度的case +- 对于存在多个输入的反向算子测试,需要指定只计算部分输入梯度的case - 例如,`test_elementwise_sub_op.py`中的`test_check_grad_ingore_x`和`test_check_grad_ingore_y`分支用来测试只需要计算一个输入梯度的情况 - 此处第三个参数max_relative_error:指定检测梯度时能容忍的最大错误值。 @@ -852,6 +873,9 @@ Op单元测试继承自`OpTest`。各项具体的单元测试在`TestTraceOp`里 ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) ``` +### 5.3 Python API 单元测试 +Python API也需要编写相关的单测进行测试,详见[添加 Python API 单元测试](new_python_api_cn.html#id2) + 其他有关单元测试添加的注意事项请参考 [《Op开发手册》](https://github.com/PaddlePaddle/Paddle/wiki/Operator-Development-Manual-Index) 及 [《Paddle单元测试规范》](https://github.com/PaddlePaddle/Paddle/wiki/PaddlePaddle-Unit-test-specification)。 @@ -873,15 +897,11 @@ make test ARGS="-R test_trace_op -V" ctest -R test_trace_op -V ``` -**注意事项:** - -- 注册Op时的类型名,需要和该Op的名字一样。即不允许在`A_op.cc`里面,注册`REGISTER_OPERATOR(B, ...)`等,这将会导致单元测试出错。 - -## 6. 其他编码要点 +## 6. 开发算子注意事项 ### 6.1 报错检查 -实现Op时检查数据的合法性需要使用PADDLE_ENFORCE以及PADDLE_ENFORCE_EQ等宏定义,基本格式如下: +实现算子时检查数据的合法性需要使用PADDLE_ENFORCE以及PADDLE_ENFORCE_EQ等宏定义,基本格式如下: ``` PADDLE_ENFORCE(表达式, 错误提示信息) @@ -909,3 +929,152 @@ PADDLE_ENFORCE_EQ(比较对象A, 比较对象B, 错误提示信息) - 例如:`Suggested Fix:If your classifier expects one-hot encoding label,check your n_classes argument to the estimatorand/or the shape of your label.Otherwise, check the shape of your label.` 更详细的报错检查规范介绍请参考 [《Paddle报错信息文案书写规范》](https://github.com/PaddlePaddle/Paddle/wiki/Paddle-Error-Message-Writing-Specification)。 + +### 6.2 算子兼容性问题 +对算子的修改需要考虑兼容性问题,要保证算子修改之后,之前的模型都能够正常加载及运行,即新版本的Paddle预测库能成功加载运行旧版本训练的模型。**所以,需要保证算子当前的所有输入输出参数不能被修改(文档除外)或删除,可以新增参数,但是新增的Tensor类型变量需要设置为optional,非Tensor变量需要设置默认值。更多详细内容请参考[OP修改规范:Input/Output/Attribute只能做兼容修改](https://github.com/PaddlePaddle/Paddle/wiki/OP-Input-Output-Attribute-Compatibility-Modification)** 。 + +### 6.3 显存优化 + +#### 6.3.1 为可原位计算的算子注册inplace +有些算子的计算逻辑中,输出可以复用输入的显存空间,也可称为原位计算。例如[reshape](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/kernels/reshape_kernel.cc)中,输出`out`可以复用输入`x`的显存空间,因为该算子的计算逻辑不会改变`x`的实际数据,只是修改它的shape,输出和输入复用同一块显存空间不影响结果。对于这类算子,可以注册`inlace`,从而让框架在运行时自动地进行显存优化。 + +注册方式为在算子的Yaml配置中添加`inplace`配置项,格式如:`(x -> out)`,详见[Yaml配置规则](new_cpp_op_cn.html#yaml)。示例: + +```yaml +- api : reshape + args : (Tensor x, IntArray shape) + output : Tensor(out) + ... + inplace : (x -> out) +``` + +#### 6.3.2 减少反向算子中的无关变量 +通常反向算子会依赖于前向算子的某些输入、输出Tensor,以供反向算子计算使用。但有些情况下,反向算子不需要前向Op的所有输入和输出;有些情况下,反向算子只需要前向算子的部分输入和输出;有些情况下,反向算子只需要使用前向算子中输入和输出变量的Shape和LoD信息。若开发者在注册反向算子时,将不必要的前向算子输入和输出作为反向算子的输入,会导致这部分显存无法被框架现有的显存优化策略优化,从而导致模型显存占用过高。 + +所以在定义反向算子时需要注意以下几点: + +- 如果反向不需要前向的某些输入或输出参数,则无需在args中设置。 +- 如果有些反向算子需要依赖前向算子的输入或输出变量的的Shape或LoD,但不依赖于变量中Tensor的内存Buffer数据,且不能根据其他变量推断出该Shape和LoD,则可以通过`no_need_buffer`对该变量进行配置,详见[Yaml配置规则](new_cpp_op_cn.html#yaml)。示例: +```yaml +- backward_api : trace_grad + forward : trace (Tensor x, int offset, int axis1, int axis2) -> Tensor(out) + args : (Tensor x, Tensor out_grad, int offset, int axis1, int axis2) + output : Tensor(x_grad) + ... + no_need_buffer : x +``` + +### 6.4 性能优化 +#### 6.4.1 第三方库的选择 +在写算子过程中优先使用高性能(如cudnn、mkldnn、mklml、eigen等)中提供的操作,但是一定要做benchmark,有些库中的操作在深度学习任务中可能会比较慢。因为高性能库(如eigen等)中提供的操作为了更为通用,在性能方面可能并不是很好,通常深度学习模型中数据量较小,所以有些情况下可能高性能库中提供的某些操作速度较慢。比如Elementwise系列的所有算子(前向和反向),Elementwise操作在模型中调用的次数比较多,尤其是Elementwise_add,在很多操作之后都需要添加偏置项。在之前的实现中Elementwise_op直接调用Eigen库,由于Elementwise操作在很多情况下需要对数据做Broadcast,而实验发现Eigen库做Broadcast的速度比较慢,慢的原因在这个PR[#6229](https://github.com/PaddlePaddle/Paddle/pull/6229)中有描述。 + +#### 6.4.2 算子性能优化 +算子的计算速度与输入的数据量有关,对于某些算子可以根据输入数据的Shape和算子的属性参数来选择不同的计算方式。比如concat_op,当axis>=1时,在对多个tensor做拼接过程中需要对每个tensor做很多次拷贝,如果是在GPU上,需要调用cudaMemCopy。相对CPU而言,GPU属于外部设备,所以每次调用GPU的操作都会有一定的额外开销,并且当需要拷贝的次数较多时,这种开销就更为凸现。目前concat_op的实现会根据输入数据的Shape以及axis值来选择不同的调用方式,如果输入的tensor较多,且axis不等于0,则将多次拷贝操作转换成一个CUDA Kernel来完成;如果输入tensor较少,且axis等于0,使用直接进行拷贝。相关实验过程在该PR([#8669](https://github.com/PaddlePaddle/Paddle/pull/8669))中有介绍。 + +由于CUDA Kernel的调用有一定的额外开销,所以如果算子中出现多次调用CUDA Kernel,可能会影响算子的执行速度。比如之前的sequence_expand_op中包含很多CUDA Kernel,通常这些CUDA Kernel处理的数据量较小,所以频繁调用这样的Kernel会影响算子的计算速度,这种情况下最好将这些小的CUDA Kernel合并成一个。在优化sequence_expand_op过程(相关PR[#9289](https://github.com/PaddlePaddle/Paddle/pull/9289))中就是采用这种思路,优化后的sequence_expand_op比之前的实现平均快出约1倍左右,相关实验细节在该PR([#9289](https://github.com/PaddlePaddle/Paddle/pull/9289))中有介绍。 + +减少CPU与GPU之间的拷贝和同步操作的次数。比如fetch操作,在每个迭代之后都会对模型参数进行更新并得到一个loss,并且数据从GPU端到没有页锁定的CPU端的拷贝是同步的,所以频繁的fetch多个参数会导致模型训练速度变慢。 + +更多算子性能优化方法,请参考 [算子性能优化 方法介绍](../op_optimization/op_optimization_method_introduction_cn.html)。 + +### 6.5 稀疏梯度参数更新方法 +目前稀疏梯度在做更新的时候会先对梯度做merge,即对相同参数的梯度做累加,然后做参数以及附加参数(如velocity)的更新。 + +### 6.6 混合设备调用 +由于GPU是异步执行的,当CPU调用返回之后,GPU端可能还没有真正的执行,所以如果在算子中创建了GPU运行时需要用到的临时变量,当GPU开始运行的时候,该临时变量可能在CPU端已经被释放,这样可能会导致GPU计算出错。 + +关于GPU中的一些同步和异步操作: +``` +The following device operations are asynchronous with respect to the host: + Kernel launches; + Memory copies within a single device's memory; + Memory copies from host to device of a memory block of 64 KB or less; + Memory copies performed by functions that are suffixed with Async; + Memory set function calls. +``` + +关于cudaMemCpy和cudaMemCpyAsync注意事项: + +- 如果数据传输是从GPU端到非页锁定的CPU端,数据传输将是同步,即使调用的是异步拷贝操作。 +- 如果数据传输是从CPU端到CPU端,数据传输将是同步的,即使调用的是异步拷贝操作。 + +更多内容可参考:[Asynchronous Concurrent Execution](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#asynchronous-concurrent-execution),[API synchronization behavior](https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html#api-sync-behavior) + +### 6.7 算子数值稳定性问题 +有些算子存在数值稳定性问题,出现数值稳定性的主要原因程序在多次运行时,对浮点型数据施加操作的顺序可能不同,进而导致最终计算结果不同。而GPU是通过多线程并行计算的方式来加速计算的,所以很容易出现对浮点数施加操作的顺序不固定现象。 + +目前发现cudnn中的卷积操作、cudnn中的MaxPooling、CUDA中CudaAtomicXX、ParallelExecutor的Reduce模式下参数梯度的聚合等操作运行结果是非确定的。 + +为此Paddle中添加了一些FLAGS,比如使用FLAGS_cudnn_deterministic来强制cudnn使用确定性算法、FLAGS_cpu_deterministic强制CPU端的计算使用确定性方法。 + +### 6.8 算子的数学公式 +如果算子有数学公式,一定要在代码中将数学公式写明,并在Python API的Doc中显示,因为用户在对比不同框架的计算结果时可能需要了解Paddle对算子是怎么实现的。 + +### 6.9 LoD 在算子内部的传导规范 + +[LoD](https://github.com/PaddlePaddle/FluidDoc/blob/develop/doc/fluid/design/concepts/lod_tensor.md) 是 Paddle 框架用来表示变长序列数据的属性,除了仅支持输入是 padding data 的算子外,所有算子的实现都要考虑 LoD 的传导问题。 + +根据算子的计算过程中是否用到 LoD,我们可以将涉及到 LoD 传导问题的算子分为两类: LoD-Transparent 与 LoD-Based。 + + + + + + + + + + + + + + + + + + + + + +
类型特点示例
LoD-Transparent 计算过程不依赖 LoD,输入是否有 LoD 不会影响计算的结果,通常是 position-wise 的计算 conv2d_op、batch_norm_op、dropout_op 等
LoD-Based 计算以序列为单位, 计算过程依赖 LoD lstm_op、gru_op、sequence_ops 等
+ +这两类算子的 LoD 传导需要考虑前向和反向两个过程。 + +#### 前向传导 + +在前向传导过程,与输入的 LoD 相比较,算子输出的 LoD 可能出现不变、改变和消失这三种情况: + + - 不变:适用于所有的 LoD-Transparent 算子与部分的 LoD-Based算子。可以在`InferMeta` 中调用 `ShareLoD()` 直接将输入 Var 的 LoD 共享给输出 Var, 可参考 [lstm_op](https://github.com/PaddlePaddle/Paddle/blob/a88a1faa48a42a8c3737deb0f05da968d200a7d3/paddle/fluid/operators/lstm_op.cc#L92); 如果有多个输入且都可能存在 LoD 的情况,通常默认共享第一个输入, 例如 [elementwise_ops forward](https://github.com/PaddlePaddle/Paddle/blob/5d6a1fcf16bcb48d2e66306b27d9994d9b07433c/paddle/fluid/operators/elementwise/elementwise_op.h#L69); + + - 改变:适用于部分 LoD-Based 算子。在实现 OpKernel 时需考虑输出 LoD 的正确计算,真实的 LoD 在前向计算结束后才能确定,此时仍需要在`InferMeta` 中调用 `ShareLoD()`,以确保CompileTime 时对 LoD Level 做了正确的传导,可参考 [sequence_expand_op](https://github.com/PaddlePaddle/Paddle/blob/565d30950138b9f831caa33904d9016cf53c6c2e/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc); + + - 消失:适用于输出不再是序列数据的 LoD-Based 算子。此时不用再考虑前向的 LoD 传导问题,可参考 [sequence_pool_op](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc); + +其它重要的注意事项: + + - 实现 LoD-Based 算子时,需要处理好 LoD 传导的边界情况,例如对长度为零的输入的支持,并完善相应的单测,单测 case 覆盖空序列出现在 batch 开头、中间和末尾等位置的情况,可参考 [test_lstm_op.py](https://github.com/PaddlePaddle/Paddle/blob/4292bd8687ababc7737cffbddc0d38ead2138c00/python/paddle/fluid/tests/unittests/test_lstm_op.py#L203-L216) + + - 对 LoD Level 有明确要求的算子,推荐的做法是在 `InferMeta` 中即完成 LoD Level的检查,例如 [sequence_pad_op](https://github.com/PaddlePaddle/Paddle/blob/4292bd8687ababc7737cffbddc0d38ead2138c00/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc#L79)。 + + +#### 反向传导 + +通常来讲,算子的某个输入 Var 所对应的梯度 GradVar 的 LoD 应该与 Var 自身相同,所以应直接将 Var 的 LoD 共享给 GradVar,可以参考 [elementwise ops 的 backward](https://github.com/PaddlePaddle/Paddle/blob/a88a1faa48a42a8c3737deb0f05da968d200a7d3/paddle/fluid/operators/elementwise/elementwise_op.h#L189-L196) + + +## 7. 补充 +### 7.1 Paddle基于Yaml配置的算子代码自动生成 +Paddle支持动态图和静态图两种模式,在Yaml配置文件中完成算子基本属性的定义后,需要进行解析并分别生成动态图和静态图所对应的算子代码逻辑,从而将算子接入框架的执行体系。基于Yaml配置的算子代码自动生成示意图: +![code_gen_by_yaml](./code_gen_by_yaml.png) + +- 其中Yaml配置文件为前向:`python/paddle/utils/code_gen/api.yaml` 和反向:`python/paddle/utils/code_gen/backward.yaml`。 +- 动态图中自动生成的代码包括从Python API到计算Kernel间的各层调用接口实现,从底层往上分别为: + - C++ API:一套与Python API参数对齐的C++接口(只做逻辑计算,不支持自动微分),内部封装了底层kernel的选择和调用等逻辑,供上层灵活使用。 + - 注:前向算子生成C++ API头文件和实现代码分别为`paddle/phi/api/include/api.h`和`paddle/phi/api/lib/api.cc`,反向算子生成的头文件和实现代码分别为`paddle/phi/api/backward/backward_api.h`,`paddle/phi/api/lib/backward_api.cc`。 + - 动态图前向函数与反向节点(Autograd API):在C++ API的基础上进行了封装,组成一个提供自动微分功能的C++函数接口。 + - 注:生成的相关代码在`paddle/fluid/eager/api/generated/eager_generated`目录下 + - Python-C 接口:将支持自动微分功能的C++的函数接口(Autograd API)暴露到Python层供Python API调用。 + - 注:生成的Python-C 接口代码在`paddle/fluid/pybind/eager_final_state_op_function_impl.h`中 +- 静态图的执行流程与动态图不同,所以生成的代码也与动态图有较大差异。静态图由于是先组网后计算,Python API主要负责组网,算子的调度和kernel计算由静态图执行器来完成,因此自动生成的代码是将配置文件中的算子信息注册到框架内供执行器调度,主要包括[OpMaker](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/framework/op_proto_maker.h)(静态图中定义算子的输入、输出以及属性等信息)和`REGISTER_OPERATOR`(将算子名称以及OpMaker等信息进行注册)等静态图算子注册组件,具体的代码逻辑可参考`paddle/fluid/operators/generated_op.cc` + +**注意:由于代码自动生成在编译时进行,所以查看上述生成代码需要先完成[框架的编译](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/install/compile/fromsource.html)。** diff --git a/docs/dev_guides/api_contributing_guides/new_cpp_op_notes_cn.md b/docs/dev_guides/api_contributing_guides/new_cpp_op_notes_cn.md deleted file mode 100644 index 98295ba7ec9..00000000000 --- a/docs/dev_guides/api_contributing_guides/new_cpp_op_notes_cn.md +++ /dev/null @@ -1,361 +0,0 @@ -# C++ OP 开发注意事项 - -## Paddle中Op的构建逻辑 -### 1.Paddle中Op的构建逻辑 -Paddle中所有的Op都继承自`OperatorBase`,且所有的Op都是无状态的,每个Op包含的成员变量只有四个:type、inputs、outputs、attribute。 - -Op的核心方法是Run,Run方法需要两方面的资源:数据资源和计算资源,这两个资源分别通过`Scope`和`Place`获取。框架内部有一个全局的`DeviceContextPool`,用来记录`Place`和`DeviceContext`之间的对应的关系,即每个`Place`有且仅有一个`DeviceContext`与之对应,`DeviceContext`中存放了当前设备的计算资源。比如对于GPU,这些资源包括`cudnn_handle`、`cublas_handle`、`stream`等,**Op内部所有的计算(数据拷贝和CUDA Kernel等)都必须在`DeviceContext`中进行**。 - -Paddle框架的设计理念是可以在多种设备及第三方库上运行,有些Op的实现可能会因为设备或者第三方库的不同而不同。为此,Paddle引入了OpKernel的方式,即一个Op可以有多个OpKernel,这类Op继承自`OperatorWithKernel`,这类Op的代表是conv_op,conv_op的OpKernel有:`GemmConvKernel`、`CUDNNConvOpKernel`、`ConvMKLDNNOpKernel`,且每个OpKernel都有double和float两种数据类型。不需要OpKernel的代表有`WhileOp`等。 - -Operator继承关系图: -![op_inheritance_relation_diagram](./op_inheritance_relation_diagram.png) - -进一步了解可参考:[multi_devices](https://github.com/PaddlePaddle/FluidDoc/tree/develop/doc/fluid/design/multi_devices),[scope](https://github.com/PaddlePaddle/FluidDoc/blob/develop/doc/fluid/design/concepts/scope.md),[Developer's_Guide_to_Paddle_Fluid](https://github.com/PaddlePaddle/FluidDoc/blob/release/1.2/doc/fluid/getstarted/Developer's_Guide_to_Paddle_Fluid.md) - -### 2.Op的注册逻辑 -每个Operator的注册项包括: - ```C++ - OpCreator creator_; - GradOpMakerFN grad_op_maker_; - proto::OpProto* proto_{nullptr}; - OpAttrChecker* checker_{nullptr}; - InferVarTypeFN infer_var_type_; - InferShapeFN infer_shape_; - ``` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
注册项类型说明调用
proto::OpProto Class 存放Op的输入/输出/属性/Op类型 编译时调用
GradOpMakerFN Functor 返回当前Op对应的反向Op的一组OpDesc,因为正向Op的反向可能有多个Op构成 编译时调用
OpAttrChecker Class 对Op的attr进行check 编译时调用
InferVarTypeFN Functor 用于推断输出Var的Type,比如是LoDTensor还是SelectedRows,或者其他 编译时调用
InferShapeFN Functor 用于推断Output的Shape 分为编译时和运行时,编译时是在Python端调用;如果Op继承自OperatorWithKernel,运行时是在op.run中调用
OpCreator Functor 每次调用都会创建一个新的OperatorBase 运行时调用
- -通常Op注释时需要调用REGISTER_OPERATOR,即: - ``` - REGISTER_OPERATOR(op_type, - OperatorBase - op_maker_and_checker_maker, - op_grad_opmaker, - op_infer_var_shape, - op_infer_var_type) - ``` - -**注意:** - -1. 对于所有Op,前三个参数是必须的,op_type指明op的名字,OperatorBase是该Op的对象,op_maker_and_checker_maker是op的maker以及Op中attr的checker。 -2. 如果该Op有反向,则必须要有op_grad_opmaker,因为在backward会根据正向的Op中获取反向Op的Maker。 -3. 框架提供了一个默认的op_grad_opmaker:`DefaultGradOpDescMaker`,这个Maker会将前向Op的输入和输出都作为反向Op的输入,将前向Op的输入的梯度作为反向Op的输出,并将前向Op的属性拷贝过来。**注意:DefaultGradOpDescMaker会将前向Op的所有输入输出都做反向Op的输入,即使这个输入是没有必要的,这将会导致无法对没有用到的变量做内存优化**。 -4. 框架没有提供默认的op_infer_var_shape方法。如果该Op是无OpKernel的,通常需要用户添加对应的op_infer_var_shape方法;如果该Op是有OpKernel的,需要实现`OperatorWithKernel`中的`InferShape`方法,此时不需要提供op_infer_var_shape方法。具体实现可参考[while_op.cc](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/controlflow/while_op.cc),[conv_op.cc](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/conv_op.cc)。 -5. 框架没有提供默认的op_infer_var_type方法,用户需要根据实际情况添加op_infer_var_type。严格来说每个Op都应该注册一个InferVarType,op_infer_var_type根据输入的Var的type和dtype推断输出Var的type和dtype。**注意:在Python端的LayerHelper中create_variable_for_type_inference操作返回的Variable里面是LoDTensor,C++端的InferVarType可以修改`Variable`的type和dtype**。 - - - -更多内容请参考: [如何写新的Op](new_op.html) - -## 写Op注意事项 -### 1.Op可以支持输入输出类型 -Paddle的Op的输入输出都是`Variable`,从设计上讲,`Variable`中可以存放任意类型,Op的输入输出`Variable`可能是是任意类型,通常情况下`Variable`中存放的是`LoDTensor`、`SelectedRows`。 - -**注意:** - -- 代码中经常出现`context.Input("Input")`,并不表示"Input"的`Variable`是`Tensor`,而是从"Input"的`Variable`的`LoDTensor`中获取`Tensor`。如果"Input"的`Variable`是`SelectedRows`,则会报错。 -- 如果”Input”是`SelectedRows`,`context->GetInputDim("Input")`返回的是`var->Get().GetCompleteDims()`,而不是`SelectedRows`中`Tensor`的Dim。 - -### 2.在Op内部不能对输入的数据做任何的改写 -在Op内部绝不允许对输入数据做任何改写,因为可能存在其他Op需要读这个数据。 - -### 3.OpKernel需要注册的数据类型 -目前要求所有OpKernel都要注册double和float数据类型。 - -### 4.GetExpectedKernelType方法重写 -GetExpectedKernelType方法是OperatorWithKernel类中用于获取指定设备(例如CPU,GPU)上指定数据类型(例如double,float)的OpKernel的方法。该方法通过获取输入变量内部的Tensor数据类型得知需要的Kernel数据类型,但是由于Tensor在此处可能尚未被初始化,所以在该方法内使用输入变量时需要进行必要的初始化检查。在新增含Kernel的Op的时候,关于该方法的重写需要注意以下两点。 - -#### 4.1 仅在必要时重写此方法 - -基类OperatorWithKernel中的GetExpectedKernelType方法对于派生类Op的所有输入变量进行了完备的初始化检查,建议在新增的Op中直接使用基类的此方法,例如: - -- [MeanOp](https://github.com/PaddlePaddle/Paddle/blob/3556514e971bdbb98fdf0f556371c527f4dfa98c/paddle/fluid/operators/mean_op.cc#L39):该Op的所有输入变量在Run之前应该全部被初始化,初始化检查是必要且合理的 - -但是在一些情况下,直接使用基类的GetExpectedKernelType方法无法满足需求,则需要对该方法进行重写,具体情况及示例如下: - -1. OP的输入有多个,且数据类型不同,例如 [AccuracyOp](https://github.com/PaddlePaddle/Paddle/blob/370f0345b6d35a513c8e64d519a0edfc96b9276c/paddle/fluid/operators/metrics/accuracy_op.cc#L80),需要重写GetExpectedKernelType方法,指定用某一输入变量获取kernel类型 - -2. Op包含Dispensable的输入变量,该类输入变量是可选的,当用户未输入时,该类变量未被初始化属于合理情况,例如 [ConvOp](https://github.com/PaddlePaddle/Paddle/blob/250e72d254ccbe3521c29aa2801a1cb15b75ea73/paddle/fluid/operators/conv_op.cc#L206),存在Bias等可选的输入变量,需要重写GetExpectedKernelType方法,指定用必须提供的输入变量获取kernel类型 - -3. Op的部分输入变量即使未被初始化也属于合理情况,例如 [ConcatOp](https://github.com/PaddlePaddle/Paddle/blob/250e72d254ccbe3521c29aa2801a1cb15b75ea73/paddle/fluid/operators/concat_op.cc#L90),输入变量X中有个Tensor需要连接,其中可能包含未被初始化的Tensor,需要重写GetExpectedKernelType方法,使用输入变量X获取kernel的过程中,合理忽略掉部分Tensor为空的情况 - -4. OP的Kernel类型与输入变量无关(可能由其他参数指定),例如 [FillOp](https://github.com/PaddlePaddle/Paddle/blob/efbdad059634bef022d4a3f5b00aef6ef8e88ed6/paddle/fluid/operators/one_hot_op.cc#L72),该Op没有输入,Kernel类型通过Op的dtype参数指定,因此需要重写GetExpectedKernelType方法,用参数指定的数据类型获取kernel类型 - -5. Op Kernel的部分参数在使用某些库时,需要指定为相应的值,因此需要重写GetExpectedKernelType方法,覆盖默认参数 - - 使用CUDNN库:需要指定OpKernel的LibraryType为kCUDNN,例如 [AffineGridOp](https://github.com/PaddlePaddle/Paddle/blob/370f0345b6d35a513c8e64d519a0edfc96b9276c/paddle/fluid/operators/affine_grid_op.cc#L78) - - 使用MKLDNN库:需要指定OpKernel的LibraryType和DataLayout为kMKLDNN [MulOp](https://github.com/PaddlePaddle/Paddle/blob/250e72d254ccbe3521c29aa2801a1cb15b75ea73/paddle/fluid/operators/mul_op.cc#L89) - -#### 4.2 重写此方法时需要对输入变量进行初始化检查 - -在需要重写GetExpectedKernelType方法时,一般会根据某一输入变量获取Kernel的数据类型,此时请使用`OperatorWithKernel::IndicateVarDataType`接口获取变量的dtype,该方法对指定的输入变量进行了必要的初始化检查,详见[Paddle PR #20044](https://github.com/PaddlePaddle/Paddle/pull/20044),实现示例如下,: - -``` - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); - } -``` - -如果未使用带有初始化检查的方法,直接使用了`Tensor->type()`,可能会导致报出`holder_ should not be null. Tensor not initialized yet when Tensor::type()`的错误,例如[Paddle issue #19522](https://github.com/PaddlePaddle/Paddle/issues/19522) ,用户仅凭该错误信息将无法得知具体出错的Op,不利于调试。 - -### 5.Op兼容性问题 -对Op的修改需要考虑兼容性问题,要保证Op修改之后,之前的模型都能够正常加载及运行,即新版本的Paddle预测库能成功加载运行旧版本训练的模型。**所以,需要保证Op的Input、Output和Attribute不能被修改(文档除外)或删除,可以新增Input、Output和Attribute,但是新增的Input,Output必须设置AsDispensable,新增的Attribute必须设置默认值。更多详细内容请参考[OP修改规范:Input/Output/Attribute只能做兼容修改](https://github.com/PaddlePaddle/Paddle/wiki/OP-Input-Output-Attribute-Compatibility-Modification)** 。 - -### 6.ShareDataWith的调用 -ShareDataWith的功能是使两个Tensor共享底层buffer,在调用这个操作的时候需要特别注意,在Op内部不能将ShareDataWith作用在Op的输出上,即Op输出的Tensor必须是Malloc出来的。 - -### 7.稀疏梯度参数更新方法 -目前稀疏梯度在做更新的时候会先对梯度做merge,即对相同参数的梯度做累加,然后做参数以及附加参数(如velocity)的更新。 - -### 8.显存优化 - -#### 8.1 为可原位计算的Op注册Inplace -有些Op的计算逻辑中,输出可以复用输入的显存空间,也可称为原位计算。例如[reshape_op](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/reshape_op.cc)中,输出`Out`可以复用输入`X`的显存空间,因为该Op的计算逻辑不会改变`X`的实际数据,只是修改它的shape,输出和输入复用同一块显存空间不影响结果。对于这类OP,可以注册`Inlace`,从而让框架在运行时自动地进行显存优化。 - -Paddle提供了`DECLARE_INPLACE_OP_INFERER`宏用于注册`Inplace`,该宏第一个参数是一个类名,如`ReshapeOpInplaceInToOut`;第二个参数是一对复用的输入输出,以`{"X", "Out"}`的形式给出。在`REGISTER_OPERATOR`时, -可以将类名传传入,从而为该Op注册`Inplace`。 - -``` -DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInToOut, {"X", "Out"}); - -REGISTER_OPERATOR( - reshape, ops::ReshapeOp, ops::ReshapeOpMaker, - paddle::framework::DefaultGradOpMaker, - paddle::framework::DefaultGradOpMaker, - ops::ReshapeOpInplaceInToOut); -``` - -#### 8.2 减少OP中的无关变量 -通常反向Op会依赖于前向Op的某些输入(Input)、输出(Output),以供反向Op计算使用。但有些情况下,反向Op不需要前向Op的所有输入和输出;有些情况下,反向Op只需要前向Op的部分输入和输出;有些情况下,反向Op只需要使用前向Op中输入和输出变量的Shape和LoD信息。若Op开发者在注册反向Op时,将不必要的前向Op输入和输出作为反向Op的输入,会导致这部分显存无法被框架现有的显存优化策略优化,从而导致模型显存占用过高。 - -所以在写注册反向Op时需要注意以下几点: - -- Paddle提供的`DefaultGradOpMaker`,默认会将前向op的所有输入(`Input`)、输出(`Output`)以及输出变量所对应的梯度(`Output@Grad`)作为反向Op的输入,将前向Op输入所对应的梯度(`Input@Grad`)作为反向Op的输出。所以在使用`DefaultGradOpMaker`时需要考虑是否有些变量在计算中不被用到。 -- 如果`DefaultGradOpMaker`不能够满足需求,需要用户自己手动构建`GradOpMaker`,具体实现请参考[相关文档](new_op.html#gradopmaker); -- 如果有些反向Op需要依赖前向Op的输入或输出变量的的Shape或LoD,但不依赖于变量中Tensor的Buffer,且不能根据其他变量推断出该Shape和LoD,则可以通过`DECLARE_NO_NEED_BUFFER_VARS_INFERER`接口对该变量(以下称该变量为`X`)在反向Op中进行注册`NoNeedBufferVars`。**一旦注册了`NoNeedBufferVars`,反向op中就不能读写该变量对应的Tensor中的buffer,只能调用Tensor的dims()和lod()方法,同时,反向Op中的`GetExpectedKernelType()`必须要重写,并且`GetExpectedKernelType()`中不能访问`X`变量中Tensor的type()方法**。比如在`SliceOpGrad`中只会用到`Input`中变量的Shape信息,所以需要为对`Input`在`SliceOpGrad`上进行注册: -``` -namespace paddle { -namespace operators { -// ... -class SliceOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - // ... - } - - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - // Note: don't get data type from ctx.Input("Input"); - auto dtype = ctx.Input(framework::GradVarName("Out"))->type(); - return framework::OpKernelType( dtype, ctx.GetPlace()); - } -}; - - -template -class SliceOpGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr bind) const override { - bind->SetInput("Input", this->Input("Input")); - if (this->HasInput("StartsTensor")) { - bind->SetInput("StartsTensor", this->Input("StartsTensor")); - } - if (this->HasInput("EndsTensor")) { - bind->SetInput("EndsTensor", this->Input("EndsTensor")); - } - if (this->HasInput("StartsTensorList")) { - bind->SetInput("StartsTensorList", this->Input("StartsTensorList")); - } - if (this->HasInput("EndsTensorList")) { - bind->SetInput("EndsTensorList", this->Input("EndsTensorList")); - } - bind->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - bind->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input")); - bind->SetAttrMap(this->Attrs()); - bind->SetType("slice_grad"); - } -}; - -DECLARE_NO_NEED_BUFFER_VARS_INFERER(SliceOpGradNoNeedBufferVarsInference, - "Input"); -} // namespace operators -} // namespace paddle -namespace ops = paddle::operators; -REGISTER_OPERATOR(slice, ops::SliceOp, ops::SliceOpMaker, - ops::SliceOpGradMaker, - ops::SliceOpGradMaker); -REGISTER_OPERATOR(slice_grad, ops::SliceOpGrad, - ops::SliceDoubleOpGradMaker, - ops::SliceDoubleOpGradMaker, - ops::SliceOpGradNoNeedBufferVarsInference); -``` - -### 9.混合设备调用 -由于GPU是异步执行的,当CPU调用返回之后,GPU端可能还没有真正的执行,所以如果在Op中创建了GPU运行时需要用到的临时变量,当GPU开始运行的时候,该临时变量可能在CPU端已经被释放,这样可能会导致GPU计算出错。 - -关于GPU中的一些同步和异步操作: -``` -The following device operations are asynchronous with respect to the host: - Kernel launches; - Memory copies within a single device's memory; - Memory copies from host to device of a memory block of 64 KB or less; - Memory copies performed by functions that are suffixed with Async; - Memory set function calls. -``` - -关于cudaMemCpy和cudaMemCpyAsync注意事项: - -- 如果数据传输是从GPU端到非页锁定的CPU端,数据传输将是同步,即使调用的是异步拷贝操作。 -- 如果数据传输是从CPU端到CPU端,数据传输将是同步的,即使调用的是异步拷贝操作。 - -更多内容可参考:[Asynchronous Concurrent Execution](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#asynchronous-concurrent-execution),[API synchronization behavior](https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html#api-sync-behavior) - -### 10. LoD 在 Op 内部的传导规范 - -[LoD](https://github.com/PaddlePaddle/FluidDoc/blob/develop/doc/fluid/design/concepts/lod_tensor.md) 是 Paddle 框架用来表示变长序列数据的属性,除了仅支持输入是 padding data 的 Op 外,所有 Op 的实现都要考虑 LoD 的传导问题。 - -根据 OP 的计算过程中是否用到 LoD,我们可以将涉及到 LoD 传导问题的 OP 分为两类: LoD-Transparent 与 LoD-Based。 - - - - - - - - - - - - - - - - - - - - - -
类型特点示例
LoD-Transparent 计算过程不依赖 LoD,输入是否有 LoD 不会影响计算的结果,通常是 position-wise 的计算 conv2d_op、batch_norm_op、dropout_op 等
LoD-Based 计算以序列为单位, 计算过程依赖 LoD lstm_op、gru_op、sequence_ops 等
- -这两类 OP 的 LoD 传导需要考虑前向和反向两个过程。 - -#### 前向传导 - -在前向传导过程,与输入的 LoD 相比较,Op 输出的 LoD 可能出现不变、改变和消失这三种情况: - - - 不变:适用于所有的 LoD-Transparent OP 与部分的 LoD-Based OP。可以在`InferShape` 中调用 `ShareLoD()` 直接将输入 Var 的 LoD 共享给输出 Var, 可参考 [lstm_op](https://github.com/PaddlePaddle/Paddle/blob/a88a1faa48a42a8c3737deb0f05da968d200a7d3/paddle/fluid/operators/lstm_op.cc#L92); 如果有多个输入且都可能存在 LoD 的情况,通常默认共享第一个输入, 例如 [elementwise_ops forward](https://github.com/PaddlePaddle/Paddle/blob/5d6a1fcf16bcb48d2e66306b27d9994d9b07433c/paddle/fluid/operators/elementwise/elementwise_op.h#L69); - - - 改变:适用于部分 LoD-Based OP。在实现 OpKernel 时需考虑输出 LoD 的正确计算,真实的 LoD 在前向计算结束后才能确定,此时仍需要在`InferShape` 中调用 `ShareLoD()`,以确保CompileTime 时对 LoD Level 做了正确的传导,可参考 [sequence_expand_op](https://github.com/PaddlePaddle/Paddle/blob/565d30950138b9f831caa33904d9016cf53c6c2e/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc); - - - 消失:适用于输出不再是序列数据的 LoD-Based OP。此时不用再考虑前向的 LoD 传导问题,可参考 [sequence_pool_op](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc); - -其它重要的注意事项: - - - 实现 LoD-Based OP 时,需要处理好 LoD 传导的边界情况,例如对长度为零的输入的支持,并完善相应的单测,单测 case 覆盖空序列出现在 batch 开头、中间和末尾等位置的情况,可参考 [test_lstm_op.py](https://github.com/PaddlePaddle/Paddle/blob/4292bd8687ababc7737cffbddc0d38ead2138c00/python/paddle/fluid/tests/unittests/test_lstm_op.py#L203-L216) - - - 对 LoD Level 有明确要求的 OP,推荐的做法是在 `InferShape` 中即完成 LoD Level的检查,例如 [sequence_pad_op](https://github.com/PaddlePaddle/Paddle/blob/4292bd8687ababc7737cffbddc0d38ead2138c00/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc#L79)。 - - -#### 反向传导 - -通常来讲,OP 的某个输入 Var 所对应的梯度 GradVar 的 LoD 应该与 Var 自身相同,所以应直接将 Var 的 LoD 共享给 GradVar,可以参考 [elementwise ops 的 backward](https://github.com/PaddlePaddle/Paddle/blob/a88a1faa48a42a8c3737deb0f05da968d200a7d3/paddle/fluid/operators/elementwise/elementwise_op.h#L189-L196) - - -## Op性能优化 -### 1.第三方库的选择 -在写Op过程中优先使用高性能(如cudnn、mkldnn、mklml、eigen等)中提供的操作,但是一定要做benchmark,有些库中的操作在深度学习任务中可能会比较慢。因为高性能库(如eigen等)中提供的操作为了更为通用,在性能方面可能并不是很好,通常深度学习模型中数据量较小,所以有些情况下可能高性能库中提供的某些操作速度较慢。比如Elementwise系列的所有Op(前向和反向),Elementwise操作在模型中调用的次数比较多,尤其是Elementwise_add,在很多操作之后都需要添加偏置项。在之前的实现中Elementwise_op直接调用Eigen库,由于Elementwise操作在很多情况下需要对数据做Broadcast,而实验发现Eigen库做Broadcast的速度比较慢,慢的原因在这个PR[#6229](https://github.com/PaddlePaddle/Paddle/pull/6229)中有描述。 - -### 2.Op性能优化 -Op的计算速度与输入的数据量有关,对于某些Op可以根据输入数据的Shape和Op的属性参数来选择不同的计算方式。比如concat_op,当axis>=1时,在对多个tensor做拼接过程中需要对每个tensor做很多次拷贝,如果是在GPU上,需要调用cudaMemCopy。相对CPU而言,GPU属于外部设备,所以每次调用GPU的操作都会有一定的额外开销,并且当需要拷贝的次数较多时,这种开销就更为凸现。目前concat_op的实现会根据输入数据的Shape以及axis值来选择不同的调用方式,如果输入的tensor较多,且axis不等于0,则将多次拷贝操作转换成一个CUDA Kernel来完成;如果输入tensor较少,且axis等于0,使用直接进行拷贝。相关实验过程在该PR([#8669](https://github.com/PaddlePaddle/Paddle/pull/8669))中有介绍。 - -由于CUDA Kernel的调用有一定的额外开销,所以如果Op中出现多次调用CUDA Kernel,可能会影响Op的执行速度。比如之前的sequence_expand_op中包含很多CUDA Kernel,通常这些CUDA Kernel处理的数据量较小,所以频繁调用这样的Kernel会影响Op的计算速度,这种情况下最好将这些小的CUDA Kernel合并成一个。在优化sequence_expand_op过程(相关PR[#9289](https://github.com/PaddlePaddle/Paddle/pull/9289))中就是采用这种思路,优化后的sequence_expand_op比之前的实现平均快出约1倍左右,相关实验细节在该PR([#9289](https://github.com/PaddlePaddle/Paddle/pull/9289))中有介绍。 - -减少CPU与GPU之间的拷贝和同步操作的次数。比如fetch操作,在每个迭代之后都会对模型参数进行更新并得到一个loss,并且数据从GPU端到没有页锁定的CPU端的拷贝是同步的,所以频繁的fetch多个参数会导致模型训练速度变慢。 - -## Op数值稳定性问题 -### 1.有些Op存在数值稳定性问题 -出现数值稳定性的主要原因程序在多次运行时,对浮点型数据施加操作的顺序可能不同,进而导致最终计算结果不同。而GPU是通过多线程并行计算的方式来加速计算的,所以很容易出现对浮点数施加操作的顺序不固定现象。 - -目前发现cudnn中的卷积操作、cudnn中的MaxPooling、CUDA中CudaAtomicXX、ParallelExecutor的Reduce模式下参数梯度的聚合等操作运行结果是非确定的。 - -为此Paddle中添加了一些FLAGS,比如使用FLAGS_cudnn_deterministic来强制cudnn使用确定性算法、FLAGS_cpu_deterministic强制CPU端的计算使用确定性方法。 - -## 其他 -### 1.报错信息 -Enforce提示信息不能为空,并且需要写明,因为报错信息可以更快更方便地分析出错误的原因。 - -### 2.Op的数学公式 -如果Op有数学公式,一定要在代码中将数学公式写明,并在Python API的Doc中显示,因为用户在对比不同框架的计算结果时可能需要了解Paddle对Op是怎么实现的。 - -**注意:**在merge到develop分支之前一定进行公式预览。可参考[dynamic_lstmp](../../../api_cn/layers_cn/nn_cn.html#dynamic-lstmp)。 - -### 3.Op变量名的命名要规范 -在定义Op时,Op的输入输出以及属性的命名需要符合规范,具体命名规则请参考:[name_convention](https://github.com/PaddlePaddle/FluidDoc/blob/release/1.2/doc/fluid/dev/name_convention.md)。 - -### 4.Python端Op接口中参数的顺序 -Python API中参数的顺序一般按照重要性来排,以fc为例: -``` -def fc(input, - size, - num_flatten_dims=1, - param_attr=None, - bias_attr=None, - act=None, - is_test=False, - name=None) -``` diff --git a/docs/dev_guides/api_contributing_guides/new_python_api_cn.md b/docs/dev_guides/api_contributing_guides/new_python_api_cn.md index 4c309275309..5bcc910293b 100644 --- a/docs/dev_guides/api_contributing_guides/new_python_api_cn.md +++ b/docs/dev_guides/api_contributing_guides/new_python_api_cn.md @@ -4,10 +4,10 @@ ## 开发 Python API代码 -这分为两种情况,Paddle 的 API 包含需要开发 c++ operator 的和不需要开发 operator 而仅使用现有 Python API 组合得到的两种,但两种情况下均有 Python 端的开发工作。 +这分为两种情况,Paddle 的 API 包含需要开发 c++ 算子的和不需要开发 c++ 算子而仅使用现有 Python API 组合得到的两种,但两种情况下均有 Python 端的开发工作。 -1. 包含 c++ operator 的开发的情况,需要在 Python 端添加相应 API 以调用对应的 operator; -2. 不需要开发 c++ operator 的情况,需要在 Python 端添加相应 API 以调用其他 API 组合实现功能; +1. 包含 c++ 算子的开发的情况,需要在 Python 端添加相应 API 以调用对应的算子; +2. 不需要开发 c++ 算子的情况,需要在 Python 端添加相应 API 以调用其他 API 组合实现功能; ### 文件位置与 API 名称 @@ -76,7 +76,7 @@ from a import f # it's ok, too # Python/paddle/tensor/math.py def logsumexp(...): ... - + # Python/paddle/tensor/__init__.py from .math import logsumexp @@ -155,17 +155,19 @@ Python API 一般包含如下的部分: 例子: ```Python -def mm(input, mat2, name=None): - # 为了突出重点,省略部分代码 - - # 动态图,直接调用 op 对应的 CPython 函数 - if paddle.in_dynamic_mode(): +def mm(input, mat2, name=None): + # 为了突出重点,省略部分代码 + # 新动态图模式,直接调用 op 对应的 CPython 函数 + if in_dygraph_mode(): + return _C_ops.final_state_matmul(input, mat2, False, False) + # 旧动态图模式 + elif _in_legacy_dygraph(): return _C_ops.matmul_v2(input, mat2) - # 静态分支 + # 静态分支 ## 检测输入 - __check_input(input, mat2) - + __check_input(input, mat2) + ## 构造输出,添加 op,返回输出 helper = LayerHelper('mm', **locals()) out = helper.create_variable_for_type_inference(dtype=input.dtype) @@ -188,44 +190,58 @@ def ones(shape, dtype=None, name=None): 因为 `fill_constant` 里已经处理了动态图和静态图的情况,所以直接调用即可。 -而如果 API 的实现中需要调用一个 op 时,则需要根据动态图和静态图使用不同的写法,用 `paddle.in_dynamic_mode()` 获取当前状态走不同的分支。 +而如果 API 的实现中需要调用一个C++算子时,则需要根据动态图和静态图使用不同的写法。 #### 动静态图分支 +**动态图分支** + +由于目前动态图正处在重构升级阶段,所以需要为新旧动态图分别添加对应的代码分支。其中 `in_dygraph_mode()` 表示新动态图分支,`_in_legacy_dygraph()`表示旧动态图分支。 -参考前面 `paddle.nn.functional.kl_div` 的代码,动态图分支的写法一般是调用 API 对应的 CPython 函数。 +参考`paddle.trace` 的代码,动态图分支的写法一般是调用 API 对应的 CPython 函数。 ```Python -_C_ops.matmul_v2(input, mat2) +# 新动态图模式 +if in_dygraph_mode(): + return _C_ops.final_state_trace( x, offset, axis1, axis2 ) + +# 旧动态图模式 +if _in_legacy_dygraph(): + return _C_ops.trace(x, 'offset', offset, 'axis1', axis1, 'axis2', axis2) ``` -`_C_ops` 是 `Python/paddle/_C_ops.py`,其中从 paddle 编译得到的二进制文件中 import 了 c++ operator 对应的 Python C 函数,函数名和 operator 名一致。如希望调用名为 `matmul_v2` 的 operator,则使用 `_C_ops.matmul_v2`, 然后传入参数。 +`_C_ops` 是 `Python/paddle/_C_ops.py`,其中从 paddle 编译得到的二进制文件中 import 了 c++ 算子对应的 Python C 函数。 -其中参数分为两个部分,`Tensor` 对于 `Tensor` 类型的输入,直接按照定义 opmaker 时添加输入的次序,以按位置传参的方式传入。关于 opmaker 可以参考 [定义OpProtoMaker类](new_cpp_op_cn.html#opprotomaker)(本文中用 opmaker 简称 operator ). +- 在新动态图模式下,Python C 的调用函数名为`final_state_` + 算子名,然后将参数按照Yaml中定义的输入参数顺序传入即可。 +- 在旧动态图模式下,Python C 函数名和算子名一致。如希望调用名为 `trace` 的算子,则使用 `_C_ops.trace`, 然后传入参数。其中参数分为两个部分: + - 对于 `Tensor` 类型的输入,直接按照Yaml中的定义,按位置传参的方式传入 + - 对于非 `Tensor` 类型的输入,则以 `attribute 名,attribute 值` 交替的方式传入,这类似 Python 中的按关键字传参的方式。然后返回调用函数得到的结果。 -而对于非 `Tensor` 类型的输入(对应 opmaker 中的 Attribute),则以 `attribute 名,attribute 值` 交替的方式传入,这类似 Python 中的按关键字传参的方式。然后返回调用函数得到的结果。 -而对于静态图,则一般分为创建输出 Tensor,添加 operator 两步。 +**静态图分支** -```Python -loss = _C_ops.kldiv_loss(input, label, 'reduction', 'none') +对于静态图,一般分为创建输出 Tensor,添加 operator 两步。 -# layerhelper 创建准备工作 -helper = LayerHelper('kl_div', **locals()) +```Python +# LayerHelper是一个用于创建op输出变量、向program中添加op的辅助工具类 +helper = LayerHelper('trace', **locals()) # 创建输出 Tensor -loss = helper.create_variable_for_type_inference(dtype=input.dtype) +out = helper.create_variable_for_type_inference(dtype=x.dtype) # 将输入 Tensor,输出 Tensor, 非 Tensor 的 attributes 以三个字典的形式 # 作为参数添加 operator helper.append_op( - type='kldiv_loss', - inputs={'X': input, - 'Target': label}, - outputs={'Loss': loss}, - attrs={'reduction': 'none'}) + type='trace', + inputs={'Input': [x]}, + attrs={'offset': offset, + 'axis1': axis1, + 'axis2': axis2}, + outputs={'Out': [out]}) +return out ``` +注意:在`append_op`添加的`inputs`和`outputs`项,其中的key值(静态图中变量名)一般为Yaml中定义的输入输出Tensor变量名的首字母大写格式,静态图中的变量名可以在`paddle/fluid/operators/generated_op.cc`(需要先开发C++算子并完成编译)文件内对应算子的`OpMaker`中找到;`attrs`项的变量名与Yaml中相同。 +这里`trace`中的'Input'没有与Yaml配置的中'x'直接对应是由于为了兼容旧算子体系下`Trace`算子的`OpMaker`实现而做了额外的映射,新增算子时无需考虑这种情况。 -上述的代码中,动态图分支的 `input, label` 对应静态图分支中的 inputs 字典,其次序和 opmaker 中定义的有关。而静态图中的 attrs 字典在动态图分支中则以 `key, value` 交替的形式排列,如果有多个 attribute, 则依次排列。 ## 开发单元测试代码 @@ -235,7 +251,7 @@ helper.append_op( 单元测试相关的开发规范可以参考 - [C++ OP 开发(新增原生算子)](new_cpp_op_cn.html) ,[Op开发手册(Operator Development Manual)](https://github.com/PaddlePaddle/Paddle/wiki/Operator-Development-Manual-Index). + [C++ 算子开发指南-添加单元测试](new_cpp_op_cn.html#tianjiadanyuanceshi) ,[Op开发手册(Operator Development Manual)](https://github.com/PaddlePaddle/Paddle/wiki/Operator-Development-Manual-Index). 在此不作展开,主要讲述 Python API 的单元测试。 @@ -267,7 +283,7 @@ helper.append_op( self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float32') self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ else paddle.CPUPlace() - + def test_static_api(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): @@ -280,7 +296,7 @@ helper.append_op( out_ref = ref_hardtanh(self.x_np) for r in res: self.assertEqual(np.allclose(out_ref, r), True) - + def test_dygraph_api(self): paddle.disable_static(self.place) x = paddle.to_tensor(self.x_np) @@ -290,7 +306,7 @@ helper.append_op( out_ref = ref_hardtanh(self.x_np) for r in [out1, out2]: self.assertEqual(np.allclose(out_ref, r.numpy()), True) - + out1 = F.hardtanh(x, -2.0, 2.0) m = paddle.nn.Hardtanh(-2.0, 2.0) out2 = m(x) @@ -306,7 +322,7 @@ helper.append_op( paddle.disable_static(place=paddle.fluid.CPUPlace()) self.run_imperative() paddle.enable_static() - + with fluid.program_guard(fluid.Program()): self.run_static()