Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SCHEDULE] Allow mutate dataflow during schedule phase #44

Merged
merged 1 commit into from
Feb 17, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions include/tvm/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ using FCompute = std::function<Expr (const Array<Var>& i)>;
* \param dtype the data type of the tensor.
* \param name The name of the Tensor.
*/
Tensor Placeholder(Array<Expr> shape,
Tensor placeholder(Array<Expr> shape,
Type dtype = Float(32),
std::string name = "placeholder");

Expand All @@ -147,7 +147,7 @@ Tensor Placeholder(Array<Expr> shape,
* \param fcompute The compute function to create the tensor.
* \param name The optional name of the tensor.
*/
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");

/*!
* \brief Construct new tensors by scan over scan_axis.
Expand All @@ -158,36 +158,36 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"
* \param state_placeholder The placeholder for the states.
* \param name The optional name of the tensor.
*/
Array<Tensor> Scan(IterVar scan_axis,
Array<Tensor> scan(IterVar scan_axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
std::string name = "scan");

// same as compute, specialized for different fcompute function
inline Tensor Compute(Array<Expr> shape,
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
return Compute(shape, fc, name);
return compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
return Compute(shape, fc, name);
return compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
return Compute(shape, fc, name);
return compute(shape, fc, name);
}
inline Tensor Compute(Array<Expr> shape,
inline Tensor compute(Array<Expr> shape,
std::function<Expr(Var, Var, Var, Var)> f,
std::string name = "tensor") {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
return Compute(shape, fc, name);
return compute(shape, fc, name);
}

} // namespace tvm
Expand Down
70 changes: 62 additions & 8 deletions include/tvm/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ class Stage : public NodeRef {
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor);
/*!
* \brief Specify thread launching group in
* outer most scope of the stage.
* This is only valid for composite operators.
* \param threads The threads to be launched.
*/
Stage& outermost_threads(Array<IterVar> threads);
/*!
* \brief Vectorize iteration.
* \param var The axis to be vectorized.
Expand Down Expand Up @@ -179,6 +186,28 @@ class Schedule : public NodeRef {
Stage operator[](const Tensor& tensor) {
return this->operator[](tensor->op);
}
/*!
* \brief create a cache read of original tensor for readers.
* This will mutate the body of the readers.
* A new stage will be created for the tensor.
* \param tensor The tensor cached.
* \param scope The scope of the cache.
* \param readers The readers to redirect to the tensor.
* \return The created tensor.
*/
Tensor cache_read(const Tensor& tensor,
const std::string& scope,
const Array<Operation>& readers);
/*!
* \brief Create a cache write tensor for producing tensor.
* The the tensor will take over body of original tensor op.
* The original tensor's body will be changed to an identity read
* from the corresponding cache.
* \param tensor The tensor to be produced.
* \param scope The scope of the storage.
* \return The created tensor.
*/
Tensor cache_write(const Tensor& tensor, const std::string& scope);
/*!
* \brief Normalize the schedule.
* This is needed before bound inference.
Expand All @@ -193,6 +222,11 @@ class Schedule : public NodeRef {
* \return the pointer to the internal node container
*/
inline const ScheduleNode* operator->() const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline ScheduleNode* operator->();
// declare container type
using ContainerType = ScheduleNode;
};
Expand Down Expand Up @@ -244,17 +278,28 @@ class IterVarAttr : public NodeRef {
*/
class StageNode : public Node {
public:
/*! \brief The operation to be scheduled */
Operation op;
/*! \brief The thread scope level of the stage */
std::string scope;
/*! \brief The operation of stage, can be different from original op. */
Operation op;
/*!
* \brief The original operator.
* The op field can change during schedule to alternate the dataflow,
* while origin_op remains fixed.
*/
Operation origin_op;
/*! \brief All the nodes in the iter var */
Array<IterVar> all_iter_vars;
/*!
* \brief The current leafs in the schedule.
* Operations can only be performed in leaves.
*/
Array<IterVar> leaf_iter_vars;
/*!
* \brief Specify threads to be launched at the stage.
* This is only valid for composite ops such as Scan.
*/
Array<IterVar> outermost_threads;
/*! \brief The relation bwteen of IterVars */
Array<IterVarRelation> relations;
/*! \brief additional attributes about iter var. */
Expand All @@ -265,17 +310,22 @@ class StageNode : public Node {
IterVar attach_ivar;
/*! \brief The stage this node attaches to */
Stage attach_stage;
/*! \brief Whether this is an output stage */
bool is_output{false};

void VisitAttrs(AttrVisitor* v) final {
v->Visit("scope", &scope);
v->Visit("op", &op);
v->Visit("origin_op", &origin_op);
v->Visit("all_iter_vars", &all_iter_vars);
v->Visit("leaf_iter_vars", &leaf_iter_vars);
v->Visit("outermost_threads", &outermost_threads);
v->Visit("relations", &relations);
v->Visit("iter_var_attrs", &iter_var_attrs);
v->Visit("attach_type", &attach_type);
v->Visit("attach_ivar", &attach_ivar);
v->Visit("attach_stage", &attach_stage);
v->Visit("is_output", &is_output);
}

static constexpr const char* _type_key = "Stage";
Expand All @@ -285,18 +335,18 @@ class StageNode : public Node {
/*! \brief node container for schedule */
class ScheduleNode : public Node {
public:
/*! \brief The root operations */
Array<Operation> roots;
/*! \brief The output operations in original data flow graph */
Array<Operation> outputs;
/*!
* \brief list of all stages for non-placeholder ops
* The stage are ordered in PostDFS order of their op.
* \brief list of all stages for non-placeholder ops.
* The stages are sorted in dependency order.
*/
Array<Stage> stages;
/*! \brief map of operation to the stages */
Map<Operation, Stage> stage_map;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("roots", &roots);
v->Visit("outputs", &outputs);
v->Visit("stages", &stages);
v->Visit("stage_map", &stage_map);
}
Expand Down Expand Up @@ -412,12 +462,16 @@ inline StageNode* Stage::operator->() {

inline bool Stage::is_scheduled() const {
const StageNode* n = operator->();
return !(n->relations.empty() && n->attach_type == kNone);
return !(n->relations.empty() && n->attach_type == kNone &&
n->all_iter_vars.same_as(n->leaf_iter_vars));
}

inline const ScheduleNode* Schedule::operator->() const {
return static_cast<const ScheduleNode*>(node_.get());
}
inline ScheduleNode* Schedule::operator->() {
return static_cast<ScheduleNode*>(node_.get());
}

inline const IterVarRelationNode* IterVarRelation::operator->() const {
return static_cast<const IterVarRelationNode*>(node_.get());
Expand Down
1 change: 0 additions & 1 deletion python/tvm/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def build(sch,
arg_list.append(x)
else:
raise ValueError("args must be Tensor, Buffer or Var")

# lowering
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
Expand Down
60 changes: 60 additions & 0 deletions python/tvm/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ._ctypes._node import NodeBase, register_node
from . import _api_internal
from . import tensor as _tensor
from . import collections as _collections

@register_node
class Buffer(NodeBase):
Expand Down Expand Up @@ -41,6 +42,53 @@ def normalize(self):
"""
_api_internal._ScheduleNormalize(self)

def cache_read(self, tensor, scope, readers):
"""Create a cache read of original tensor for readers.

This will mutate the body of the readers.
A new cache stage will be created for the tensor.
Call this before doing any split/fuse schedule.

Parameters
----------
tensor : Tensor
The tensor to be cached.
scope : str
The scope of cached
readers : list of Tensor or Operation
The readers to read the cache.

Returns
-------
cache : Tensor
The created cache tensor.
"""
if isinstance(readers, (_tensor.Tensor, _tensor.Operation)):
readers = [readers]
readers = [t.op if isinstance(t, _tensor.Tensor) else t for t in readers]
return _api_internal._ScheduleCacheRead(self, tensor, scope, readers)

def cache_write(self, tensor, scope):
"""Create a cache write of original tensor, before storing into tensor.

This will mutate the body of the tensor.
A new cache stage will created before feed into the tensor.

Parameters
----------
tensor : Tensor
The tensor to be feed to.
scope : str
The scope of cached

Returns
-------
cache : Tensor
The created cache tensor.
"""
return _api_internal._ScheduleCacheWrite(self, tensor, scope)


@register_node
class Stage(NodeBase):
"""A Stage represents schedule for one operation."""
Expand Down Expand Up @@ -104,6 +152,18 @@ def set_scope(self, scope):
"""
return _api_internal._StageSetScope(self, scope)

def outermost_threads(self, threads):
"""Force launch threads at outermost scope of the stage.

Parameters
----------
threads : list of threads
The threads to be launched.
"""
if isinstance(threads, _collections.IterVar):
threads = [threads]
_api_internal._StageOutermostThreads(self, threads)

def compute_at(self, parent, scope):
"""Attach the stage at parent's scope

Expand Down
20 changes: 19 additions & 1 deletion src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ TVM_REGISTER_API(_TensorHash)

TVM_REGISTER_API(_Placeholder)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Placeholder(args[0],
*ret = placeholder(args[0],
args[1],
args[2]);
});
Expand Down Expand Up @@ -262,6 +262,12 @@ TVM_REGISTER_API(_StageTile)
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
});

TVM_REGISTER_API(_StageOutermostThreads)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.outermost_threads(args[1]);
});

TVM_REGISTER_API(_StageUnroll)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
Expand All @@ -280,4 +286,16 @@ TVM_REGISTER_API(_ScheduleNormalize)
.normalize();
});

TVM_REGISTER_API(_ScheduleCacheRead)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
.cache_read(args[1], args[2], args[3]);
});

TVM_REGISTER_API(_ScheduleCacheWrite)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
.cache_write(args[1], args[2]);
});

} // namespace tvm
6 changes: 3 additions & 3 deletions src/lang/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Operation PlaceholderOpNode::make(std::string name,



Tensor Placeholder(Array<Expr> shape, Type dtype, std::string name) {
Tensor placeholder(Array<Expr> shape, Type dtype, std::string name) {
return PlaceholderOpNode::make(name, shape, dtype).output(0);
}

Expand Down Expand Up @@ -82,7 +82,7 @@ Array<Expr> ComputeOpNode::output_shape(size_t i) const {
return Array<Expr>(shape);
}

Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name) {
auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension.
size_t ndim = shape.size();
Expand Down Expand Up @@ -188,7 +188,7 @@ Operation ScanOpNode::make(std::string name,
return Operation(n);
}

Array<Tensor> Scan(IterVar scan_axis,
Array<Tensor> scan(IterVar scan_axis,
Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
Expand Down
Loading