From 685aca44964549d9a750e9c910111742da9b9e42 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 15 Feb 2017 22:46:22 -0800 Subject: [PATCH] [SCAN/Refactor] Refactor scan interface, enable fix point analysis. --- include/tvm/operation.h | 4 +- include/tvm/schedule.h | 4 +- include/tvm/tensor.h | 2 + python/tvm/addon/nvcc_compiler.py | 9 +- python/tvm/api.py | 11 +- python/tvm/build.py | 3 +- src/api/api_schedule.cc | 3 + src/arithmetic/int_set.cc | 10 +- src/lang/operation.cc | 28 +- src/pass/inline.cc | 4 +- src/schedule/bound.cc | 117 +++---- src/schedule/graph.cc | 307 ++++++++++++++++- src/schedule/graph.h | 54 +++ src/schedule/schedule_dataflow_rewrite.cc | 312 ++++++++++++++++++ src/schedule/schedule_lang.cc | 265 +++------------ src/schedule/schedule_ops.cc | 51 ++- tests/python/integration/test_scan.py | 6 +- tests/python/unittest/test_lang_tensor.py | 5 +- .../unittest/test_schedule_bound_inference.py | 31 +- tests/python/unittest/test_schedule_graph.py | 101 ++++++ .../unittest/test_schedule_schedule_ops.py | 28 +- 21 files changed, 977 insertions(+), 378 deletions(-) create mode 100644 src/schedule/schedule_dataflow_rewrite.cc create mode 100644 tests/python/unittest/test_schedule_graph.py diff --git a/include/tvm/operation.h b/include/tvm/operation.h index 745277308c70..85b289f5d220 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -152,14 +152,12 @@ Tensor compute(Array shape, FCompute fcompute, std::string name = "tensor" /*! * \brief Construct new tensors by scan over scan_axis. * - * \param scan_axis The iteration representing the scan. * \param init The intialize tensor of first K steps. * \param update The update tensor indicated the updated result after each timestamp. * \param state_placeholder The placeholder for the states. * \param name The optional name of the tensor. */ -Array scan(IterVar scan_axis, - Array init, +Array scan(Array init, Array update, Array state_placeholder, std::string name = "scan"); diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 18407567744a..c6bbc65660c4 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -26,7 +26,9 @@ enum AttachType : int { kNone = 0, kRoot = 1, kInline = 2, - kScope = 3 + kInlinedAlready = 3, + kScope = 4, + kScanUpdate = 5 }; /*! \brief IterVar type */ diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h index 92786b33106d..11766cd005d5 100644 --- a/include/tvm/tensor.h +++ b/include/tvm/tensor.h @@ -175,6 +175,8 @@ class OperationNode : public FunctionBaseNode { virtual Type output_dtype(size_t i) const = 0; /*! \return shape of i-th output */ virtual Array output_shape(size_t i) const = 0; + + static constexpr const char* _type_key = "Operation"; }; // Implementations of inline functions diff --git a/python/tvm/addon/nvcc_compiler.py b/python/tvm/addon/nvcc_compiler.py index a1c2b938d58c..7895a2b98c56 100644 --- a/python/tvm/addon/nvcc_compiler.py +++ b/python/tvm/addon/nvcc_compiler.py @@ -4,7 +4,7 @@ import tempfile import subprocess -def compile_source(code, target="cubin"): +def compile_source(code, target="cubin", options=None): """Compile cuda code with NVCC from env. Parameters @@ -12,9 +12,12 @@ def compile_source(code, target="cubin"): code : str The cuda code. - target: str + target : str The target format + options : str + The additional options + Return ------ cubin : bytearray @@ -32,6 +35,8 @@ def compile_source(code, target="cubin"): cmd = ["nvcc"] cmd += ["--%s" % target, "-O3"] cmd += ["-o", path_target] + if options: + cmd += options cmd += [path_code] args = ' '.join(cmd) diff --git a/python/tvm/api.py b/python/tvm/api.py index 2c3f544836d4..d6c81bac69e3 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -140,14 +140,11 @@ def compute(shape, fcompute, name="compute"): return op_node.output(0) -def scan(axis, init, update, state_placeholder, name="scan"): +def scan(init, update, state_placeholder, name="scan"): """Construct new tensors by scanning over axis. Parameters ---------- - axis: IterVar - The scanning axis. - init: Tensor or list of Tensor The initial condition of first init.shape[0] timestamps @@ -170,12 +167,11 @@ def scan(axis, init, update, state_placeholder, name="scan"): # The following code is equivalent to numpy.cumsum m = tvm.Var("m") n = tvm.Var("n") - t = tvm.IterVar((1, m), name="t") X = tvm.placeholder((m, n), name="X") s_state = tvm.placeholder((m, n)) s_init = tvm.compute((1, n), lambda _, i: X[0, i]) - s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + X[t, i]) - res = tvm.scan(t, s_init, s_update, s_state) + s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i]) + res = tvm.scan(s_init, s_update, s_state) """ if isinstance(init, _tensor.Tensor): init = [init] @@ -185,6 +181,7 @@ def scan(axis, init, update, state_placeholder, name="scan"): state_placeholder = [state_placeholder] if len(init) != len(update) or len(init) != len(state_placeholder): raise ValueError("init, update, state_placeholder must have same length") + axis = IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name) op = _api_internal._ScanOp(name, axis, init, update, state_placeholder) res = [op.output(i) for i in range(len(update))] return (res[0] if len(res) == 1 else res) diff --git a/python/tvm/build.py b/python/tvm/build.py index 40cb92b458aa..764db0ae5304 100644 --- a/python/tvm/build.py +++ b/python/tvm/build.py @@ -63,7 +63,8 @@ def build(sch, arg_list.append(x) else: raise ValueError("args must be Tensor, Buffer or Var") - # lowering + # normalize schedule first + sch.normalize() bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds) stmt = ir_pass.StorageFlatten(stmt, binds) diff --git a/src/api/api_schedule.cc b/src/api/api_schedule.cc index 882ff94bde21..d953e37e2353 100644 --- a/src/api/api_schedule.cc +++ b/src/api/api_schedule.cc @@ -34,6 +34,9 @@ TVM_REGISTER_API(_schedule_AutoInlineElemWise) REGISTER_SCHEDULE_PASS1(InferBound); REGISTER_SCHEDULE_PASS1(CreateReadGraph); REGISTER_SCHEDULE_PASS2(PostDFSOrder); +REGISTER_SCHEDULE_PASS1(ScanGetBody); +REGISTER_SCHEDULE_PASS1(CreateAttachPath); +REGISTER_SCHEDULE_PASS2(ScanFixPointAnalysis); REGISTER_SCHEDULE_PASS2(ScheduleOps); } // namespace schedule diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 8fdba6650f25..8c89d93e6bc8 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -166,7 +166,15 @@ IntSet Union(const Array& set) { if (set.size() == 1) return set[0]; Interval x = set[0].cover_interval().as()->i; for (size_t i = 1; i < set.size(); ++i) { - x.include(set[i].cover_interval().as()->i); + IntSet s = set[i].cover_interval(); + const Interval& y = s.as()->i; + if (can_prove(x.max + 1 >= y.min)) { + x.max = y.max; + } else if (can_prove(y.max + 1 >= x.min)) { + x.min = y.min; + } else { + x.include(y); + } } return IntervalSet::make(x); } diff --git a/src/lang/operation.cc b/src/lang/operation.cc index ddc4770f0bb9..ac1e9541744d 100644 --- a/src/lang/operation.cc +++ b/src/lang/operation.cc @@ -51,8 +51,6 @@ Operation PlaceholderOpNode::make(std::string name, return Operation(n); } - - Tensor placeholder(Array shape, Type dtype, std::string name) { return PlaceholderOpNode::make(name, shape, dtype).output(0); } @@ -162,24 +160,25 @@ Operation ScanOpNode::make(std::string name, << " scan_axis.dom.min + scan_axis.dom.extent"; CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim()) << "The dimension of init need to match state_placeholder"; - CHECK_EQ(update[i].ndim() + 1, state_placeholder[i].ndim()) + CHECK_EQ(update[i].ndim(), state_placeholder[i].ndim()) << "The update.ndim need to be state_placeholder.ndim - 1"; for (size_t k = 0; k < update[i].ndim(); ++k) { CHECK(prove_equal( - update[i]->shape[k], state_placeholder[i]->shape[k + 1])); - // setup spatial axis - std::ostringstream spatial_name; - spatial_name << name << ".out" << i << ".i" << k + 1; - n->spatial_axis_.push_back( - IterVar(Range::make_with_min_extent(0, update[i]->shape[k]), - spatial_name.str())); + update[i]->shape[k], state_placeholder[i]->shape[k])); + if (k != 0) { + // setup spatial axis + std::ostringstream spatial_name; + spatial_name << name << ".out" << i << ".i" << k; + n->spatial_axis_.push_back( + IterVar(Range::make_with_min_extent(0, update[i]->shape[k]), + spatial_name.str())); + } } for (size_t k = 1; k < init[i].ndim(); ++k) { CHECK(prove_equal( init[i]->shape[k], state_placeholder[i]->shape[k])); } } - n->name = name; n->scan_axis = axis; n->init = init; @@ -188,11 +187,14 @@ Operation ScanOpNode::make(std::string name, return Operation(n); } -Array scan(IterVar scan_axis, - Array init, +Array scan(Array init, Array update, Array state_placeholder, std::string name) { + IterVar scan_axis( + Range::make_with_min_extent( + init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]), + name + ".idx"); Operation op = ScanOpNode::make( name, scan_axis, init, update, state_placeholder); Array res; diff --git a/src/pass/inline.cc b/src/pass/inline.cc index 1dee4776e6ab..87f54ce0b497 100644 --- a/src/pass/inline.cc +++ b/src/pass/inline.cc @@ -61,7 +61,9 @@ Stmt Inline(Stmt stmt, Expr body) { CHECK_EQ(f->num_outputs(), 1) << "can only inline output single value operation"; - return ConvertSSA(IRInline(f, args, body).Mutate(stmt)); + Stmt ret = IRInline(f, args, body).Mutate(stmt); + if (ret.same_as(stmt)) return ret; + return ConvertSSA(ret); } } // namespace ir } // namespace tvm diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc index 4724d97627a7..c2fa061bde7c 100644 --- a/src/schedule/bound.cc +++ b/src/schedule/bound.cc @@ -1,3 +1,4 @@ + /*! * Copyright (c) 2016 by Contributors * \file bound.cc @@ -259,11 +260,14 @@ void BoundProp(const Operation& op, init_dom->data[0].push_back(IntSet::range( Range::make_with_min_extent(0, scan->init[i]->shape[0]))); } + if (update_dom) { + update_dom->data[0].push_back(dom_map.at(scan->scan_axis->var.get())); + } // The update dimensions - for (size_t k = 0; k < scan->update[i]->shape.size(); ++k, ++sp_idx) { + for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) { IterVar sp_ax = scan->spatial_axis_[sp_idx]; if (init_dom) { - init_dom->data[k + 1].push_back(dom_map.at(sp_ax->var.get())); + init_dom->data[k].push_back(dom_map.at(sp_ax->var.get())); } if (update_dom) { update_dom->data[k].push_back(dom_map.at(sp_ax->var.get())); @@ -277,10 +281,12 @@ void BoundProp(const Operation& op, } } + // Given the bound of output of op // Pass the bound to the related axis in op. void GatherOpBound(const ScanOpNode* scan, const Operation& op, + const FeedGraph& fg, const std::unordered_map& tmap, std::unordered_map* rmap) { CHECK(!rmap->count(scan->scan_axis)); @@ -299,21 +305,29 @@ void GatherOpBound(const ScanOpNode* scan, Range r = arith::Union(time_dom).cover_range(sdom); (*rmap)[scan->scan_axis] = Range::make_with_min_extent( sdom->min, ir::Simplify(r->extent + r->min - sdom->min)); + Array body = ScanGetBody_(scan, fg); + Map fix_pt = ScanFixPointAnalysis(op, body); // Update for spatial axis. size_t sp_idx = 0; for (size_t i = 0; i < output.size(); ++i) { - for (size_t k = 0; k < scan->update[i]->shape.size(); ++k, ++sp_idx) { + const TensorDom& d = tmap.at(output[i]); + for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) { IterVar sp_ax = scan->spatial_axis_[sp_idx]; CHECK(!rmap->count(sp_ax)); - // In default, we always need all spatial axis - // Unless that axis only refers back to itself as a fixed point. - // TODO(tqchen): Add fix point detection. - (*rmap)[sp_ax] = sp_ax->dom; + CHECK(fix_pt.count(sp_ax)); + if (fix_pt[sp_ax].as()->value) { + // fix point, we can slice it. + (*rmap)[sp_ax] = arith::Union(d.data[k + 1]).cover_range(sp_ax->dom); + } else { + // not a fix point, need to include everything. + (*rmap)[sp_ax] = sp_ax->dom; + } } } } void GatherOpBound(const Operation& op, + const FeedGraph& fg, const std::unordered_map& tmap, std::unordered_map* rmap) { if (op.as()) { @@ -329,7 +343,7 @@ void GatherOpBound(const Operation& op, (*rmap)[compute->reduce_axis[i]] = compute->reduce_axis[i]->dom; } } else if (op.as()) { - GatherOpBound(op.as(), op, tmap, rmap); + GatherOpBound(op.as(), op, fg, tmap, rmap); } else if (op.as()) { // dp nothing } else { @@ -347,20 +361,14 @@ inline bool ScopeRelax(const IterVar& iv, const std::string& scope) { return StorageScope::make(scope).rank <= ThreadScope::make(iv->thread_tag).rank; } -// The map beteen tensor and operation it feeds ti -using FeedGraph = std::unordered_map >; - -// AttachPath maps op-> a list of IterVar -// That represents the loop nest op sits in from inner most to outermost -using AttachPath = Map >; - - void InferRootBound(const Stage& stage, const FeedGraph& feed_graph, const AttachPath& attach_path, std::unordered_map* rmap) { - if (stage->attach_type == kInline) return; - if (stage->attach_type == kRoot || stage->attach_type == kNone) { + CHECK_NE(stage->attach_type, kInline) + << "call schedule.normalize before scheduleops"; + if (stage->attach_type == kInlinedAlready) return; + if (stage->is_output || stage->op.as()) { for (auto iv : OutputRelatedIterVars(stage->op)) { CHECK(iv->dom.defined()); CHECK(!rmap->count(iv)); @@ -368,11 +376,11 @@ void InferRootBound(const Stage& stage, } return; } - // Infer root bounds for the attached node. - CHECK_EQ(stage->attach_type, kScope); - Stage parent = stage->attach_stage; - CHECK(parent.defined()); - + // parent stage, if any + Stage parent; + if (stage->attach_type == kScope || stage->attach_type == kScanUpdate) { + parent = stage->attach_stage; + } // The tensor domain. std::unordered_map tmap; // consumers other than parent @@ -385,7 +393,7 @@ void InferRootBound(const Stage& stage, auto it = feed_graph.find(t); if (it != feed_graph.end()) { for (const Operation& op : it->second) { - if (op != parent->op) { + if (!parent.defined() || op != parent->op) { consumers.insert(op); } else { direct_consume_by_parent = true; @@ -404,16 +412,20 @@ void InferRootBound(const Stage& stage, relax_set[iv->var.get()] = IntSet::range(rmap->at(iv)); } } - if (direct_consume_by_parent) { + // parent stage if exist + Stage parent = stage->attach_stage; // Bound inference logics in parent. std::unordered_map up_state; bool fix_value = true; for (auto iv : parent->leaf_iter_vars) { - Range vrange = rmap->at(iv); + auto it = rmap->find(iv); + CHECK(it != rmap->end()); + Range vrange = it->second; CHECK(is_zero(vrange->min)) << "InferBound requires every leaf iter var's min equals 0, " - << "call schedule.normalize to achieve this."; + << " call schedule.normalize to achieve this. " + << " stage=" << parent; // special optimization to remove trivial loop if (is_one(vrange->extent)) { up_state[iv] = IntSet::single_point(vrange->min); @@ -464,8 +476,9 @@ void InferRootBound(const Stage& stage, for (const Operation& op : consumers) { std::unordered_map dom_map; bool found = false; + Array attach = attach_path.at(stage->op); for (IterVar iv : attach_path.at(op)) { - if (iv == stage->attach_ivar) { + if (attach.size() != 0 && iv == attach[0]) { found = true; break; } Range vrange = rmap->at(iv); @@ -474,7 +487,7 @@ void InferRootBound(const Stage& stage, << "call schedule.normalize to achieve this."; relax_set[iv->var.get()] = IntSet::range(vrange); } - CHECK(found) + CHECK(found || attach.size() == 0) << "Invalid Schedule, cannot find the producer " << stage->op << " along the loop nest specified by compute_at of consumer " << op; for (auto iv : OutputRelatedIterVars(op)) { @@ -483,50 +496,15 @@ void InferRootBound(const Stage& stage, } BoundProp(op, dom_map, &tmap); } - GatherOpBound(stage->op, tmap, rmap); + GatherOpBound(stage->op, feed_graph, tmap, rmap); } -FeedGraph CreateFeedGraph(const Schedule& sch) { +Map InferBound(const Schedule& sch) { Array roots; for (Operation op : sch->outputs) { roots.push_back(sch->stage_map[op]->op); } - auto g = CreateReadGraph(roots); - FeedGraph fg; - for (auto kv : g) { - for (Tensor t : kv.second) { - fg[t].push_back(kv.first); - } - } - return fg; -} - -// Create AttachPath that maps op-> a list of IterVar -// That represents the loop nest op sits in from inner most to outermost -AttachPath CreateAttachPath(const Schedule& sch) { - AttachPath ret; - for (Stage stage : sch->stages) { - Array path; - for (Stage s = stage; s->attach_type == kScope;) { - IterVar attach_ivar = s->attach_ivar; - s = s->attach_stage; - bool start_attach = false; - for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) { - IterVar iv = s->leaf_iter_vars[i - 1]; - if (iv == attach_ivar) start_attach = true; - if (start_attach) path.push_back(iv); - } - CHECK(start_attach) - << "Invalid Schedule: cannot find attach point " << attach_ivar - << " in the schedule of " << s->op; - } - ret.Set(stage->op, path); - } - return ret; -} - -Map InferBound(const Schedule& sch) { - FeedGraph feed_graph = CreateFeedGraph(sch); + FeedGraph feed_graph = CreateFeedGraph(CreateReadGraph(roots)); AttachPath attach_path = CreateAttachPath(sch); std::unordered_map ret; @@ -535,6 +513,11 @@ Map InferBound(const Schedule& sch) { InferRootBound(stage, feed_graph, attach_path, &ret); // pass down to get bound of all iter vars. PassDown(stage, &ret); + // setup outer most threads. + for (IterVar iv : stage->outermost_threads) { + CHECK(iv->dom.defined()); + ret[iv] = iv->dom; + } } return Map(ret.begin(), ret.end()); } diff --git a/src/schedule/graph.cc b/src/schedule/graph.cc index f1047bf95ac9..5cd6b95193a4 100644 --- a/src/schedule/graph.cc +++ b/src/schedule/graph.cc @@ -8,6 +8,46 @@ #include #include "./graph.h" +namespace tvm { +namespace schedule { +// key to specific tensor dimension. +struct TensorDimKey { + FunctionRef f; + int value_index; + int dim; + TensorDimKey() {} + TensorDimKey(const ir::Call* op, int dim) + : f(op->func), value_index(op->value_index), dim(dim) { + } + TensorDimKey(const Tensor& t, int dim) + : f(t->op), value_index(t->value_index), dim(dim) { + } + inline bool operator==(const TensorDimKey& other) const { + return f == other.f && + value_index == other.value_index && + dim == other.dim; + } + inline bool operator!=(const TensorDimKey& other) const { + return !operator==(other); + } +}; +} // namespace schedule +} // namespace tvm + +namespace std { +template <> +struct hash<::tvm::schedule::TensorDimKey> { + std::size_t operator()(const ::tvm::schedule::TensorDimKey& k) const { + size_t lhs = k.f.hash(); + size_t rhs = static_cast(k.value_index) << 32UL | + static_cast(k.dim); + lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); + return lhs; + } +}; +} // namespace std + + namespace tvm { namespace schedule { @@ -28,7 +68,7 @@ ReadGraph CreateReadGraph(const Array& roots) { stack.pop_back(); Array deps; if (op.as()) { - auto fvisit = [&deps, &visited, &stack](const NodeRef& n) { + auto fvisit = [&deps](const NodeRef& n) { auto *call = n.as(); if (call != nullptr && call->func.defined()) { Operation call_op(call->func.node_); @@ -59,7 +99,6 @@ ReadGraph CreateReadGraph(const Array& roots) { return rmap; } - void PostDFSOrder(const Operation& op, const ReadGraph& g, std::unordered_set* visited, @@ -83,5 +122,269 @@ Array PostDFSOrder( return post_order; } +FeedGraph CreateFeedGraph(const ReadGraph& g) { + FeedGraph fg; + for (auto kv : g) { + for (Tensor t : kv.second) { + fg[t].push_back(kv.first); + } + } + return fg; +} + +AttachPath CreateAttachPath(Schedule sch) { + AttachPath ret; + + for (Stage stage : sch->stages) { + if (stage->attach_type == kScanUpdate) { + const Stage& parent = stage->attach_stage; + stage->attach_ivar = + parent->leaf_iter_vars[parent->leaf_iter_vars.size() - 1]; + } + } + + for (Stage stage : sch->stages) { + Array path; + + for (Stage s = stage; s->attach_type == kScope || s->attach_type == kScanUpdate;) { + IterVar attach_ivar = s->attach_ivar; + s = s->attach_stage; + bool start_attach = false; + for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) { + IterVar iv = s->leaf_iter_vars[i - 1]; + if (iv == attach_ivar) start_attach = true; + if (start_attach) path.push_back(iv); + } + CHECK(start_attach) + << "Invalid Schedule: cannot find attach point " << attach_ivar + << " in the schedule of " << s->op; + } + + if (!ret.count(stage->op)) { + ret.Set(stage->op, path); + } + } + return ret; +} + +// graph of push reach relation of tensor dimensions +using ReachGraph = std::unordered_map >; + +ReachGraph GetReachGraph(const Array& ops) { + ReachGraph reach; + std::unordered_set bset; + for (size_t i = 0; i < ops.size(); ++i) { + bset.insert(ops[i].get()); + } + + for (Operation op : ops) { + if (op.as()) { + const auto& update = op.as()->update; + const auto& init = op.as()->init; + for (size_t i = 0; i < update.size(); ++i) { + Tensor t = op.output(i); + for (size_t k = 1; k < update[i]->shape.size(); ++k) { + reach[TensorDimKey(t, k)].emplace_back( + TensorDimKey(update[i], k)); + reach[TensorDimKey(t, k)].emplace_back( + TensorDimKey(init[i], k)); + } + } + } else if (op.as()) { + std::unordered_map vmap; + const auto& axis = op.as()->axis; + Tensor t = op.output(0); + for (size_t i = 0; i < axis.size(); ++i) { + vmap[axis[i]->var.get()] = TensorDimKey(t, i); + reach[TensorDimKey(t, i)] = {}; + } + auto fvisit = [&vmap, &reach, &bset](const NodeRef& n) { + const ir::Call *call = n.as(); + if (call != nullptr && call->func.defined()) { + if (!bset.count(call->func.get())) return; + for (size_t i = 0; i < call->args.size(); ++i) { + TensorDimKey dkey(call, i); + auto fpush = [&dkey, &vmap, &reach](const NodeRef& node) { + const Variable *v = node.as(); + auto it = vmap.find(v); + if (it != vmap.end()) { + reach[it->second].push_back(dkey); + } + }; + ir::PostOrderVisit(call->args[i], fpush); + } + } + }; + ir::PostOrderVisit(op.as()->body, fvisit); + } + } + return reach; +} + +// Get all the operations that forms body of scan +void ScanGetBodyPostDFS_( + Operation op, + const ScanOpNode* scan, + const FeedGraph& feed_graph, + std::unordered_set* visited, + Array* result) { + if (op.get() == scan) return; + bool empty_feed = true; + for (int i = 0; i < op->num_outputs(); ++i) { + auto it = feed_graph.find(op.output(i)); + if (it != feed_graph.end() && it->second.size()) { + empty_feed = false; + for (const Operation& xop : it->second) { + if (visited->count(xop.get())) continue; + visited->insert(xop.get()); + ScanGetBodyPostDFS_(xop, scan, feed_graph, visited, result); + result->push_back(xop); + } + } + } + if (empty_feed && op.get() != scan) { + LOG(FATAL) << "Bad scan body, tensor reads scan_state but not connect to scan"; + } +} + +Array ScanGetBody_( + const ScanOpNode* scan, + const FeedGraph& feed_graph) { + CHECK(scan != nullptr); + std::unordered_set visited; + Array result; + for (Tensor t : scan->state_placeholder) { + ScanGetBodyPostDFS_(t->op, scan, feed_graph, &visited, &result); + } + return result; +} + +Array ScanGetBody(const Operation& scan) { + return ScanGetBody_(scan.as(), + CreateFeedGraph(CreateReadGraph({scan}))); +} + +Map ScanFixPointAnalysis( + const Operation& scan_op, const Array& body) { + const ScanOpNode* scan = scan_op.as(); + CHECK(body[0].get() == scan); + + std::unordered_map exact_reach; + std::unordered_set fail_set; + + for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) { + for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) { + TensorDimKey key(scan->state_placeholder[i], k); + exact_reach[key] = scan->spatial_axis_[sp_idx].get(); + } + } + // merge exact reach + auto f_merge_key = [&exact_reach, &fail_set]( + const TensorDimKey& dst, const TensorDimKey& src) { + auto sit = exact_reach.find(src); + if (sit == exact_reach.end()) return; + auto dit = exact_reach.find(dst); + if (dit == exact_reach.end()) { + exact_reach[dst] = sit->second; + } else { + if (dit->second != sit->second) { + fail_set.insert(dit->second); + fail_set.insert(sit->second); + } + } + }; + // prop exact reach back. + for (size_t i = body.size(); i != 1; --i) { + const Operation& op = body[i - 1]; + if (op.as()) { + const auto& update = op.as()->update; + const auto& init = op.as()->init; + for (size_t i = 0; i < update.size(); ++i) { + Tensor t = op.output(i); + for (size_t k = 1; i < update[i]->shape.size(); ++k) { + f_merge_key(TensorDimKey(t, k), TensorDimKey(update[i], k)); + f_merge_key(TensorDimKey(t, k), TensorDimKey(init[i], k)); + } + } + } else if (op.as()) { + std::unordered_map vmap; + const auto& axis = op.as()->axis; + Tensor t = op.output(0); + for (size_t i = 0; i < axis.size(); ++i) { + vmap[axis[i]->var.get()] = TensorDimKey(t, i); + } + auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set]( + const NodeRef& n) { + const ir::Call *call = n.as(); + if (call != nullptr && call->func.defined()) { + for (size_t i = 0; i < call->args.size(); ++i) { + auto it = vmap.find(call->args[i].get()); + TensorDimKey src(call, i); + if (it != vmap.end()) { + f_merge_key(it->second, src); + } else { + if (exact_reach.count(src)) { + fail_set.insert(exact_reach.at(src)); + } + } + } + } + }; + ir::PostOrderVisit(op.as()->body, fvisit); + } + } + ReachGraph reach; + Map ret; + std::unordered_set place_holder_ref; + for (size_t i = 0; i < scan->state_placeholder.size(); ++i) { + for (size_t k = 0; k < scan->state_placeholder[i]->shape.size(); ++k) { + place_holder_ref.insert(TensorDimKey(scan->state_placeholder[i], k)); + } + } + + for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) { + for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) { + TensorDimKey key(scan->update[i], k); + TensorDimKey target(scan->state_placeholder[i], k); + IterVar sp_iv = scan->spatial_axis_[sp_idx]; + if (fail_set.count(sp_iv.get()) || + !exact_reach.count(key) || + exact_reach.at(key) != sp_iv.get()) { + ret.Set(sp_iv, make_const(Int(32), 0)); + } else { + // now we proved exact match, need to prove no interference with other graph. + if (reach.size() == 0) reach = GetReachGraph(body); + // do a DFS + std::unordered_set visited; + std::vector stack{key}; + visited.insert(key); + while (!stack.empty()) { + TensorDimKey k = stack.back(); + if (k != target && place_holder_ref.count(k)) break; + stack.pop_back(); + if (!reach.count(k)) { + LOG(FATAL) << "cannot find reach of " << k.f << "-" << k.dim; + } + + for (TensorDimKey kk : reach.at(k)) { + if (visited.count(kk)) { + continue; + } + visited.insert(kk); + stack.push_back(kk); + } + } + if (!stack.empty()) { + // failed the prove. + ret.Set(sp_iv, make_const(Int(32), 0)); + } else { + ret.Set(sp_iv, make_const(Int(32), 1)); + } + } + } + } + return ret; +} + } // namespace schedule } // namespace tvm diff --git a/src/schedule/graph.h b/src/schedule/graph.h index 5a40c8e4ce0f..4b4b2df6e747 100644 --- a/src/schedule/graph.h +++ b/src/schedule/graph.h @@ -9,6 +9,7 @@ #include #include #include +#include #include namespace tvm { @@ -19,6 +20,16 @@ namespace schedule { */ using ReadGraph = Map >; +/*! + * \brief The map beteen tensor and operation it feeds to + */ +using FeedGraph = std::unordered_map >; + +/*! + * \brief AttachPath maps op-> a list of IterVar + */ +using AttachPath = Map >; + /*! * \brief Get read graph of each operation to all the * Tensors that it directly depends on. @@ -41,6 +52,49 @@ ReadGraph CreateReadGraph(const Array& roots); Array PostDFSOrder( const Array& roots, const ReadGraph& g); +/*! + * \brief Create feedgraph for given Schedule + * \param g The read graph. + * \return The created feedgraph. + */ +FeedGraph CreateFeedGraph(const ReadGraph& g); + +/*! + * \brief Create AttachPath that maps op-> a list of IterVar + * That represents the loop nest op sits in from inner most to outermost + * Also inserts attach_stage for scan updates when needed. + * + * \param sch The schedule. + * \return The attach path. + */ +AttachPath CreateAttachPath(Schedule sch); + +/*! + * \brief Get all operations inside the recursion of scan. + * \param scan The scan node. + * \param feed_graph The feed graph to help analysis. + * \return The body operations, in read dependency order. + */ +Array ScanGetBody_( + const ScanOpNode* scan, const FeedGraph& feed_graph); +// same as ScanGetBody_, but create FeedGraph internally. +Array ScanGetBody(const Operation& scan); + +/*! + * \brief Analyze each spatial dimension of scan's result. + * Give check on whether each dimension is fix point, + * An axis is a fixed point if it only refers back to itself in recursion + * and it is not used in axis of other recursion field. + * + * next_state[t, ..., axis, ...] = f(prev_state[t-1, ...,axis,...] + * + * \param scan The scan node. + * \param body The body of scan, sorted in reverse PostDFSOrder. + * \return Map of spatial_axis -> IntImm + */ +Map ScanFixPointAnalysis( + const Operation& scan, const Array& body); + } // namespace schedule } // namespace tvm diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc new file mode 100644 index 000000000000..9a44a9641d37 --- /dev/null +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -0,0 +1,312 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file schedule_dataflow_rewrite.cc + */ +#include +#include +#include +#include + +namespace tvm { + +// find first occurance location in leaf +template +size_t FindNodeRef(ArrayNode* array_node, const T& v) { + const Node* n = v.get(); + for (size_t i = 0; i < array_node->data.size(); ++i) { + if (array_node->data[i].get() == n) return i; + } + return array_node->data.size(); +} + +using ir::TensorKey; + +// The replacer of cache. +class TensorReplacer : public ir::IRMutator { + public: + explicit TensorReplacer(const std::unordered_map& vmap) + : vmap_(vmap) {} + Expr Mutate_(const ir::Call* op, const Expr& e) { + if (op->call_type == ir::Call::Halide) { + ir::TensorKey key{op->func, op->value_index}; + auto it = vmap_.find(key); + if (it != vmap_.end()) { + Expr ret = ir::Call::make( + op->type, it->second->op->name, op->args, + op->call_type, it->second->op, it->second->value_index); + found = true; + return IRMutator::Mutate_(ret.as(), ret); + } + } + return IRMutator::Mutate_(op, e); + } + + // whether it is found. + bool found{false}; + + private: + const std::unordered_map& vmap_; +}; + +class VarReplacer : public ir::IRMutator { + public: + explicit VarReplacer( + const std::unordered_map& vsub) + : vsub_(vsub) {} + Expr Mutate_(const Variable* op, const Expr& e) { + auto it = vsub_.find(op); + if (it != vsub_.end()) return it->second; + return e; + } + + private: + const std::unordered_map& vsub_; +}; + +// Replace data flow appears in all stages given the tensor change. +// Also update vmap if subsequent dataflow need to be replaced. +void ReplaceDataFlow(const Array& stages, + std::unordered_map* vmap) { + for (Stage s : stages) { + if (s->op.as()) { + const ComputeOpNode* compute = s->op.as(); + TensorReplacer repl(*vmap); + Expr body = repl.Mutate(compute->body); + if (repl.found) { + Operation op = ComputeOpNode::make( + compute->name, compute->axis, body); + (*vmap)[TensorKey{s->op, 0}] = op.output(0); + s->op = op; + } + } else if (s->op.as()) { + const ScanOpNode* scan = s->op.as(); + std::shared_ptr n = + std::make_shared(*scan); + // copy on write semantics ganrantees correctness + for (size_t i = 0; i < n->init.size(); ++i) { + TensorKey key{n->init[i]->op, n->init[i]->value_index}; + if (vmap->count(key)) { + n->init.Set(i, vmap->at(key)); + } + } + for (size_t i = 0; i < n->update.size(); ++i) { + TensorKey key{n->update[i]->op, n->update[i]->value_index}; + if (vmap->count(key)) { + n->update.Set(i, vmap->at(key)); + } + } + if (!n->init.same_as(scan->init) || + !n->update.same_as(scan->update)) { + Operation op(n); + for (int i = 0; i < op->num_outputs(); ++i) { + (*vmap)[TensorKey{s->op, i}] = op.output(i); + } + s->op = op; + } + } else if (s->op.as()) { + } else { + LOG(FATAL) << "unhandled problem"; + } + } +} + +Tensor Schedule::cache_read(const Tensor& tensor, + const std::string& scope, + const Array& readers) { + // create identity mapping. + std::ostringstream os; + os << tensor->op->name; + if (tensor->op->num_outputs() != 1) { + os << ".v" << tensor->value_index; + } + os << "." << scope; + + Tensor cache = compute(tensor->shape, [&tensor](const Array& i) { + return tensor(Array(i.begin(), i.end())); + }, os.str()); + std::unordered_map vsub; + vsub[TensorKey{tensor->op, tensor->value_index}] = cache; + + std::unordered_map vmap; + for (Operation op : readers) { + const ComputeOpNode* compute = op.as(); + CHECK(compute) + << "cache read only take ComputeOp as readers"; + Stage s = operator[](op); + compute = s->op.as(); + + TensorReplacer repl(vsub); + Expr body = repl.Mutate(compute->body); + CHECK(repl.found) + << "Cannot find " << tensor + << " in the body of specified reader " << op; + Operation repl_op = ComputeOpNode::make( + compute->name, compute->axis, body); + vmap[TensorKey{s->op, 0}] = repl_op.output(0); + s->op = repl_op; + } + ReplaceDataFlow((*this)->stages, &vmap); + ArrayNode* stages = (*this)->stages.CopyOnWrite(); + size_t pos = FindNodeRef(stages, operator[](tensor->op)); + Stage cache_stage = Stage(cache->op); + cache_stage.set_scope(scope); + CHECK_LT(pos, stages->data.size()); + stages->data.insert(stages->data.begin() + pos + 1, + cache_stage.node_); + (*this)->stage_map.Set(cache->op, cache_stage); + return cache; +} + +Tensor Schedule::cache_write(const Tensor& tensor, + const std::string& scope) { + Stage orig_stage = operator[](tensor->op); + const ComputeOpNode* compute = tensor->op.as(); + CHECK(compute) + << "cache write only take ComputeOp as writers"; + CHECK_EQ(orig_stage->relations.size(), 0U) + << "Create cache_write before doing split/fuse/reorder"; + compute = orig_stage->op.as(); + CHECK(compute); + Array args; + Array new_axis; + std::unordered_map vsub; + for (IterVar iv : compute->axis) { + args.push_back(iv->var); + IterVar new_iv(iv->dom, iv->var->name_hint + ".c"); + new_axis.push_back(new_iv); + vsub[iv->var.get()] = new_iv->var; + } + VarReplacer repl(vsub); + Expr body = repl.Mutate(compute->body); + Operation cache_op = ComputeOpNode::make( + compute->name + "." + scope, new_axis, body); + Tensor cache_tensor = cache_op.output(0); + Operation orig_new_op = ComputeOpNode::make( + compute->name, compute->axis, + cache_tensor(args)); + + std::unordered_map vmap; + vmap[TensorKey{orig_stage->op, 0}] = orig_new_op.output(0); + ReplaceDataFlow((*this)->stages, &vmap); + + // mutate orig stage + orig_stage->op = orig_new_op; + orig_stage->all_iter_vars = orig_stage->op->root_iter_vars(); + orig_stage->leaf_iter_vars = orig_stage->all_iter_vars; + // create schedule for new cached stage. + ArrayNode* stages = (*this)->stages.CopyOnWrite(); + size_t pos = FindNodeRef(stages, orig_stage); + Stage cache_stage = Stage(cache_op); + cache_stage.set_scope(scope); + CHECK_LT(pos, stages->data.size()); + stages->data.insert(stages->data.begin() + pos, + cache_stage.node_); + (*this)->stage_map.Set(cache_op, cache_stage); + return cache_tensor; +} + + +void RebaseNonZeroMinLoop(const Schedule& sch) { + std::unordered_map rebase_map; + std::unordered_map attach_mark; + + for (Stage s : sch->stages) { + if (s->attach_type == kScope) { + attach_mark[s->attach_stage.get()] = 1; + } + if (s->op.as()) { + attach_mark[s.get()] = 1; + } + } + + for (Stage s : sch->stages) { + if (!attach_mark.count(s.get())) continue; + auto root_iter_vars = s->op->root_iter_vars(); + ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite(); + for (IterVar iv : root_iter_vars) { + size_t idx = FindNodeRef(leaf_vars, iv); + if (idx < leaf_vars->data.size()) { + // insert rebase + IterVar rebased(Range(), iv->var->name_hint + ".rb"); + s->relations.push_back(RebaseNode::make(iv, rebased)); + leaf_vars->data[idx] = rebased.node_; + rebase_map[iv] = rebased; + } + } + } + // remap the parent relation + for (Stage s : sch->stages) { + if (s->attach_type != kScope) continue; + if (rebase_map.count(s->attach_ivar)) { + s->attach_ivar = rebase_map.at(s->attach_ivar); + } + } +} + +void SetScanAttach(const Schedule& sch) { // NOLINT(*) + for (Stage stage : sch->stages) { + if (stage->attach_type == kScanUpdate) { + const Stage& parent = stage->attach_stage; + stage->attach_ivar = + parent->leaf_iter_vars[parent->leaf_iter_vars.size() - 1]; + } + } +} + + +void InjectInline(const Schedule& sch) { + std::vector new_body(sch->stages.size()); + // inline all the ops + for (size_t i = sch->stages.size(); i != 0; --i) { + Stage stage = sch->stages[i - 1]; + if (stage->attach_type == kInline) { + stage->attach_type = kInlinedAlready; + Array args; + Expr body; + { + // setup args + const ComputeOpNode* compute = stage->op.as(); + CHECK(compute) + << "can only inline compute op"; + for (auto iv : compute->axis) { + args.push_back(iv->var); + } + body = compute->body; + } + for (size_t j = i; j < sch->stages.size(); ++j) { + Stage s = sch->stages[j]; + const ComputeOpNode* compute = s->op.as(); + if (compute) { + if (!new_body[j].defined()) { + new_body[j] = s->op.as()->body; + } + new_body[j] = ir::Inline(ir::Evaluate::make(new_body[j]), + stage->op, args, body).as()->value; + } + } + } + } + std::unordered_map repl; + // rewrite dataflow + for (size_t i = 0; i < sch->stages.size(); ++i) { + if (new_body[i].defined() && + !new_body[i].same_as(sch->stages[i]->op)) { + const ComputeOpNode* compute = sch->stages[i]->op.as(); + CHECK(compute); + Operation op = ComputeOpNode::make( + compute->name, compute->axis, new_body[i]); + repl[TensorKey{sch->stages[i]->op, 0}] = op.output(0); + Stage s = sch->stages[i]; + s->op = op; + } + } + ReplaceDataFlow(sch->stages, &repl); +} + +void Schedule::normalize() { + RebaseNonZeroMinLoop(*this); + SetScanAttach(*this); + InjectInline(*this); +} + +} // namespace tvm diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index b18ae28e5475..308070a8b702 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -1,6 +1,6 @@ /*! * Copyright (c) 2016 by Contributors - * \file schedule.cc + * \file schedule_lang.cc */ #include #include @@ -37,6 +37,10 @@ size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) void Split(StageNode* self, IterVar parent, IterVar outer, IterVar inner, Expr factor) { + if (self->attach_type == kScanUpdate) { + CHECK(!parent.same_as(self->all_iter_vars[0])) + << "Cannot split on axis[0] of scan update"; + } ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); size_t pos = FindLeafVar(all_vars, leaf_vars, parent); @@ -83,6 +87,8 @@ Stage& Stage::set_scope(std::string scope) { // NOLINT(*) } Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*) + CHECK_NE((*this)->attach_type, kScanUpdate) + << "Cannot specify compute_at for scan updates"; (*this)->attach_type = kScope; (*this)->attach_ivar = scope; (*this)->attach_stage = parent; @@ -93,16 +99,22 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*) } } CHECK(found) - << "Cannot find the axis in parent's leaf_iter_vars or outermost_threads"; + << "Cannot find the axis " << scope + << " in parent's leaf_iter_vars or outermost_threads:" + << " parent=" << parent; return *this; } Stage& Stage::compute_inline() { // NOLINT(*) + CHECK_NE((*this)->attach_type, kScanUpdate) + << "Cannot specify compute_at for scan updates"; (*this)->attach_type = kInline; return *this; } Stage& Stage::compute_root() { // NOLINT(*) + CHECK_NE((*this)->attach_type, kScanUpdate) + << "Cannot specify compute_at for scan updates"; (*this)->attach_type = kRoot; return *this; } @@ -128,9 +140,15 @@ Stage& Stage::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor } Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT(*) + StageNode* self = operator->(); + if (self->attach_type == kScanUpdate) { + CHECK(!inner.same_as(self->all_iter_vars[0])) + << "Cannot split on axis[0] of scan update"; + CHECK(!outer.same_as(self->all_iter_vars[0])) + << "Cannot split on axis[0] of scan update"; + } IterVar fused(Range(), outer->var->name_hint + "." + inner->var->name_hint + ".fused"); *p_target = fused; - StageNode* self = operator->(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); @@ -157,6 +175,10 @@ Stage& Stage::reorder(const Array& order) { // NOLINT(*) std::vector pos; for (size_t i = 0; i < order.size(); ++i) { + if ((*this)->attach_type == kScanUpdate) { + CHECK(!order[i].same_as(self->all_iter_vars[0])) + << "Cannot split on axis[0] of scan update"; + } pos.push_back(FindLeafVar(all_vars, leaf_vars, order[i])); } std::vector > temp; @@ -239,12 +261,25 @@ Schedule::Schedule(Array ops) { stage->is_output = output_set.count(op); n->stages.push_back(stage); n->stage_map.Set(op, stage); + // mark scan updates. + if (op.as()) { + const ScanOpNode* scan = op.as(); + for (size_t i = 0; i < scan->update.size(); ++i) { + Stage s = n->stage_map[scan->update[i]->op]; + s->attach_type = kScanUpdate; + s->attach_stage = stage; + } + } } node_ = std::move(n); } Stage Schedule::operator[](const Operation& op) { - return (*this)->stage_map.at(op); + auto it = (*this)->stage_map.find(op); + CHECK(it != (*this)->stage_map.end()) + << "Cannot find Stage for operator " << op + << " in the schedule"; + return (*it).second; } IterVarRelation SplitNode::make( @@ -274,42 +309,6 @@ IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) { return IterVarRelation(n); } -void Schedule::normalize() { - std::unordered_map rebase_map; - std::unordered_map attach_mark; - - - for (Stage s : (*this)->stages) { - if (s->attach_type == kScope) { - attach_mark[s->attach_stage.get()] = 1; - } - } - - for (Stage s : (*this)->stages) { - if (!attach_mark.count(s.get())) continue; - auto root_iter_vars = s->op->root_iter_vars(); - ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite(); - - for (IterVar iv : root_iter_vars) { - size_t idx = FindNodeRef(leaf_vars, iv); - if (idx < leaf_vars->data.size()) { - // insert rebase - IterVar rebased(Range(), iv->var->name_hint + ".rb"); - s->relations.push_back(RebaseNode::make(iv, rebased)); - leaf_vars->data[idx] = rebased.node_; - rebase_map[iv] = rebased; - } - } - } - // remap the parent relation - for (Stage s : (*this)->stages) { - if (s->attach_type != kScope) continue; - if (rebase_map.count(s->attach_ivar)) { - s->attach_ivar = rebase_map.at(s->attach_ivar); - } - } -} - IterVarAttr::IterVarAttr(IterVarType t) { std::shared_ptr n = std::make_shared(); n->iter_type = t; @@ -323,190 +322,4 @@ TVM_REGISTER_NODE_TYPE(FuseNode); TVM_REGISTER_NODE_TYPE(RebaseNode); TVM_REGISTER_NODE_TYPE(ScheduleNode); -using ir::TensorKey; - -// The replacer of cache. -class TensorReplacer : public ir::IRMutator { - public: - TensorReplacer(const std::unordered_map& vmap) - : vmap_(vmap) {} - Expr Mutate_(const ir::Call* op, const Expr& e) { - if (op->call_type == ir::Call::Halide) { - ir::TensorKey key{op->func, op->value_index}; - auto it = vmap_.find(key); - if (it != vmap_.end()) { - Expr ret = ir::Call::make( - op->type, it->second->op->name, op->args, - op->call_type, it->second->op, it->second->value_index); - found = true; - return IRMutator::Mutate_(ret.as(), ret); - } - } - return IRMutator::Mutate_(op, e); - } - - // whether it is found. - bool found{false}; - - private: - const std::unordered_map& vmap_; -}; - -class VarReplacer : public ir::IRMutator { - public: - explicit VarReplacer( - const std::unordered_map& vsub) - : vsub_(vsub) {} - Expr Mutate_(const Variable* op, const Expr& e) { - auto it = vsub_.find(op); - if (it != vsub_.end()) return it->second; - return e; - } - - private: - const std::unordered_map& vsub_; -}; - -// Replace data flow appears in all stages given the tensor change. -// Also update vmap if subsequent dataflow need to be replaced. -void ReplaceDataFlow(const Array& stages, - std::unordered_map* vmap) { - for (Stage s : stages) { - if (s->op.as()) { - const ComputeOpNode* compute = s->op.as(); - TensorReplacer repl(*vmap); - Expr body = repl.Mutate(compute->body); - if (repl.found) { - Operation op = ComputeOpNode::make( - compute->name, compute->axis, body); - (*vmap)[TensorKey{s->op, 0}] = op.output(0); - s->op = op; - } - } else if (s->op.as()) { - const ScanOpNode* scan = s->op.as(); - std::shared_ptr n = - std::make_shared(*scan); - // copy on write semantics ganrantees correctness - for (size_t i = 0; i < n->init.size(); ++i) { - TensorKey key{n->init[i]->op, n->init[i]->value_index}; - if (vmap->count(key)) { - n->init.Set(i, vmap->at(key)); - } - } - for (size_t i = 0; i < n->update.size(); ++i) { - TensorKey key{n->update[i]->op, n->update[i]->value_index}; - if (vmap->count(key)) { - n->update.Set(i, vmap->at(key)); - } - } - if (!n->init.same_as(scan->init) || - !n->update.same_as(scan->update)) { - Operation op(n); - for (int i = 0; i < op->num_outputs(); ++i) { - (*vmap)[TensorKey{s->op, i}] = op.output(i); - } - s->op = op; - } - } else if (s->op.as()) { - } else { - LOG(FATAL) << "unhandled problem"; - } - } -} - -Tensor Schedule::cache_read(const Tensor& tensor, - const std::string& scope, - const Array& readers) { - // create identity mapping. - std::ostringstream os; - os << tensor->op->name; - if (tensor->op->num_outputs() != 1) { - os << ".v" << tensor->value_index; - } - os << "." << scope; - - Tensor cache = compute(tensor->shape, [&tensor](const Array& i) { - return tensor(Array(i.begin(), i.end())); - }, os.str()); - std::unordered_map vsub; - vsub[TensorKey{tensor->op, tensor->value_index}] = cache; - - std::unordered_map vmap; - for (Operation op : readers) { - const ComputeOpNode* compute = op.as(); - CHECK(compute) - << "cache read only take ComputeOp as readers"; - Stage s = operator[](op); - compute = s->op.as(); - - TensorReplacer repl(vsub); - Expr body = repl.Mutate(compute->body); - CHECK(repl.found) - << "Cannot find " << tensor - << " in the body of specified reader" << op; - Operation repl_op = ComputeOpNode::make( - compute->name, compute->axis, body); - vmap[TensorKey{s->op, 0}] = repl_op.output(0); - s->op = repl_op; - } - ReplaceDataFlow((*this)->stages, &vmap); - ArrayNode* stages = (*this)->stages.CopyOnWrite(); - size_t pos = FindNodeRef(stages, operator[](tensor->op)); - Stage cache_stage = Stage(cache->op); - cache_stage.set_scope(scope); - CHECK_LT(pos, stages->data.size()); - stages->data.insert(stages->data.begin() + pos + 1, - cache_stage.node_); - (*this)->stage_map.Set(cache->op, cache_stage); - return cache; -} - -Tensor Schedule::cache_write(const Tensor& tensor, - const std::string& scope) { - Stage orig_stage = operator[](tensor->op); - const ComputeOpNode* compute = tensor->op.as(); - CHECK(compute) - << "cache write only take ComputeOp as writers"; - CHECK(!orig_stage.is_scheduled()) - << "Create cache_write before doing split/fuse/reorder"; - compute = orig_stage->op.as(); - CHECK(compute); - Array args; - Array new_axis; - std::unordered_map vsub; - for (IterVar iv : compute->axis) { - args.push_back(iv->var); - IterVar new_iv(iv->dom, iv->var->name_hint + ".c"); - new_axis.push_back(new_iv); - vsub[iv->var.get()] = new_iv->var; - } - VarReplacer repl(vsub); - Expr body = repl.Mutate(compute->body); - Operation cache_op = ComputeOpNode::make( - compute->name + "." + scope, new_axis, body); - Tensor cache_tensor = cache_op.output(0); - Operation orig_new_op = ComputeOpNode::make( - compute->name, compute->axis, - cache_tensor(args)); - - std::unordered_map vmap; - vmap[TensorKey{orig_stage->op, 0}] = orig_new_op.output(0); - ReplaceDataFlow((*this)->stages, &vmap); - - // mutate orig stage - orig_stage->op = orig_new_op; - orig_stage->all_iter_vars = orig_stage->op->root_iter_vars(); - orig_stage->leaf_iter_vars = orig_stage->all_iter_vars; - // create schedule for new cached stage. - ArrayNode* stages = (*this)->stages.CopyOnWrite(); - size_t pos = FindNodeRef(stages, orig_stage); - Stage cache_stage = Stage(cache_op); - cache_stage.set_scope(scope); - CHECK_LT(pos, stages->data.size()); - stages->data.insert(stages->data.begin() + pos, - cache_stage.node_); - (*this)->stage_map.Set(cache_op, cache_stage); - return cache_tensor; -} - } // namespace tvm diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index aa7c383635ef..4b7c7f886d45 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -369,7 +369,7 @@ Stmt MakeRealize(const ScanOpNode* op, CHECK_EQ(static_cast(t->value_index), i); Halide::Internal::Region bounds; bounds.push_back(tdom); - for (size_t k = 0; k < op->update[i]->shape.size(); ++k, ++sp_idx) { + for (size_t k = 1; k < op->update[i]->shape.size(); ++k, ++sp_idx) { IterVar sp_ax = op->spatial_axis_[sp_idx]; bounds.push_back(dom_map.at(sp_ax)); } @@ -561,6 +561,7 @@ class InjectScanStep : public IRMutator { Stmt InjectInline(const Operation op, Stmt body) { CHECK(body.defined()); + const ComputeOpNode* compute = op.as(); CHECK(compute != nullptr) << "can only inline compute op"; @@ -614,7 +615,7 @@ class SchedulePostProc : public IRMutator { if (it->second.defined()) { Stmt ret = AttrStmt::make( it->second, op->type_key, op->value, op->body); - return this->Mutate_(ret.as(), ret); + return this->Mutate(ret); } else { return this->Mutate(op->body); } @@ -631,7 +632,7 @@ class SchedulePostProc : public IRMutator { Stmt ret = Realize::make( it->second->op, it->second->value_index, op->type, op->bounds, op->condition, op->body); - return this->Mutate_(ret.as(), ret); + return this->Mutate(ret); } else { return this->Mutate(op->body); } @@ -644,11 +645,10 @@ class SchedulePostProc : public IRMutator { TensorKey key{op->func, op->value_index}; auto it = replace_buffer_.find(key); if (it != replace_buffer_.end()) { - const Tensor& dst = it->second.first; + const Tensor& dst = it->second; Stmt ret = Provide::make( - dst->op, dst->value_index, op->value, - RewriteArgs(it->second.second, op->args)); - return IRMutator::Mutate_(ret.as(), ret); + dst->op, dst->value_index, op->value, op->args); + return this->Mutate(ret); } else { return IRMutator::Mutate_(op, s); } @@ -659,12 +659,11 @@ class SchedulePostProc : public IRMutator { TensorKey key{op->func, op->value_index}; auto it = replace_buffer_.find(key); if (it != replace_buffer_.end()) { - const Tensor& dst = it->second.first; + const Tensor& dst = it->second; Expr ret = Call::make( - op->type, dst->op->name, - RewriteArgs(it->second.second, op->args), + op->type, dst->op->name, op->args, op->call_type, dst->op, dst->value_index); - return IRMutator::Mutate_(ret.as(), ret); + return this->Mutate(ret); } } return IRMutator::Mutate_(op, e); @@ -685,14 +684,14 @@ class SchedulePostProc : public IRMutator { const ScanOpNode* scan = s->op.as(); for (size_t i = 0; i < scan->update.size(); ++i) { Tensor t = s->origin_op.output(i); - AddReplace(scan->init[i], t, Expr()); - AddReplace(scan->update[i], t, scan->scan_axis->var); - AddReplace(scan->state_placeholder[i], t, Expr()); + AddReplace(scan->init[i], t); + AddReplace(scan->update[i], t); + AddReplace(scan->state_placeholder[i], t); } } else if (!s->op.same_as(s->origin_op)) { Tensor target = s->origin_op.output(0); AddReplace(s->op.output(0), target, - Expr(), target, s->origin_op); + target, s->origin_op); } } } @@ -700,26 +699,17 @@ class SchedulePostProc : public IRMutator { private: void AddReplace(Tensor src, Tensor dst, - Expr head_idx, Tensor repl_realize = Tensor(), Operation repl_op = Operation()) { TensorKey key{src->op, src->value_index}; - replace_buffer_[key] = std::make_pair(dst, head_idx); + replace_buffer_[key] = dst; replace_realize_[key] = repl_realize; replace_op_[src->op.get()] = repl_op; } - Array RewriteArgs(Expr head, Array args) { - if (!head.defined()) return args; - Array new_args{head}; - for (Expr e : args) { - new_args.push_back(e); - } - return new_args; - } // The scan value std::unordered_map var_value_; // buffer replacement - std::unordered_map > replace_buffer_; + std::unordered_map replace_buffer_; // buffere realization to be replaced std::unordered_map replace_realize_; // replace producer consumer. @@ -755,10 +745,13 @@ Stmt ScheduleOps( // reverse the post DFS order. for (size_t i = sch->stages.size(); i != 0; --i) { Stage s = sch->stages[i - 1]; + CHECK_NE(s->attach_type, kInline) + << "call schedule.normalize before scheduleops"; // no need to specify place holder op. if (s->op.as()) continue; if (scan_attach.count(s->op)) { - CHECK(s->attach_type == kNone || s->attach_type == kInline) + CHECK(s->attach_type == kNone || + s->attach_type == kScanUpdate) << "Cannot specify compute_at for scan's init/update"; CHECK(body.defined()); const auto& p = scan_attach.at(s->op); @@ -766,8 +759,8 @@ Stmt ScheduleOps( body = mu.Mutate(body); CHECK(mu.found_attach) << "did not find attachment point for scan.init/update"; - } else if (s->attach_type == kInline) { - body = InjectInline(s->op, body); + } else if (s->attach_type == kInlinedAlready) { + // do nothing } else if (s->attach_type == kRoot || s-> attach_type == kNone) { body = MakePipeline(s, dom_map, body); } else if (s->attach_type == kScope) { diff --git a/tests/python/integration/test_scan.py b/tests/python/integration/test_scan.py index 38cd832f2e43..08adab491639 100644 --- a/tests/python/integration/test_scan.py +++ b/tests/python/integration/test_scan.py @@ -8,8 +8,8 @@ def test_scan(): X = tvm.placeholder((m, n), name="X") s_state = tvm.placeholder((m, n)) s_init = tvm.compute((1, n), lambda _, i: X[0, i]) - s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + X[t, i]) - res = tvm.scan(t, s_init, s_update, s_state) + s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i]) + res = tvm.scan(s_init, s_update, s_state) # schedule s = tvm.Schedule(res.op) @@ -18,7 +18,7 @@ def test_scan(): thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x") _, x = s[s_init].split(s_init.op.axis[1], factor=num_thread, outer=block_x) _, x = s[s_init].split(x, outer=thread_x) - _, x = s[s_update].split(s_update.op.axis[0], factor=num_thread, outer=block_x) + _, x = s[s_update].split(s_update.op.axis[1], factor=num_thread, outer=block_x) _, x = s[s_update].split(x, outer=thread_x) # one line to build the function. diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index 3459e80e918c..c5dfb748df7e 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -40,9 +40,8 @@ def test_tensor_scan(): t = tvm.IterVar((1, m), "t") x = tvm.placeholder((m, n)) s = tvm.placeholder((m, n)) - res = tvm.scan(t, - tvm.compute((1, n), lambda _, i: x[0, i]), - tvm.compute((n,), lambda i: s[t-1, i] + x[t, i]), + res = tvm.scan(tvm.compute((1, n), lambda _, i: x[0, i]), + tvm.compute((m, n), lambda t, i: s[t-1, i] + x[t, i]), s) assert tuple(res.shape) == (m, n) diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py index e80fb275c561..3e187766f516 100644 --- a/tests/python/unittest/test_schedule_bound_inference.py +++ b/tests/python/unittest/test_schedule_bound_inference.py @@ -50,25 +50,30 @@ def test_bound3(): assert(bounds[A1.op.axis[0]].extent.value==32) assert(bounds[A1.op.axis[1]].extent.value==16) +def test_bound_scan(): + m = tvm.Var("m") + n = tvm.Var("n") + X = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") + s_state = tvm.placeholder((m, n)) + s_init = tvm.compute((1, n), lambda _, i: X[0, i]) + s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i]) + s_scan = tvm.scan(s_init, s_update, s_state) -def test_create_read_graph(): - m = tvm.Var('m') - l = tvm.Var('l') - A = tvm.placeholder((m, l), name='A') - A1 = tvm.compute((m, l), lambda i, j: A[i, j]) - A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3) + assert tuple(s_scan.shape) == (m, n) - g = tvm.schedule.CreateReadGraph([A2.op]) + s = tvm.Schedule(s_scan.op) + XX = s.cache_read(X, "local", s_update) + xo, xi = s[s_update].split(s_update.op.axis[1], factor=4) + s[XX].compute_at(s[s_update], xo) - assert g[A2.op][0] == A1 - assert g[A1.op][0] == A - post_order = tvm.schedule.PostDFSOrder([A2.op], g) - assert(post_order[0] == A.op) - assert(post_order[1] == A1.op) + s.normalize() + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + assert bounds[XX.op.axis[1]].extent.value == 4 if __name__ == "__main__": - test_create_read_graph() + test_bound_scan() test_bound3() test_bound1() test_bound2() diff --git a/tests/python/unittest/test_schedule_graph.py b/tests/python/unittest/test_schedule_graph.py new file mode 100644 index 000000000000..2d1af01d710d --- /dev/null +++ b/tests/python/unittest/test_schedule_graph.py @@ -0,0 +1,101 @@ +import tvm + +def test_scan(): + m = tvm.Var("m") + n = tvm.Var("n") + x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") + s_state = tvm.placeholder((m, n)) + s_init = tvm.compute((1, n), lambda _, i: x[0, i], name="s_init") + x_trans = tvm.compute((m, n), lambda i, j: x[i, j] + 1, name="x_trans") + s_up1 = tvm.compute((m, n), lambda t, i: s_state[t - 1, i] + 1, name="up1") + s_update = tvm.compute((m, n), lambda t, i: s_up1[t, i] + x_trans[t, i], name="update") + s_scan = tvm.scan(s_init, s_update, s_state) + + def test_getbody(): + body = tvm.schedule.ScanGetBody(s_scan.op) + assert set(body) == set([s_scan.op, s_update.op, s_up1.op]) + + def test_attach_path(): + s = tvm.Schedule(s_scan.op) + s[x_trans].compute_at(s[s_update], s_update.op.axis[0]) + apath = tvm.schedule.CreateAttachPath(s) + assert(tuple(apath[s_update.op]) == tuple([s_scan.op.scan_axis])) + assert(tuple(apath[x_trans.op]) == tuple([s_update.op.axis[0], s_scan.op.scan_axis])) + + def test_fix_pt(): + body = tvm.schedule.ScanGetBody(s_scan.op) + fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body) + assert(fxpt[s_scan.spatial_axis_[0]].value != 0) + +def test_scan_fix_point(): + m = tvm.Var("m") + n = tvm.Var("n") + l = tvm.Var("l") + x = tvm.compute((l, m, n), lambda *i: tvm.const(1, "float32"), name="x") + s_state = tvm.placeholder((l, m, n)) + s_init = tvm.compute((1, m, n), lambda _, i, j: x[0, i, j], name="s_init") + + def test_scan0(): + s_update = tvm.compute((l, m, n), + lambda t, i, j: x[t, j, i] + s_state[t-1, i, j], name="update") + s_scan = tvm.scan(s_init, s_update, s_state) + body = tvm.schedule.ScanGetBody(s_scan.op) + fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body) + assert(fxpt[s_scan.op.spatial_axis_[0]].value == 1) + assert(fxpt[s_scan.op.spatial_axis_[1]].value == 1) + + def test_scan1(): + s_update = tvm.compute((l, m, n), + lambda t, i, j: x[t, j, i] + s_state[t-1, j, i], name="update") + s_scan = tvm.scan(s_init, s_update, s_state) + body = tvm.schedule.ScanGetBody(s_scan.op) + fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body) + assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0) + assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0) + + def test_scan3_not_exact_reach(): + s_h1 = tvm.compute((l, n, m), lambda t, j, i: s_state[t-1, i, j], name="h1") + s_h2 = tvm.compute((l, m, n), lambda t, i, j: s_state[t-1, i, 10] * 2, name="h1") + s_update = tvm.compute((l, m, n), lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update") + s_scan = tvm.scan(s_init, s_update, s_state) + body = tvm.schedule.ScanGetBody(s_scan.op) + fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body) + assert(fxpt[s_scan.op.spatial_axis_[0]].value == 1) + assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0) + + def test_scan4_reach_other(): + s_h1 = tvm.compute((l, n, m), lambda t, j, i: s_state[t-1, j, j], name="h1") + s_h2 = tvm.compute((l, m, n), lambda t, i, j: s_state[t-1, i, j] * 2, name="h1") + s_update = tvm.compute((l, m, n), + lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update") + s_scan = tvm.scan(s_init, s_update, s_state) + body = tvm.schedule.ScanGetBody(s_scan.op) + fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body) + assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0) + assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0) + + test_scan0() + test_scan1() + test_scan3_not_exact_reach() + test_scan4_reach_other() + +def test_create_read_graph(): + m = tvm.Var('m') + l = tvm.Var('l') + A = tvm.placeholder((m, l), name='A') + A1 = tvm.compute((m, l), lambda i, j: A[i, j]) + A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3) + + g = tvm.schedule.CreateReadGraph([A2.op]) + + assert g[A2.op][0] == A1 + assert g[A1.op][0] == A + post_order = tvm.schedule.PostDFSOrder([A2.op], g) + assert(post_order[0] == A.op) + assert(post_order[1] == A1.op) + + +if __name__ == "__main__": + test_scan() + test_create_read_graph() + test_scan_fix_point() diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 625bee596414..f24e7ffd1b64 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -43,13 +43,11 @@ def test_schedule2(): def test_schedule_scan(): m = tvm.Var("m") n = tvm.Var("n") - l = tvm.Var("l") - t = tvm.IterVar((1, m), name="t") x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") s_state = tvm.placeholder((m, n)) s_init = tvm.compute((1, n), lambda _, i: x[0, i]) - s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + x[t, i]) - res = tvm.scan(t, s_init, s_update, s_state) + s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + x[t, i]) + res = tvm.scan(s_init, s_update, s_state) assert tuple(res.shape) == (m, n) s = tvm.Schedule(res.op) @@ -59,7 +57,6 @@ def test_schedule_scan(): stmt = tvm.schedule.ScheduleOps(s, bounds) print(stmt) - def test_auto_inline(): m = tvm.Var('m') n = tvm.Var('n') @@ -71,9 +68,27 @@ def test_auto_inline(): s = tvm.Schedule(T2.op) tvm.schedule.AutoInlineElemWise(s) + s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) +def test_inline_mixed(): + n = tvm.Var('n') + A = tvm.placeholder((n, ), name='A') + A1 = tvm.compute(A.shape, lambda *i: A(*i) + 1, name='A1') + A2 = tvm.compute(A.shape, lambda *i: A1(*i) + 2, name='A2') + C = tvm.compute((n,), lambda i: A2[i] + A1[i], name='C') + + s = tvm.Schedule(C.op) + xo, xi = s[C].split(C.op.axis[0], factor=8) + s[A1].compute_at(s[C], xo) + s[A2].compute_inline() + s.normalize() + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + print(stmt) + + def test_schedule_cache(): m = tvm.Var('m') n = tvm.Var('n') @@ -90,9 +105,10 @@ def test_schedule_cache(): if __name__ == "__main__": + test_inline_mixed() + test_auto_inline() test_schedule_scan() test_schedule0() test_schedule1() test_schedule2() - test_auto_inline() test_schedule_cache()