Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LANG] Support for Tuple Inputs of Reducer and ComputeOp #175

Merged
merged 16 commits into from
Jun 11, 2017
Merged
37 changes: 23 additions & 14 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,27 @@ struct CommReducer : public NodeRef {
* binary operator with identity element
*/
struct CommReducerNode : public Node {
/*! \brief The arguments of reducer */
Array<Var> args;
/*! \brief The left argument of reducer */
Array<Var> lhs;
/*! \brief The right argument of reducer */
Array<Var> rhs;
/*! \brief The result of reducer */
Expr result;
Array<Expr> result;
/*!
* \brief The identity element of reducer, which leaves other
* elements unchanged when combined with it, with respect to
* the binary operation of this reducer uses.
*/
Expr identity_element;
Array<Expr> identity_element;
/*! \brief Function call operator to combine a and b */
Expr operator()(Expr a, Expr b) const;
Array<Expr> operator()(Array<Expr> a, Array<Expr> b) const;
/*! \brief construct CommReducer from args, result and identity_element */
static CommReducer make(Array<Var> args, Expr result, Expr identity_element);
static CommReducer make(Array<Var> lhs, Array<Var> rhs,
Array<Expr> result, Array<Expr> identity_element);

void VisitAttrs(AttrVisitor* v) final {
v->Visit("args", &args);
v->Visit("lhs", &lhs);
v->Visit("rhs", &rhs);
v->Visit("result", &result);
v->Visit("identity_element", &identity_element);
}
Expand All @@ -84,26 +88,30 @@ struct Reduce : public ExprNode<Reduce> {
/*! \brief The commutative combiner */
CommReducer combiner;
/*! \brief The source operand */
Expr source;
Array<Expr> source;
/*! \brief The reduction axis */
Array<IterVar> axis;
/*!
* \brief Predicate on the reduction
* Only add the body to reduction if condition is true.
*/
Expr condition;
/*! \brief the index of this reduce node */
int value_index;

/*! \brief construct expr from op and rdom */
static Expr make(CommReducer combiner,
Expr src,
Array<Expr> src,
Array<IterVar> rdom,
Expr condition = const_true());
Expr condition,
int value_index);

void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
v->Visit("source", &source);
v->Visit("axis", &axis);
v->Visit("condition", &condition);
v->Visit("value_index", &value_index);
}
static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "Reduce";
Expand Down Expand Up @@ -292,11 +300,12 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit";
/*!
* \brief See pesudo code
*
* Expr tvm_thread_allreduce(CommReducer combiner, Expr value, Expr cond,
* Var thread_idx1, thread_idx2...) {
* void tvm_thread_allreduce(UIntImm size, Expr source0, ..., Expr cond,
* Var reduce_temp0, .., Var thread_idx1, ...) {
* // constraint by the other thread_idx remain the same.
* return reduce(combiner, value, cond,
* over [thread_idx1, thread_idx2] passed by any caller)
* // reduce_temp is used to save intermediate result.
* reduce_temp0, ... = reduce(combiner, source0, ..., cond
* over [thread_idx1, thread_idx2] passed by any caller)
* }
*/
constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce";
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ Expr Substitute(Expr expr, const Map<Var, Expr>& value_map);
/*!
* \brief inline all calls of f in stmt.
*
* \param stmt The statement to apply inline optimization.
* \param f The function reference to be inlined
* \param args The arguments variable of the function.
* \param body The defintion body of the function.
* \param stmt The statement to apply inline optimization.
* \param body The definition body of the function.
* \return The result stmt
*
* \note All the passes in this file uses SSA form and outputs SSA form.
Expand Down
16 changes: 14 additions & 2 deletions include/tvm/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class ComputeOpNode : public OperationNode {
/*! \brief IterVar on each reduction axis, if the body is a Reduce */
Array<IterVar> reduce_axis;
/*! \brief the compute expression */
Expr body;
Array<Expr> body;
/*! \brief constructor */
ComputeOpNode() {}
// override functions
Expand Down Expand Up @@ -218,7 +218,7 @@ class ComputeOpNode : public OperationNode {
}
static Operation make(std::string name,
Array<IterVar> axis,
Expr body);
Array<Expr> body);

static constexpr const char* _type_key = "ComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode);
Expand Down Expand Up @@ -358,6 +358,9 @@ class ExternOpNode : public OperationNode {
/*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>;

/*! \brief The compute function to specify the inputs source of Tensors */
using FBatchCompute = std::function<Array<Expr> (const Array<Var>& i)>;

/*!
* \brief create a place holder tensor.
* \param shape The shape of the tensor.
Expand All @@ -377,6 +380,15 @@ Tensor placeholder(Array<Expr> shape,
*/
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");

/*!
* \brief Construct a new tensor by computing over shape,
* using the computation rule: result_tensor[axis] = fcompute(axis)
* \param shape Shape of the tensor.
* \param fcompute The compute function to create the tensors.
* \param name The optional name of the tensor.
*/
Array<Tensor> compute(Array<Expr> shape, FBatchCompute fcompute, std::string name = "tensor");

/*!
* \brief Construct new tensors by scan.
*
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,15 @@ class Schedule : public NodeRef {
/*!
* \brief Factor a reduction axis in tensor's schedule to be an explicit axis.
* This will create a new stage that generated the new tensor with axis
* as the first dimension. The tensor's body wil be rewriten as a reduction
* as the first dimension. The tensor's body will be rewritten as a reduction
* over the factored tensor.
*
* \param tensor The tensor to be factored.
* \param axis The reduction axis in tensor's schedule to be factored.
* \return The created factored tensor.
* \return The created factored tensors.
*/
Tensor rfactor(const Tensor& tensor,
const IterVar& axis);
Array<Tensor> rfactor(const Tensor& tensor,
const IterVar& axis);
/*!
* \brief Normalize the schedule.
* This is needed before bound inference.
Expand Down
51 changes: 41 additions & 10 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,14 @@ def compute(shape, fcompute, name="compute"):

dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape)]
body = fcompute(*[v.var for v in dim_var])
if not isinstance(body, (list, tuple)):
body = [body]
body = convert(body)
op_node = _api_internal._ComputeOp(
name, dim_var, body)
return op_node.output(0)
num = op_node.num_outputs
outputs = tuple(op_node.output(i) for i in range(num))
return outputs[0] if num == 1 else outputs


def scan(init, update, state_placeholder, inputs=None, name="scan"):
Expand Down Expand Up @@ -525,18 +529,45 @@ def _reduce_directly(*args):
return res

def _make_reduce(expr, axis, where=None):
expr = convert(expr)
dtype = expr.dtype
code = fcombine.__code__
assert fcombine.__code__.co_argcount == 2
arg_vars = [var(name, dtype) for name in code.co_varnames]
result = fcombine(*[v for v in arg_vars])
expr = convert(expr)
if isinstance(expr, _collections.Array):
size = len(expr)
larr = []
rarr = []
dtypes = []
for i in range(size):
dtype = expr[i].dtype
dtypes.append(dtype)
lname = code.co_varnames[0] + '_' + str(i)
larr.append(var(lname, dtype))
rname = code.co_varnames[1] + '_' + str(i)
rarr.append(var(rname, dtype))
lhs = convert(larr)
rhs = convert(rarr)
result = fcombine(lhs, rhs)
id_elem = fidentity(*dtypes)
else:
assert isinstance(expr, _expr.Expr)
size = 1
dtype = expr.dtype
lvar = var(code.co_varnames[0], dtype)
rvar = var(code.co_varnames[1], dtype)
result = [fcombine(lvar, rvar)]
id_elem = [fidentity(dtype)]
lhs = convert([lvar])
rhs = convert([rvar])
expr = convert([expr])
result = convert(result)
id_elem = fidentity(dtype)
assert isinstance(id_elem, _expr.Expr)
combiner = _make.CommReducer(arg_vars, result, id_elem)
axis = axis if isinstance(axis, list) else [axis]
return _make.Reduce(combiner, expr, axis, where)
id_elem = convert(id_elem)
combiner = _make.CommReducer(lhs, rhs, result, id_elem)
axis = convert(axis if isinstance(axis, list) else [axis])
if where is None:
where = convert(True)
outputs = tuple(_make.Reduce(combiner, expr, axis, where, i)
for i in range(size))
return outputs[0] if size == 1 else outputs

def reducer(expr, axis, where=None, *args):
if isinstance(axis, (_schedule.IterVar, list)):
Expand Down
7 changes: 4 additions & 3 deletions python/tvm/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def rfactor(self, tensor, axis):
""" Factor a reduction axis in tensor's schedule to be an explicit axis.

This will create a new stage that generated the new tensor with axis
as the first dimension. The tensor's body wil be rewriten as a reduction
as the first dimension. The tensor's body will be rewritten as a reduction
over the factored tensor.

Parameters
Expand All @@ -193,10 +193,11 @@ def rfactor(self, tensor, axis):

Returns
-------
tfactor : Tensor
tfactor : Tensor or Array of Tensor
The created factored tensor.
"""
return _api_internal._ScheduleRFactor(self, tensor, axis)
factored = _api_internal._ScheduleRFactor(self, tensor, axis)
return factored[0] if len(factored) == 1 else factored


@register_node
Expand Down
10 changes: 6 additions & 4 deletions src/api/api_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,13 @@ TVM_REGISTER_API("make.Call")
});

TVM_REGISTER_API("make.CommReducer")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = CommReducerNode::make(args[0], args[1], args[2]);
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = CommReducerNode::make(args[0],
args[1],
args[2],
args[3]);
});


// make from two arguments
#define REGISTER_MAKE1(Node) \
TVM_REGISTER_API("make."#Node) \
Expand Down Expand Up @@ -112,7 +114,7 @@ TVM_REGISTER_API("make.CommReducer")
*ret = Node::make(a, b); \
})

REGISTER_MAKE4(Reduce);
REGISTER_MAKE5(Reduce);
REGISTER_MAKE4(AttrStmt);

REGISTER_MAKE2(IntImm);
Expand Down
15 changes: 9 additions & 6 deletions src/lang/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,27 @@ Expr sum(Expr source, Array<IterVar> rdom) {
Var x("x"), y("y");
Expr result = ir::Add::make(x, y);
Expr identity_element = make_zero(source.type());
ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element);
return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true));
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}

Expr max(Expr source, Array<IterVar> rdom) {
Var x("x"), y("y");
Expr result = ir::Max::make(x, y);
Expr identity_element = source.type().min();
ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element);
return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true));
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}

Expr min(Expr source, Array<IterVar> rdom) {
Var x("x"), y("y");
Expr result = ir::Min::make(x, y);
Expr identity_element = source.type().max();
ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element);
return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true));
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}

std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*)
Expand Down
Loading