Skip to content

Commit

Permalink
[Relay/TOPI][Op] Add erf intrinsic and op (apache#3702)
Browse files Browse the repository at this point in the history
* add more ops

* stop vectorization for erf

* x

* cleanup

* fix

* add whitelist for vectorizable intrin

* add tf converter

* fix dense

* fix

* add missing intrin

* fix mxnet frontend

* fix nvptx
  • Loading branch information
icemelon authored and wweic committed Sep 16, 2019
1 parent 9798b95 commit 2d58005
Show file tree
Hide file tree
Showing 23 changed files with 233 additions and 83 deletions.
1 change: 1 addition & 0 deletions include/tvm/expr_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,7 @@ TVM_DLL Expr trunc(Expr x);
} \

TVM_DECLARE_INTRIN_UNARY(exp);
TVM_DECLARE_INTRIN_UNARY(erf);
TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid);
TVM_DECLARE_INTRIN_UNARY(sqrt);
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,9 @@ class Call : public ExprNode {
name == intrin_name);
}

/*! \return Whether call node can be vectorized. */
bool is_vectorizable() const;

static constexpr const char* _type_key = "Call";
TVM_DECLARE_NODE_TYPE_INFO(Call, ExprNode);

Expand All @@ -571,6 +574,9 @@ class Call : public ExprNode {
static constexpr const char* likely = "likely";
static constexpr const char* glsl_texture_store = "glsl_texture_store";
static constexpr const char* prefetch = "prefetch";

/*! \brief Vectorizable intrinsic list. */
static const char* vectorizable_intrinsics[];
};

/*!
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,22 @@ def exp(x):
return call_pure_intrin(x.dtype, "exp", x)


def erf(x):
"""Take gauss error function of the input x.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "erf", x)


def tanh(x):
"""Take hyperbolic tanh of input x.
Expand Down
14 changes: 7 additions & 7 deletions python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ def _convert_args(self, expr, args, kwargs):
return args

if kwargs and not isinstance(expr, Function):
raise Exception("can only supply keyword parameters for a \
relay.Function, found {0}".format(expr))
raise Exception("can only supply keyword parameters for a "
"relay.Function, found {0}".format(expr))

params = expr.params
param_names = [p.name_hint for p in params]
Expand All @@ -182,16 +182,16 @@ def _convert_args(self, expr, args, kwargs):
if i < num_of_args:
if kwargs.get(name):
raise Exception(
"duplicate argument supplied in \
both positional args (at position: {0}), \
and keyword argument (with name: {1})".format(i, name))
"duplicate argument supplied in "
"both positional args (at position: {0}), "
"and keyword argument (with name: {1})".format(i, name))
else:
cargs.append(kwargs[name])

if len(cargs) != len(params):
raise Exception(
"insufficient arguments, expected" \
" {0}, provided {1}".format(len(cargs), len(params)))
"insufficient arguments, expected "
"{0}, provided {1}".format(len(cargs), len(params)))

return tuple(cargs)

Expand Down
11 changes: 10 additions & 1 deletion python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,16 @@ def get_int_tuple(self, key, default=RequiredAttr()):
"""
if key in self.attrs:
tshape = self.attrs[key]
return tuple(int(x.strip()) for x in tshape.strip('()[]').split(',') if x)
ret = []
for x in tshape.strip('()[]').split(','):
x = x.strip()
if not x:
continue
if x == "None":
ret.append(None)
else:
ret.append(int(x))
return tuple(ret)
if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key))
return default
Expand Down
26 changes: 18 additions & 8 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,17 @@ def _mx_fully_connected(inputs, attrs):
use_flatten = attrs.get_bool("flatten", True)
if has_flatten and use_flatten:
inputs[0] = _op.nn.batch_flatten(inputs[0])
data_shape = _infer_type(inputs[0]).checked_type.shape
if len(data_shape) > 2:
inputs[0] = _op.reverse_reshape(inputs[0], [-1, 0])
res = _op.nn.dense(inputs[0], inputs[1], units=units)
if use_bias:
assert len(inputs) == 3
res = _op.nn.bias_add(res, inputs[2], axis=-1)
if len(data_shape) > 2:
new_shape = data_shape[:-1]
new_shape.append(units)
res = _op.reshape(res, new_shape)
return res


Expand Down Expand Up @@ -241,8 +248,8 @@ def _mx_layer_norm(inputs, attrs):

def _mx_slice(inputs, attrs):
new_attrs = {}
begin = attrs.get_int_tuple('begin', None)
end = attrs.get_int_tuple('end', None)
begin = list(attrs.get_int_tuple('begin', None))
end = list(attrs.get_int_tuple('end', None))
stride = attrs.get_int_tuple('step', None)
if begin is None:
raise tvm.error.OpAttributeRequired(
Expand All @@ -251,11 +258,12 @@ def _mx_slice(inputs, attrs):
raise tvm.error.OpAttributeRequired(
'Attribute "end" not found in operator Slice.')
if None in begin:
raise tvm.error.OpAttributeInvalid(
'Value None in attribute "begin" of operator Slice is not valid.')
if None in end:
raise tvm.error.OpAttributeInvalid(
'Value None in attribute "end" of operator Slice is not valid.')
data_shape = _infer_type(inputs[0]).checked_type.shape
for i, beg in enumerate(begin):
if beg is None:
assert end[i] is None
begin[i] = 0
end[i] = data_shape[i]
new_attrs = {'begin': begin, 'end': end}
if stride is not None:
new_attrs['strides'] = stride
Expand Down Expand Up @@ -497,7 +505,8 @@ def _mx_arange(inputs, attrs):
'Attribute "repeat" is not supported in operator arange.')
new_attrs = {}
new_attrs["start"] = _expr.const(attrs.get_float("start", 0.0))
new_attrs["stop"] = _expr.const(attrs.get_float("stop"))
stop = attrs.get_str("stop", "None")
new_attrs["stop"] = None if stop == "None" else _expr.const(float(stop))
new_attrs["step"] = _expr.const(attrs.get_float("step", 1.0))
new_attrs["dtype"] = attrs.get_str("dtype", "float32")
return _op.arange(**new_attrs)
Expand Down Expand Up @@ -910,6 +919,7 @@ def _mx_one_hot(inputs, attrs):
_identity_list = [
"log",
"exp",
"erf",
"sqrt",
"floor",
"ceil",
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,6 +1261,7 @@ def _impl(inputs, attr, params):
'DepthToSpace' : _depth_to_space(),
'Equal' : _broadcast('equal'),
'Elu' : _elu(),
'Erf' : AttrCvt('erf'),
'Exp' : AttrCvt('exp'),
'ExpandDims' : _expand_dims(),
'Fill' : _fill(),
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
register_schedule("cos", schedule_broadcast)
register_schedule("sin", schedule_broadcast)
register_schedule("exp", schedule_broadcast)
register_schedule("erf", schedule_broadcast)
register_schedule("sqrt", schedule_broadcast)
register_schedule("rsqrt", schedule_broadcast)
register_schedule("sigmoid", schedule_broadcast)
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,22 @@ def exp(data):
return _make.exp(data)


def erf(data):
"""Compute elementwise error function of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.erf(data)


def sqrt(data):
"""Compute elementwise sqrt of data.
Expand Down
3 changes: 3 additions & 0 deletions src/codegen/intrin_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ namespace intrin {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.exp")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log")
.set_body(DispatchExtern<FloatSuffix>);

Expand Down
3 changes: 3 additions & 0 deletions src/codegen/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
.set_body(DispatchExtern<CUDAFastMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf")
.set_body(DispatchExtern<CUDAMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log")
.set_body(DispatchExtern<CUDAFastMath>);

Expand Down
3 changes: 3 additions & 0 deletions src/codegen/llvm/intrin_rule_nvptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.erf")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma")
.set_body(DispatchExternLibDevice);

Expand Down
16 changes: 16 additions & 0 deletions src/lang/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,22 @@ Expr Let::make(Var var, Expr value, Expr body) {
return Expr(node);
}

const char* Call::vectorizable_intrinsics[] = {
"floor", "ceil", "sign", "trunc", "fabs", "round", "exp", "tanh", "sqrt",
"log", "sin", "cos", "pow", ir::Call::shift_left, ir::Call::shift_right,
ir::Call::likely, ir::Call::popcount
};

bool Call::is_vectorizable() const {
size_t cnt = sizeof(Call::vectorizable_intrinsics) / sizeof(char*);
for (size_t i = 0; i < cnt; ++i) {
if (name == Call::vectorizable_intrinsics[i]) {
return true;
}
}
return false;
}

Expr Call::make(DataType type,
std::string name,
Array<Expr> args,
Expand Down
36 changes: 27 additions & 9 deletions src/pass/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,16 +268,34 @@ class Vectorizer : public IRMutator {
if (op->name == intrinsic::tvm_if_then_else) {
return MutateIfThenElseExpr_(op, e);
}
int lane = 0;
Array<Expr> new_args = MutateArray(op->args, &lane);

// normal code path.
if (op->args.same_as(new_args)) {
return e;
if (!op->is_vectorizable()) {
// Cannot vectorize this op
Array<Expr> new_args;
for (auto arg : op->args) {
auto new_arg = this->Mutate(arg);
if (new_arg.type().is_vector()) {
need_scalarize_ = true;
return e;
}
new_args.push_back(new_arg);
}
if (op->args.same_as(new_args)) {
return e;
} else {
return Call::make(
op->type, op->name, new_args, op->call_type, op->func, op->value_index);
}
} else {
return Call::make(
op->type.with_lanes(lane), op->name, new_args,
op->call_type, op->func, op->value_index);
int lane = 0;
Array<Expr> new_args = MutateArray(op->args, &lane);
// normal code path.
if (op->args.same_as(new_args)) {
return e;
} else {
return Call::make(
op->type.with_lanes(lane), op->name, new_args,
op->call_type, op->func, op->value_index);
}
}
}
// Load
Expand Down
12 changes: 12 additions & 0 deletions src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ RELAY_REGISTER_UNARY_OP("exp")
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp));


RELAY_REGISTER_UNARY_OP("erf")
.describe(R"code(Returns the error function value for input array, computed element-wise.
.. math::
\erf(x)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::erf));


RELAY_REGISTER_UNARY_OP("sqrt")
.describe(R"code(Returns the sqrt input array, computed element-wise.
Expand Down
9 changes: 9 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1844,6 +1844,14 @@ def test_forward_zeros_like():
_test_forward_zeros_like((2, 3, 11), "float32")
_test_forward_zeros_like((2, 3, 11), "float64")

def test_forward_erf():
ishape = (1, 3, 10, 10)
inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
tf.math.erf(in1)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Erf:0')

def _test_forward_reverse_v2(in_shape, axis, dtype):
np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype)
tf.reset_default_graph()
Expand Down Expand Up @@ -2244,6 +2252,7 @@ def test_forward_one_hot():
test_forward_log_softmax()
test_forward_bias_add()
test_forward_zeros_like()
test_forward_erf()

# Reductions
test_forward_argminmax()
Expand Down
2 changes: 2 additions & 0 deletions tests/python/relay/test_op_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
import numpy as np
import tvm
import scipy
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing import ctx_list
Expand Down Expand Up @@ -67,6 +68,7 @@ def check_single_op(opfunc, ref):

for opfunc, ref in [(tvm.relay.log, np.log),
(tvm.relay.exp, np.exp),
(tvm.relay.erf, scipy.special.erf),
(tvm.relay.sqrt, np.sqrt),
(tvm.relay.rsqrt, rsqrt),
(tvm.relay.sigmoid, sigmoid),
Expand Down
1 change: 1 addition & 0 deletions topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ using namespace tvm;
}

TOPI_DECLARE_UNARY_OP(exp);
TOPI_DECLARE_UNARY_OP(erf);
TOPI_DECLARE_UNARY_OP(sigmoid);
TOPI_DECLARE_UNARY_OP(sqrt);
TOPI_DECLARE_UNARY_OP(log);
Expand Down
17 changes: 17 additions & 0 deletions topi/python/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,23 @@ def exp(x):
return tvm.compute(x.shape, lambda *i: tvm.exp(x(*i)))


@tvm.tag_scope(tag=tag.ELEMWISE)
def erf(x):
"""Take gauss error function of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.erf(x(*i)))


@tvm.tag_scope(tag=tag.ELEMWISE)
def tanh(x):
"""Take hyperbolic tanh of input x.
Expand Down
Loading

0 comments on commit 2d58005

Please sign in to comment.