Skip to content

Commit

Permalink
propagate nan and inf in relu and scaled_fc
Browse files Browse the repository at this point in the history
  • Loading branch information
hc-www authored and zmxdream committed Oct 10, 2023
1 parent 7733f57 commit a7e2d64
Show file tree
Hide file tree
Showing 13 changed files with 76 additions and 34 deletions.
23 changes: 22 additions & 1 deletion paddle/fluid/operators/activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -890,12 +890,33 @@ It is recommended to use the defaults for this activation.
}
};

class ReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"Input of Relu operator, an N-D Tensor, with data type float32, float64 or float16.");
AddOutput("Out",
"Output of Relu operator, a Tensor with shape same as input.");
AddAttr<float>("safe",
"QuantSafe relu")
.SetDefault(0.);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
.AsExtra();
AddAttr<bool>("use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(false)
.AsExtra();
AddComment(ReluDoc);
}
};

REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Silu, SiluDoc);
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
REGISTER_ACTIVATION_OP_MAKER(Expm1, Expm1Doc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
Expand Down
9 changes: 7 additions & 2 deletions paddle/fluid/operators/scaled_fc_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,13 @@ __global__ void kernel_cast_and_cut(const int N,
int col_idx = i % coln_ori;
int row_idx = i / coln_ori;
int idx = row_idx * coln_pad + col_idx;
matrix[i] = static_cast<T>(matrix_pad[idx]);
matrix[i] *= scale_factor;
T tmp = static_cast<T>(matrix_pad[idx]) * scale_factor;
// Some functions will replace inf with a normal float number such as fmax, which stops us finding
// abnormal instance in the final output tensors. Replace inf with nan to let bad things propagate.
if (isinf(tmp)) {
tmp = std::numeric_limits<T>::quiet_NaN();
}
matrix[i] = tmp;
}
}

Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/activation_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ DECLARE_ACTIVATION_KERNEL(Cosh)
DECLARE_ACTIVATION_KERNEL(Asinh)
DECLARE_ACTIVATION_KERNEL(Acosh)
DECLARE_ACTIVATION_KERNEL(Atanh)
DECLARE_ACTIVATION_KERNEL(Relu)
DECLARE_ACTIVATION_KERNEL(Tanh)
DECLARE_ACTIVATION_KERNEL(TanhShrink)
DECLARE_ACTIVATION_KERNEL(Silu)
Expand All @@ -73,6 +72,7 @@ DECLARE_ACTIVATION_KERNEL(Floor)
DECLARE_ACTIVATION_KERNEL(Ceil)
DECLARE_ACTIVATION_KERNEL(Negative)

DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Relu, safe)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu, alpha)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, threshold)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Relu6, threshold)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ DEFINE_CPU_ACTIVATION_KERNEL(Cosh, CoshFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Asinh, AsinhFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Acosh, AcoshFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Atanh, AtanhFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Relu, ReluCPUFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Tanh, TanhFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(TanhShrink, TanhShrinkFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Silu, SiluFunctor)
Expand All @@ -92,6 +91,7 @@ DEFINE_CPU_ACTIVATION_KERNEL(Floor, FloorFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Ceil, CeilFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Negative, NegativeFunctor)

DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Relu, ReluCPUFunctor, safe)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, LeakyReluFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
ThresholdedReluFunctor,
Expand Down
24 changes: 21 additions & 3 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -777,10 +777,18 @@ struct Expm1GradFunctor : public BaseActivationFunctor<T> {
// relu(x) = max(x, 0)
template <typename T>
struct ReluCPUFunctor : public BaseActivationFunctor<T> {
float safe_ {0.0};
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"safe", &safe_}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr([] HOSTDEVICE(T v) {
return v > static_cast<T>(0) ? v : static_cast<T>(0);
out.device(d) = x.unaryExpr([this] HOSTDEVICE(T v) {
if (safe_) {
return v * (v > static_cast<T>(0) ? static_cast<T>(1) : static_cast<T>(0));
} else {
return v > static_cast<T>(0) ? v : static_cast<T>(0);
}
});
}
};
Expand Down Expand Up @@ -2094,11 +2102,21 @@ struct SquareGradGradFunctor : public BaseActivationFunctor<T> {
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
template <typename T>
struct CudaReluFunctor : public BaseActivationFunctor<T> {
// the data type is limited to float ...
float safe_ {0.0};
T zero = static_cast<T>(0.0f);

typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"safe", &safe_}};
}

// relu(x) = max(x, 0)
__device__ __forceinline__ T operator()(const T x) const {
return x > zero ? x : zero;
if (safe_) {
return x * (x > static_cast<T>(0) ? static_cast<T>(1) : static_cast<T>(0));
} else {
return x > zero ? x : zero;
}
}
};

Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/activation_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ DEFINE_GPU_ACTIVATION_KERNEL(Cosh, CudaCoshFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Asinh, CudaAsinhFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Acosh, CudaAcoshFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Atanh, CudaAtanhFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Relu, CudaReluFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Tanh, CudaTanhFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(TanhShrink, CudaTanhShrinkFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Silu, CudaSiluFunctor)
Expand All @@ -108,6 +107,7 @@ DEFINE_GPU_ACTIVATION_KERNEL(Round, CudaRoundFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Floor, CudaFloorFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Ceil, CudaCeilFunctor)

DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Relu, CudaReluFunctor, safe)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
CudaThresholdedReluFunctor,
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/sparse/impl/unary_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,10 @@ DEFINE_SPARSE_UNARY_KERNEL(Atanh)
DEFINE_SPARSE_UNARY_KERNEL(Sqrt)
DEFINE_SPARSE_UNARY_KERNEL(Square)
DEFINE_SPARSE_UNARY_KERNEL(Log1p)
DEFINE_SPARSE_UNARY_KERNEL(Relu)
DEFINE_SPARSE_UNARY_KERNEL(Abs)
DEFINE_SPARSE_UNARY_KERNEL(Expm1)

DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Relu, safe)
DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Pow, factor)
DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Relu6, threshold)
DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(LeakyRelu, alpha)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/sparse/unary_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ DECLARE_SPARSE_UNARY_KERNEL(Atan)
DECLARE_SPARSE_UNARY_KERNEL(Sinh)
DECLARE_SPARSE_UNARY_KERNEL(Asinh)
DECLARE_SPARSE_UNARY_KERNEL(Atanh)
DECLARE_SPARSE_UNARY_KERNEL(Relu)
DECLARE_SPARSE_UNARY_KERNEL(Tanh)
DECLARE_SPARSE_UNARY_KERNEL(Square)
DECLARE_SPARSE_UNARY_KERNEL(Sqrt)
DECLARE_SPARSE_UNARY_KERNEL(Log1p)
DECLARE_SPARSE_UNARY_KERNEL(Abs)
DECLARE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Relu, safe)
DECLARE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Pow, factor)

template <typename T, typename Context>
Expand Down
20 changes: 1 addition & 19 deletions python/paddle/fluid/contrib/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2621,25 +2621,7 @@ def scaled_fc(input,
attrs={
'input_scale_factor': input_scale_factor,
'bias_scale_factor': bias_scale_factor,
'grad_scale_factor': grad_scale_factor


















'grad_scale_factor': grad_scale_factor,
},
outputs={"Out": pre_act})

Expand Down
4 changes: 4 additions & 0 deletions python/paddle/fluid/dygraph/layer_object_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ def append_activation(self, input_var, act=None, use_cudnn=None):
raise TypeError(
str(act) + " should be unicode or str in %s ", self.name)

#TODO
# act should support dict type.
# see LayerHelper.append_activation in paddle/fluid/layer_helper.py

if (use_cudnn is not None) and use_cudnn:
act['use_cudnn'] = use_cudnn
use_mkldnn = _global_flags()["FLAGS_use_mkldnn"]
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/fluid/layer_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ def append_activation(self, input_var):
return input_var
if isinstance(act, six.string_types):
act = {'type': act}
elif isinstance(act, dict) and 'type' in act:
# act can be a dict to pass some attributes to activation op:
# e.g. fluid.layers.fc(..., act={'type': 'relu', 'safe': 1.})
act = copy.deepcopy(act)
else:
raise TypeError(str(act) + " should be unicode or str")

Expand Down
8 changes: 6 additions & 2 deletions python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9392,7 +9392,7 @@ def log(x, name=None):


@deprecated(since="2.0.0", update_to="paddle.nn.functional.relu")
def relu(x, name=None):
def relu(x, name=None, safe=0.):
"""
${comment}

Expand Down Expand Up @@ -9431,9 +9431,13 @@ def relu(x, name=None):
helper = LayerHelper('relu', **locals())
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype)
attrs = {}
if safe != 0.:
attrs['safe'] = safe
helper.append_op(type="relu",
inputs={"X": helper.input('x')},
outputs={"Out": out})
outputs={"Out": out},
attrs=attrs)
return out


Expand Down
7 changes: 5 additions & 2 deletions python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ def rrelu(x, lower=1. / 8., upper=1. / 3., training=True, name=None):
return out


def relu(x, name=None):
def relu(x, name=None, safe=0.):
"""
relu activation.
Expand Down Expand Up @@ -701,7 +701,10 @@ def relu(x, name=None):
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'relu')
helper = LayerHelper('relu', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(type='relu', inputs={'X': x}, outputs={'Out': out})
attrs = {}
if safe != 0.:
attrs['safe'] = safe
helper.append_op(type='relu', inputs={'X': x}, outputs={'Out': out}, attrs=attrs)
return out


Expand Down

0 comments on commit a7e2d64

Please sign in to comment.