Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LANG] Support for Tuple Inputs of Reducer and ComputeOp #175

Merged
merged 16 commits into from
Jun 11, 2017
Merged
28 changes: 18 additions & 10 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,27 @@ struct CommReducer : public NodeRef {
* binary operator with identity element
*/
struct CommReducerNode : public Node {
/*! \brief The arguments of reducer */
Array<Var> args;
/*! \brief The left argument of reducer */
Array<Var> lhs;
/*! \brief The right argument of reducer */
Array<Var> rhs;
/*! \brief The result of reducer */
Expr result;
Array<Expr> result;
/*!
* \brief The identity element of reducer, which leaves other
* elements unchanged when combined with it, with respect to
* the binary operation of this reducer uses.
*/
Expr identity_element;
Array<Expr> identity_element;
/*! \brief Function call operator to combine a and b */
Expr operator()(Expr a, Expr b) const;
Array<Expr> operator()(Array<Expr> a, Array<Expr> b) const;
/*! \brief construct CommReducer from args, result and identity_element */
static CommReducer make(Array<Var> args, Expr result, Expr identity_element);
static CommReducer make(Array<Var> lhs, Array<Var> rhs,
Array<Expr> result, Array<Expr> identity_element);

void VisitAttrs(AttrVisitor* v) final {
v->Visit("args", &args);
v->Visit("lhs", &lhs);
v->Visit("rhs", &rhs);
v->Visit("result", &result);
v->Visit("identity_element", &identity_element);
}
Expand All @@ -84,26 +88,30 @@ struct Reduce : public ExprNode<Reduce> {
/*! \brief The commutative combiner */
CommReducer combiner;
/*! \brief The source operand */
Expr source;
Array<Expr> source;
/*! \brief The reduction axis */
Array<IterVar> axis;
/*!
* \brief Predicate on the reduction
* Only add the body to reduction if condition is true.
*/
Expr condition;
/*! \brief the index of this reduce node */
int value_index;

/*! \brief construct expr from op and rdom */
static Expr make(CommReducer combiner,
Expr src,
Array<Expr> src,
Array<IterVar> rdom,
Expr condition = const_true());
Expr condition = const_true(),
int value_index = 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the default value, to be safe

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

never mind, forget this comment


void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
v->Visit("source", &source);
v->Visit("axis", &axis);
v->Visit("condition", &condition);
v->Visit("value_index", &value_index);
}
static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "Reduce";
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ Expr Substitute(Expr expr, const Map<Var, Expr>& value_map);
/*!
* \brief inline all calls of f in stmt.
*
* \param stmt The statement to apply inline optimization.
* \param f The function reference to be inlined
* \param args The arguments variable of the function.
* \param body The defintion body of the function.
* \param stmt The statement to apply inline optimization.
* \param body The definition body of the function.
* \return The result stmt
*
* \note All the passes in this file uses SSA form and outputs SSA form.
Expand Down
16 changes: 14 additions & 2 deletions include/tvm/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class ComputeOpNode : public OperationNode {
/*! \brief IterVar on each reduction axis, if the body is a Reduce */
Array<IterVar> reduce_axis;
/*! \brief the compute expression */
Expr body;
Array<Expr> body;
/*! \brief constructor */
ComputeOpNode() {}
// override functions
Expand Down Expand Up @@ -218,7 +218,7 @@ class ComputeOpNode : public OperationNode {
}
static Operation make(std::string name,
Array<IterVar> axis,
Expr body);
Array<Expr> body);

static constexpr const char* _type_key = "ComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode);
Expand Down Expand Up @@ -358,6 +358,9 @@ class ExternOpNode : public OperationNode {
/*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>;

/*! \brief The compute function to specify the inputs source of Tensors */
using FBatchCompute = std::function<Array<Expr> (const Array<Var>& i)>;

/*!
* \brief create a place holder tensor.
* \param shape The shape of the tensor.
Expand All @@ -377,6 +380,15 @@ Tensor placeholder(Array<Expr> shape,
*/
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");

/*!
* \brief Construct a new tensor by computing over shape,
* using the computation rule: result_tensor[axis] = fcompute(axis)
* \param shape Shape of the tensor.
* \param fcompute The compute function to create the tensors.
* \param name The optional name of the tensor.
*/
Array<Tensor> compute(Array<Expr> shape, FBatchCompute fcompute, std::string name = "tensor");

/*!
* \brief Construct new tensors by scan.
*
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ class Schedule : public NodeRef {
/*!
* \brief Factor a reduction axis in tensor's schedule to be an explicit axis.
* This will create a new stage that generated the new tensor with axis
* as the first dimension. The tensor's body wil be rewriten as a reduction
* as the first dimension. The tensor's body will be rewritten as a reduction
* over the factored tensor.
*
* \param tensor The tensor to be factored.
Expand Down
56 changes: 46 additions & 10 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,18 @@ def compute(shape, fcompute, name="compute"):

dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape)]
body = fcompute(*[v.var for v in dim_var])
if not isinstance(body, (list, tuple)):
body = [body]
body = convert(body)
op_node = _api_internal._ComputeOp(
name, dim_var, body)
return op_node.output(0)
outputs = []
num = op_node.num_outputs
if num == 1:
return op_node.output(0)
for i in range(num):
outputs.append(op_node.output(i))
return tuple(outputs)


def scan(init, update, state_placeholder, inputs=None, name="scan"):
Expand Down Expand Up @@ -525,18 +533,46 @@ def _reduce_directly(*args):
return res

def _make_reduce(expr, axis, where=None):
expr = convert(expr)
dtype = expr.dtype
code = fcombine.__code__
assert fcombine.__code__.co_argcount == 2
arg_vars = [var(name, dtype) for name in code.co_varnames]
result = fcombine(*[v for v in arg_vars])
expr = convert(expr)
if isinstance(expr, _collections.Array):
size = len(expr)
larr = []
rarr = []
dtypes = []
for i in range(size):
dtype = expr[i].dtype
dtypes.append(dtype)
lname = code.co_varnames[0] + '_' + str(i)
larr.append(var(lname, dtype))
rname = code.co_varnames[1] + '_' + str(i)
rarr.append(var(rname, dtype))
lhs = convert(larr)
rhs = convert(rarr)
result = fcombine(lhs, rhs)
id_elem = fidentity(*dtypes)
else:
assert isinstance(expr, _expr.Expr)
size = 1
dtype = expr.dtype
lvar = var(code.co_varnames[0], dtype)
rvar = var(code.co_varnames[1], dtype)
result = [fcombine(lvar, rvar)]
id_elem = [fidentity(dtype)]
lhs = convert([lvar])
rhs = convert([rvar])
expr = convert([expr])
result = convert(result)
id_elem = fidentity(dtype)
assert isinstance(id_elem, _expr.Expr)
combiner = _make.CommReducer(arg_vars, result, id_elem)
axis = axis if isinstance(axis, list) else [axis]
return _make.Reduce(combiner, expr, axis, where)
id_elem = convert(id_elem)
combiner = _make.CommReducer(lhs, rhs, result, id_elem)
axis = convert(axis if isinstance(axis, list) else [axis])
if where is None:
where = convert(True)
if size == 1:
return _make.Reduce(combiner, expr, axis, where, 0)
return [_make.Reduce(combiner, expr, axis, where, i)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to tuple

for i in range(size)]

def reducer(expr, axis, where=None, *args):
if isinstance(axis, (_schedule.IterVar, list)):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def rfactor(self, tensor, axis):
""" Factor a reduction axis in tensor's schedule to be an explicit axis.

This will create a new stage that generated the new tensor with axis
as the first dimension. The tensor's body wil be rewriten as a reduction
as the first dimension. The tensor's body will be rewritten as a reduction
over the factored tensor.

Parameters
Expand Down
10 changes: 6 additions & 4 deletions src/api/api_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,13 @@ TVM_REGISTER_API("make.Call")
});

TVM_REGISTER_API("make.CommReducer")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = CommReducerNode::make(args[0], args[1], args[2]);
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = CommReducerNode::make(args[0],
args[1],
args[2],
args[3]);
});


// make from two arguments
#define REGISTER_MAKE1(Node) \
TVM_REGISTER_API("make."#Node) \
Expand Down Expand Up @@ -112,7 +114,7 @@ TVM_REGISTER_API("make.CommReducer")
*ret = Node::make(a, b); \
})

REGISTER_MAKE4(Reduce);
REGISTER_MAKE5(Reduce);
REGISTER_MAKE4(AttrStmt);

REGISTER_MAKE2(IntImm);
Expand Down
15 changes: 9 additions & 6 deletions src/lang/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,27 @@ Expr sum(Expr source, Array<IterVar> rdom) {
Var x("x"), y("y");
Expr result = ir::Add::make(x, y);
Expr identity_element = make_zero(source.type());
ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element);
return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true));
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true));
}

Expr max(Expr source, Array<IterVar> rdom) {
Var x("x"), y("y");
Expr result = ir::Max::make(x, y);
Expr identity_element = source.type().min();
ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element);
return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true));
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true));
}

Expr min(Expr source, Array<IterVar> rdom) {
Var x("x"), y("y");
Expr result = ir::Min::make(x, y);
Expr identity_element = source.type().max();
ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element);
return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true));
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true));
}

std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*)
Expand Down
50 changes: 30 additions & 20 deletions src/lang/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <ir/IR.h>
#include <ir/IRPrinter.h>
#include <memory>
#include "../pass/ir_util.h"

namespace Halide {
namespace Internal {
Expand All @@ -25,23 +26,20 @@ void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const {
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) {
p->stream << "reduce(combiner="
<< op->combiner
<< ", ";
p->print(op->source);
<< op->combiner;
p->stream << ", source=" << op->source;
p->stream << ", axis=" << op->axis;
if (!is_const(op->condition, 1)) {
p->stream << ", where=" << op->condition;
}
p->stream << ", where=" << op->condition;
p->stream << ", value_index=" << op->value_index;
p->stream << ")";
});

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<CommReducerNode>([](const CommReducerNode *op, IRPrinter *p) {
p->stream << "comm_reducer(result="
<< op->result
<< ", args=" << op->args
<< ", identity_element="
<< op->identity_element
p->stream << "comm_reducer(result=" << op->result
<< ", lhs=" << op->lhs
<< ", rhs=" << op->rhs
<< ", identity_element=" << op->identity_element
<< ")";
});
} // namespace Internal
Expand All @@ -50,23 +48,34 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
namespace tvm {
namespace ir {

CommReducer CommReducerNode::make(Array<Var> args, Expr result, Expr identity_element) {
CommReducer CommReducerNode::make(Array<Var> lhs,
Array<Var> rhs,
Array<Expr> result,
Array<Expr> identity_element) {
auto node = std::make_shared<CommReducerNode>();
node->args = args;
node->lhs = lhs;
node->rhs = rhs;
node->result = result;
node->identity_element = identity_element;
return CommReducer(node);
}

Expr CommReducerNode::operator()(Expr a, Expr b) const {
Array<Expr> CommReducerNode::operator()(Array<Expr> a, Array<Expr> b) const {
CHECK_EQ(a.size(), b.size());
CHECK_EQ(lhs.size(), a.size());
CHECK_EQ(rhs.size(), b.size());
Map<Var, Expr> value_map;
value_map.Set(args[0], a);
value_map.Set(args[1], b);
return Substitute(result, value_map);
for (size_t i = 0; i < a.size(); ++i) {
value_map.Set(lhs[i], a[i]);
value_map.Set(rhs[i], b[i]);
}
return UpdateArray(result, [&value_map] (const Expr& e) {
return Substitute(e, value_map);
});
}

Expr Reduce::make(CommReducer combiner, Expr source,
Array<IterVar> axis, Expr condition) {
Expr Reduce::make(CommReducer combiner, Array<Expr> source,
Array<IterVar> axis, Expr condition, int value_index) {
for (size_t i = 0; i < axis.size(); ++i) {
CHECK_EQ(axis[i]->iter_type, kCommReduce)
<< "Can only take axis created by reduce_axis";
Expand All @@ -79,11 +88,12 @@ Expr Reduce::make(CommReducer combiner, Expr source,
for (size_t i = 0; i < axis.size(); ++i) {
CHECK(axis[i].defined());
}
n->type = source.type();
n->type = source[value_index].type();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if argument is passing by value, do std::move to save copy constructor

n->combiner = combiner;
n->source = source;
n->axis = axis;
n->condition = condition;
n->value_index = value_index;
return Expr(n);
}

Expand Down
Loading