diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 62b805cf422d95..96526c4c8cd398 100755 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -890,12 +890,33 @@ It is recommended to use the defaults for this activation. } }; +class ReluOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "Input of Relu operator, an N-D Tensor, with data type float32, float64 or float16."); + AddOutput("Out", + "Output of Relu operator, a Tensor with shape same as input."); + AddAttr("safe", + "QuantSafe relu") + .SetDefault(0.); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false) + .AsExtra(); + AddAttr("use_cudnn", + "(bool, default false) Only used in cudnn kernel, need install cudnn") + .SetDefault(false) + .AsExtra(); + AddComment(ReluDoc); + } +}; + REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc); REGISTER_ACTIVATION_OP_MAKER(Silu, SiluDoc); REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc); REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc); REGISTER_ACTIVATION_OP_MAKER(Expm1, Expm1Doc); -REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc); REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc); REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc); REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc); diff --git a/paddle/fluid/operators/scaled_fc_op.cu b/paddle/fluid/operators/scaled_fc_op.cu index 93dcd0efff00fc..20bd9dbf073612 100644 --- a/paddle/fluid/operators/scaled_fc_op.cu +++ b/paddle/fluid/operators/scaled_fc_op.cu @@ -80,8 +80,13 @@ __global__ void kernel_cast_and_cut(const int N, int col_idx = i % coln_ori; int row_idx = i / coln_ori; int idx = row_idx * coln_pad + col_idx; - matrix[i] = static_cast(matrix_pad[idx]); - matrix[i] *= scale_factor; + T tmp = static_cast(matrix_pad[idx]) * scale_factor; + // Some functions will replace inf with a normal float number such as fmax, which stops us finding + // abnormal instance in the final output tensors. Replace inf with nan to let bad things propagate. + if (isinf(tmp)) { + tmp = std::numeric_limits::quiet_NaN(); + } + matrix[i] = tmp; } } diff --git a/paddle/phi/kernels/activation_kernel.h b/paddle/phi/kernels/activation_kernel.h index 8a83226b23027e..f7680528502b39 100644 --- a/paddle/phi/kernels/activation_kernel.h +++ b/paddle/phi/kernels/activation_kernel.h @@ -51,7 +51,6 @@ DECLARE_ACTIVATION_KERNEL(Cosh) DECLARE_ACTIVATION_KERNEL(Asinh) DECLARE_ACTIVATION_KERNEL(Acosh) DECLARE_ACTIVATION_KERNEL(Atanh) -DECLARE_ACTIVATION_KERNEL(Relu) DECLARE_ACTIVATION_KERNEL(Tanh) DECLARE_ACTIVATION_KERNEL(TanhShrink) DECLARE_ACTIVATION_KERNEL(Silu) @@ -73,6 +72,7 @@ DECLARE_ACTIVATION_KERNEL(Floor) DECLARE_ACTIVATION_KERNEL(Ceil) DECLARE_ACTIVATION_KERNEL(Negative) +DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Relu, safe) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu, alpha) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, threshold) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Relu6, threshold) diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index ac75c4ad3479fe..de10c781b3f765 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -70,7 +70,6 @@ DEFINE_CPU_ACTIVATION_KERNEL(Cosh, CoshFunctor) DEFINE_CPU_ACTIVATION_KERNEL(Asinh, AsinhFunctor) DEFINE_CPU_ACTIVATION_KERNEL(Acosh, AcoshFunctor) DEFINE_CPU_ACTIVATION_KERNEL(Atanh, AtanhFunctor) -DEFINE_CPU_ACTIVATION_KERNEL(Relu, ReluCPUFunctor) DEFINE_CPU_ACTIVATION_KERNEL(Tanh, TanhFunctor) DEFINE_CPU_ACTIVATION_KERNEL(TanhShrink, TanhShrinkFunctor) DEFINE_CPU_ACTIVATION_KERNEL(Silu, SiluFunctor) @@ -92,6 +91,7 @@ DEFINE_CPU_ACTIVATION_KERNEL(Floor, FloorFunctor) DEFINE_CPU_ACTIVATION_KERNEL(Ceil, CeilFunctor) DEFINE_CPU_ACTIVATION_KERNEL(Negative, NegativeFunctor) +DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Relu, ReluCPUFunctor, safe) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, LeakyReluFunctor, alpha) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, ThresholdedReluFunctor, diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index ee79cafd155dcf..f1699d019721ed 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -777,10 +777,18 @@ struct Expm1GradFunctor : public BaseActivationFunctor { // relu(x) = max(x, 0) template struct ReluCPUFunctor : public BaseActivationFunctor { + float safe_ {0.0}; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"safe", &safe_}}; + } template void operator()(Device d, X x, Out out) const { - out.device(d) = x.unaryExpr([] HOSTDEVICE(T v) { - return v > static_cast(0) ? v : static_cast(0); + out.device(d) = x.unaryExpr([this] HOSTDEVICE(T v) { + if (safe_) { + return v * (v > static_cast(0) ? static_cast(1) : static_cast(0)); + } else { + return v > static_cast(0) ? v : static_cast(0); + } }); } }; @@ -2094,11 +2102,21 @@ struct SquareGradGradFunctor : public BaseActivationFunctor { #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) template struct CudaReluFunctor : public BaseActivationFunctor { + // the data type is limited to float ... + float safe_ {0.0}; T zero = static_cast(0.0f); + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"safe", &safe_}}; + } + // relu(x) = max(x, 0) __device__ __forceinline__ T operator()(const T x) const { - return x > zero ? x : zero; + if (safe_) { + return x * (x > static_cast(0) ? static_cast(1) : static_cast(0)); + } else { + return x > zero ? x : zero; + } } }; diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index 6e116a3e157503..acf4fa8b4a68d9 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -87,7 +87,6 @@ DEFINE_GPU_ACTIVATION_KERNEL(Cosh, CudaCoshFunctor) DEFINE_GPU_ACTIVATION_KERNEL(Asinh, CudaAsinhFunctor) DEFINE_GPU_ACTIVATION_KERNEL(Acosh, CudaAcoshFunctor) DEFINE_GPU_ACTIVATION_KERNEL(Atanh, CudaAtanhFunctor) -DEFINE_GPU_ACTIVATION_KERNEL(Relu, CudaReluFunctor) DEFINE_GPU_ACTIVATION_KERNEL(Tanh, CudaTanhFunctor) DEFINE_GPU_ACTIVATION_KERNEL(TanhShrink, CudaTanhShrinkFunctor) DEFINE_GPU_ACTIVATION_KERNEL(Silu, CudaSiluFunctor) @@ -108,6 +107,7 @@ DEFINE_GPU_ACTIVATION_KERNEL(Round, CudaRoundFunctor) DEFINE_GPU_ACTIVATION_KERNEL(Floor, CudaFloorFunctor) DEFINE_GPU_ACTIVATION_KERNEL(Ceil, CudaCeilFunctor) +DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Relu, CudaReluFunctor, safe) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, CudaThresholdedReluFunctor, diff --git a/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h b/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h index 338bb13d287275..4765fcdf915bbe 100644 --- a/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h +++ b/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h @@ -84,9 +84,10 @@ DEFINE_SPARSE_UNARY_KERNEL(Atanh) DEFINE_SPARSE_UNARY_KERNEL(Sqrt) DEFINE_SPARSE_UNARY_KERNEL(Square) DEFINE_SPARSE_UNARY_KERNEL(Log1p) -DEFINE_SPARSE_UNARY_KERNEL(Relu) DEFINE_SPARSE_UNARY_KERNEL(Abs) DEFINE_SPARSE_UNARY_KERNEL(Expm1) + +DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Relu, safe) DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Pow, factor) DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Relu6, threshold) DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(LeakyRelu, alpha) diff --git a/paddle/phi/kernels/sparse/unary_kernel.h b/paddle/phi/kernels/sparse/unary_kernel.h index fdb6b21a44427c..c1f5ea3654cdd5 100644 --- a/paddle/phi/kernels/sparse/unary_kernel.h +++ b/paddle/phi/kernels/sparse/unary_kernel.h @@ -49,12 +49,12 @@ DECLARE_SPARSE_UNARY_KERNEL(Atan) DECLARE_SPARSE_UNARY_KERNEL(Sinh) DECLARE_SPARSE_UNARY_KERNEL(Asinh) DECLARE_SPARSE_UNARY_KERNEL(Atanh) -DECLARE_SPARSE_UNARY_KERNEL(Relu) DECLARE_SPARSE_UNARY_KERNEL(Tanh) DECLARE_SPARSE_UNARY_KERNEL(Square) DECLARE_SPARSE_UNARY_KERNEL(Sqrt) DECLARE_SPARSE_UNARY_KERNEL(Log1p) DECLARE_SPARSE_UNARY_KERNEL(Abs) +DECLARE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Relu, safe) DECLARE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Pow, factor) template diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index a0a04e90084fee..84b422fe1722ee 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -2621,25 +2621,7 @@ def scaled_fc(input, attrs={ 'input_scale_factor': input_scale_factor, 'bias_scale_factor': bias_scale_factor, - 'grad_scale_factor': grad_scale_factor - - - - - - - - - - - - - - - - - - + 'grad_scale_factor': grad_scale_factor, }, outputs={"Out": pre_act}) diff --git a/python/paddle/fluid/dygraph/layer_object_helper.py b/python/paddle/fluid/dygraph/layer_object_helper.py index 394df321811d83..d7b1c6c45624f5 100644 --- a/python/paddle/fluid/dygraph/layer_object_helper.py +++ b/python/paddle/fluid/dygraph/layer_object_helper.py @@ -158,6 +158,10 @@ def append_activation(self, input_var, act=None, use_cudnn=None): raise TypeError( str(act) + " should be unicode or str in %s ", self.name) + #TODO + # act should support dict type. + # see LayerHelper.append_activation in paddle/fluid/layer_helper.py + if (use_cudnn is not None) and use_cudnn: act['use_cudnn'] = use_cudnn use_mkldnn = _global_flags()["FLAGS_use_mkldnn"] diff --git a/python/paddle/fluid/layer_helper.py b/python/paddle/fluid/layer_helper.py index 42b67a5a0dfa85..c4888faf89dae0 100644 --- a/python/paddle/fluid/layer_helper.py +++ b/python/paddle/fluid/layer_helper.py @@ -147,6 +147,10 @@ def append_activation(self, input_var): return input_var if isinstance(act, six.string_types): act = {'type': act} + elif isinstance(act, dict) and 'type' in act: + # act can be a dict to pass some attributes to activation op: + # e.g. fluid.layers.fc(..., act={'type': 'relu', 'safe': 1.}) + act = copy.deepcopy(act) else: raise TypeError(str(act) + " should be unicode or str") diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index a67e827ac5838c..cf8c955e6e8e21 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -9392,7 +9392,7 @@ def log(x, name=None): @deprecated(since="2.0.0", update_to="paddle.nn.functional.relu") -def relu(x, name=None): +def relu(x, name=None, safe=0.): """ ${comment} @@ -9431,9 +9431,13 @@ def relu(x, name=None): helper = LayerHelper('relu', **locals()) dtype = helper.input_dtype(input_param_name='x') out = helper.create_variable_for_type_inference(dtype) + attrs = {} + if safe != 0.: + attrs['safe'] = safe helper.append_op(type="relu", inputs={"X": helper.input('x')}, - outputs={"Out": out}) + outputs={"Out": out}, + attrs=attrs) return out diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 373186096bda0c..4b2db6289b6892 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -667,7 +667,7 @@ def rrelu(x, lower=1. / 8., upper=1. / 3., training=True, name=None): return out -def relu(x, name=None): +def relu(x, name=None, safe=0.): """ relu activation. @@ -701,7 +701,10 @@ def relu(x, name=None): check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'relu') helper = LayerHelper('relu', **locals()) out = helper.create_variable_for_type_inference(x.dtype) - helper.append_op(type='relu', inputs={'X': x}, outputs={'Out': out}) + attrs = {} + if safe != 0.: + attrs['safe'] = safe + helper.append_op(type='relu', inputs={'X': x}, outputs={'Out': out}, attrs=attrs) return out