diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 131d504795fb4..d7f4971724711 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -605,7 +605,6 @@ void DiagonalInferMeta(const MetaTensor& input, int offset_ = offset; int axis1_ = axis1 < 0 ? x_dims.size() + axis1 : axis1; int axis2_ = axis2 < 0 ? x_dims.size() + axis2 : axis2; - PADDLE_ENFORCE_GE( x_dims.size(), 2, @@ -621,6 +620,15 @@ void DiagonalInferMeta(const MetaTensor& input, -(x_dims.size()), (x_dims.size() - 1), axis1)); + PADDLE_ENFORCE_GE( + axis1_, + 0, + phi::errors::OutOfRange( + "Attr(axis1) is out of range (expected to be in range of [%ld, " + "%ld], but got %ld).", + -(x_dims.size()), + (x_dims.size() - 1), + axis1)); PADDLE_ENFORCE_LT( axis2_, x_dims.size(), @@ -630,6 +638,15 @@ void DiagonalInferMeta(const MetaTensor& input, -(x_dims.size()), (x_dims.size() - 1), axis2)); + PADDLE_ENFORCE_GE( + axis2_, + 0, + phi::errors::OutOfRange( + "Attr(axis2) is out of range (expected to be in range of [%ld, " + "%ld], but got %ld).", + -(x_dims.size()), + (x_dims.size() - 1), + axis2)); PADDLE_ENFORCE_NE( axis1_, axis2_,