Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay, Topi, TF Frontend] Isfinite operator #4981

Merged
merged 25 commits into from
Mar 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ List of operators
topi.round
topi.abs
topi.isnan
topi.isfinite
topi.isinf
topi.exp
topi.tanh
topi.log
Expand Down Expand Up @@ -134,6 +136,8 @@ topi
.. autofunction:: topi.round
.. autofunction:: topi.abs
.. autofunction:: topi.isnan
.. autofunction:: topi.isfinite
.. autofunction:: topi.isinf
.. autofunction:: topi.exp
.. autofunction:: topi.tanh
.. autofunction:: topi.log
Expand Down
2 changes: 2 additions & 0 deletions docs/frontend/tensorflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ Supported Ops
- Greater
- GreaterEqual
- Identity
- IsFinite
- IsInf
- LeakyRelu
- LeftShift
- Less
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,8 @@ class CallNode : public PrimExprNode {
static constexpr const char* glsl_texture_store = "glsl_texture_store";
static constexpr const char* prefetch = "prefetch";
static constexpr const char* isnan = "isnan";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth adding an isinf operator to this PR? That way we can check for both halves of isfinite separately if needed. I'm not sure if you'd ever need to check for just infinity rather than both nan and infinity though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added support for isinf as well.

static constexpr const char* isfinite = "isfinite";
static constexpr const char* isinf = "isinf";

/*! \brief Vectorizable intrinsic list. */
static const char* vectorizable_intrinsics[];
Expand Down
21 changes: 21 additions & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ TVM_DLL PrimExpr max_value(const DataType& dtype);
*/
TVM_DLL PrimExpr min_value(const DataType& dtype);

/*!
* Get the value of infinity.
* \param dtype The data type.
* \return the infinity value in this format.
*/
TVM_DLL PrimExpr infinity(const DataType& dtype);

/*!
* \brief cast value to type.
*
Expand Down Expand Up @@ -439,6 +446,20 @@ TVM_DLL PrimExpr abs(PrimExpr x);
*/
TVM_DLL PrimExpr isnan(PrimExpr x);

/*!
* \brief Check if x is finite.
* \param x The input data
* \return The result expression.
*/
TVM_DLL PrimExpr isfinite(PrimExpr x);

/*!
* \brief Check if x is infinite.
* \param x The input data
* \return The result expression.
*/
TVM_DLL PrimExpr isinf(PrimExpr x);

/*!
* \brief sum of of source expression over axis
* \param source The source expression.
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1667,6 +1667,8 @@ def _impl(inputs, attr, params):
'Greater' : _broadcast('greater'),
'GreaterEqual' : _broadcast('greater_equal'),
'Identity' : _identity(),
'IsFinite' : AttrCvt('isfinite'),
'IsInf' : AttrCvt('isinf'),
'LeakyRelu' : AttrCvt('leaky_relu'),
'LeftShift' : AttrCvt('left_shift'),
'Less' : _broadcast('less'),
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
register_broadcast_schedule("less_equal")
register_broadcast_schedule("greater")
register_broadcast_schedule("greater_equal")
register_broadcast_schedule("isfinite")
register_broadcast_schedule("isinf")
register_injective_schedule("maximum")
register_injective_schedule("minimum")
register_injective_schedule("right_shift")
Expand Down
32 changes: 32 additions & 0 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,3 +1008,35 @@ def ndarray_size(data, dtype="int32"):
The number of elements of input tensor.
"""
return _make.ndarray_size(data, dtype)


def isfinite(data):
"""Compute element-wise finiteness of data.

Parameters
----------
data : relay.Expr
The input data

Returns
-------
result : relay.Expr
The computed result.
"""
return _make.isfinite(data)


def isinf(data):
"""Compute element-wise infiniteness of data.

Parameters
----------
data : relay.Expr
The input data

Returns
-------
result : relay.Expr
The computed result.
"""
return _make.isinf(data)
3 changes: 2 additions & 1 deletion python/tvm/te/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
# expose all operators in tvm tir.op
from tvm.tir import any, all, min_value, max_value, trace
from tvm.tir import exp, erf, tanh, sigmoid, log, tan, cos, sin, atan, sqrt, rsqrt, floor, ceil
from tvm.tir import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else
from tvm.tir import trunc, abs, round, nearbyint, power, popcount, fmod, if_then_else
from tvm.tir import isnan, isfinite, isinf
from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from tvm.tir import comm_reducer, min, max, sum

Expand Down
3 changes: 2 additions & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
from .op import exp, exp2, exp10, log, log2, log10
from .op import cos, sin, cosh, sinh, tan, tanh, atan
from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil
from .op import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else
from .op import trunc, abs, round, nearbyint, power, popcount, fmod, if_then_else
from .op import isnan, isfinite, isinf
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from .op import comm_reducer, min, max, sum

Expand Down
32 changes: 32 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,38 @@ def isnan(x):
return _ffi_api.isnan(x)


def isfinite(x):
"""Check if input value is finite.

Parameters
----------
x : PrimExpr
Input argument.

Returns
-------
y : PrimExpr
The result.
"""
return _ffi_api.isfinite(x)


def isinf(x):
"""Check if input value is infinite.

Parameters
----------
x : PrimExpr
Input argument.

Returns
-------
y : PrimExpr
The result.
"""
return _ffi_api.isinf(x)


def power(x, y):
"""x power y

Expand Down
18 changes: 18 additions & 0 deletions src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -415,5 +415,23 @@ ElemwiseArbitraryLayout)
.set_support_level(10)
.set_attr<FTVMCompute>("FTVMCompute", NdarraySizeCompute);

RELAY_REGISTER_UNARY_OP("isfinite")
.describe(R"code(Returns the finiteness of input, computed element-wise.
.. math::
isfinite(x)
)code" TVM_ADD_FILELINE)
.set_support_level(3)
.add_type_rel("IdentityCompRel", IdentityCompRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isfinite));

RELAY_REGISTER_UNARY_OP("isinf")
.describe(R"code(Returns the infiniteness of input, computed element-wise.
.. math::
isfinite(x)
)code" TVM_ADD_FILELINE)
.set_support_level(3)
.add_type_rel("IdentityCompRel", IdentityCompRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isinf));

} // namespace relay
} // namespace tvm
12 changes: 12 additions & 0 deletions src/relay/op/type_relations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,18 @@ bool BroadcastCompRel(const Array<Type>& types,
return false;
}

bool IdentityCompRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
if (auto* t0 = types[0].as<TensorTypeNode>()) {
Type out_type = TensorType(GetRef<TensorType>(t0)->shape, DataType::Bool());
reporter->Assign(types[1], out_type);
return true;
}
return false;
}

Array<IndexExpr> RankShape(const Array<IndexExpr>& shape) {
if (shape.size() == 0) {
return {};
Expand Down
5 changes: 5 additions & 0 deletions src/relay/op/type_relations.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ bool BroadcastCompRel(const Array<Type>& types,
const Attrs& attrs,
const TypeReporter& reporter);

bool IdentityCompRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter);

Array<IndexExpr> RankShape(const Array<IndexExpr>& shape);

} // namespace relay
Expand Down
16 changes: 16 additions & 0 deletions src/target/intrin_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,22 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid")
*rv = one / (one + exp(-call->args[0]));
});

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isfinite")
.set_body([](const TVMArgs& args, TVMRetValue* rv){
PrimExpr e = args[0];
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
*rv = isfinite(call->args[0]);
});

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isinf")
.set_body([](const TVMArgs& args, TVMRetValue* rv){
PrimExpr e = args[0];
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
*rv = isinf(call->args[0]);
});

} // namespace intrin
} // namespace codegen
} // namespace tvm
36 changes: 36 additions & 0 deletions src/tir/ir/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,21 @@ PrimExpr min_value(const DataType& dtype) {
return PrimExpr();
}

// infinity
PrimExpr infinity(const DataType& dtype) {
using namespace tir;
CHECK_EQ(dtype.lanes(), 1);
if (dtype.is_float()) {
if (dtype.bits() == 64) {
return FloatImm(dtype, std::numeric_limits<double>::infinity());
} else if (dtype.bits() == 32 || dtype.bits() == 16) {
return FloatImm(dtype, std::numeric_limits<float>::infinity());
}
}
LOG(FATAL) << "Cannot decide infinity for type " << dtype;
return PrimExpr();
}

namespace tir {
template<typename ValueType>
inline bool ConstPowerHelper(ValueType val, int *shift) {
Expand Down Expand Up @@ -575,6 +590,21 @@ PrimExpr isnan(PrimExpr x) {
}
}

PrimExpr isinf(PrimExpr x) {
DataType t = DataType::Bool(x.dtype().lanes());
if (x.dtype().is_int() || x.dtype().is_uint()) {
return make_const(t, false);
} else if (x.dtype().is_float()) {
PrimExpr infX = infinity(x.dtype());
return abs(x) == infX && !isnan(x);
} else {
LOG(FATAL) << "Data type " << x.dtype() << " not supported for finiteness ops. Skipping it...";
return x;
}
}

PrimExpr isfinite(PrimExpr x) { return !isinf(x) && !isnan(x); }

PrimExpr sum(PrimExpr source, Array<IterVar> rdom) {
Var x("x", source.dtype()), y("y", source.dtype());
PrimExpr result = tir::AddNode::make(x, y);
Expand Down Expand Up @@ -721,6 +751,12 @@ TVM_REGISTER_GLOBAL("tir.abs")
TVM_REGISTER_GLOBAL("tir.isnan")
.set_body_typed(tvm::isnan);

TVM_REGISTER_GLOBAL("tir.isfinite")
.set_body_typed(tvm::isfinite);

TVM_REGISTER_GLOBAL("tir.isinf")
.set_body_typed(tvm::isinf);

TVM_REGISTER_GLOBAL("tir.floor")
.set_body_typed(tvm::floor);

Expand Down
34 changes: 33 additions & 1 deletion tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3152,7 +3152,37 @@ def test_forward_dilation():
_test_dilation2d([1, 3, 3, 1], [2, 2, 1], [1, 1, 1, 1], [1, 2, 2, 1], "SAME")
_test_dilation2d([1, 3, 3, 1], [2, 2, 1], [1, 1, 1, 1], [1, 1, 2, 1], "VALID")

# #######################################################################

#######################################################################
# infinity ops
# ------------
def _verify_infiniteness_ops(tf_op, name):
"""test operator infinity ops"""

# Only float types are allowed in Tensorflow for isfinite and isinf
# float16 is failing on cuda
tf_dtypes = ["float32", "float64"]
for tf_dtype in tf_dtypes:
shape = (8, 8)
data = np.random.uniform(size=shape).astype(tf_dtype)
data.ravel()[np.random.choice(data.size, int(data.size * 0.5), replace=False)] = np.infty
data.ravel()[np.random.choice(data.size, int(data.size * 0.5), replace=False)] = np.nan

tf.reset_default_graph()
in_data = tf.placeholder(tf_dtype, shape, name="in_data")
tf_op(in_data, name=name)
compare_tf_with_tvm([data], ['in_data:0'], '{}:0'.format(name))


def test_forward_isinf():
_verify_infiniteness_ops(tf.is_inf, "isinf")


def test_forward_isfinite():
_verify_infiniteness_ops(tf.is_finite, "isfinite")


#######################################################################
# Main
# ----
if __name__ == '__main__':
Expand Down Expand Up @@ -3224,6 +3254,8 @@ def test_forward_dilation():
test_forward_squared_difference()
test_forward_add_n()
test_forward_floormod()
test_forward_isfinite()
test_forward_isinf()
test_forward_unravel_index()

# Reductions
Expand Down
Loading