From 1c256f48a415e7c775cbf2a892a3d8ca29e3d25d Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 12 Jun 2020 17:23:05 -0700 Subject: [PATCH] [TIR][REFACTOR] Cleanup unused classes (#5789) --- include/tvm/arith/bound.h | 8 ++----- include/tvm/te/operation.h | 8 ++++--- include/tvm/te/tensor.h | 5 ++-- include/tvm/tir/expr.h | 34 --------------------------- include/tvm/tir/var.h | 2 -- src/arith/domain_touched.cc | 6 ++--- src/contrib/hybrid/codegen_hybrid.cc | 2 +- src/te/schedule/graph.cc | 12 +++++----- src/tir/transforms/inject_prefetch.cc | 2 +- 9 files changed, 21 insertions(+), 58 deletions(-) diff --git a/include/tvm/arith/bound.h b/include/tvm/arith/bound.h index df1a9e7c7a43..12b91cc033e5 100644 --- a/include/tvm/arith/bound.h +++ b/include/tvm/arith/bound.h @@ -32,13 +32,9 @@ #include namespace tvm { -// forward delcare Tensor -namespace te { -class Tensor; -} namespace arith { -using tir::Domain; +using tir::Region; using tir::Stmt; using tir::Var; using tir::VarNode; @@ -82,7 +78,7 @@ IntSet DeduceBound(PrimExpr v, PrimExpr cond, * \param consider_stores If stores are considered. * \return The domain that covers all the calls or provides within the given statement. */ -Domain DomainTouched(const Stmt& body, const tir::Buffer& buffer, bool consider_loads, +Region DomainTouched(const Stmt& body, const tir::Buffer& buffer, bool consider_loads, bool consider_stores); } // namespace arith diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 4b7037aec7dc..dbd07fa4cf69 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -53,7 +53,7 @@ struct TensorDom { /*! * \brief Base class of all operation nodes */ -class OperationNode : public tir::FunctionBaseNode { +class OperationNode : public Object { public: /*! \brief optional name of the operation */ std::string name; @@ -61,8 +61,10 @@ class OperationNode : public tir::FunctionBaseNode { std::string tag; /*! \brief additional attributes of the operation*/ Map attrs; - /*! \return name of the operation */ - const std::string& func_name() const final { return name; } + // virtual destructor. + virtual ~OperationNode() {} + /*! \return number of outputs */ + virtual int num_outputs() const = 0; /*! * \return The list of iteration variable at root * \note root_iter_vars decides the shape of the outputs. diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index 0c4af4bc636a..2f9fa2f534c5 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -42,13 +42,14 @@ using namespace tvm::tir; // internal node container for Operation class OperationNode; +class Tensor; /*! \brief Operation that produces tensors */ -class Operation : public tir::FunctionRef { +class Operation : public ObjectRef { public: /*! \brief default constructor */ Operation() {} - explicit Operation(ObjectPtr n) : FunctionRef(n) {} + explicit Operation(ObjectPtr n) : ObjectRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 423f09e1b984..4b6b28d52ee9 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -870,40 +870,6 @@ class Let : public PrimExpr { TVM_DEFINE_OBJECT_REF_METHODS(Let, PrimExpr, LetNode); }; -// Call node, represent a function call or a multi-dimensional array load. -// -// TODO(tvm-team): -// Refactor call with more explicit property registrations. -// rather than calling a string symbol. -// We should move most information into function itself and remove name. - -/*! \brief Base node of internal functions. */ -class FunctionBaseNode : public Object { - public: - /*! \brief virtual destructor */ - virtual ~FunctionBaseNode() {} - /*! \return the name of the function */ - virtual const std::string& func_name() const = 0; - /*! \return the number of outputs of this function */ - virtual int num_outputs() const = 0; - - // fall back to pointer equality now before refactor. - bool SEqualReduce(const FunctionBaseNode* other, SEqualReducer equal) const { - return this == other; - } - - void SHashReduce(SHashReducer hash_reduce) const {} - - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; -}; - -/*! \brief reference to a function */ -class FunctionRef : public ObjectRef { - public: - TVM_DEFINE_OBJECT_REF_METHODS(FunctionRef, ObjectRef, FunctionBaseNode); -}; - /*! * \brief Call node. */ diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 363bf6b2eb19..9f098248b836 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -226,8 +226,6 @@ enum IterVarType : int { kTensorized = 8 }; -using Domain = Array; - /*! * \brief An iteration variable representing an iteration * over a one dimensional interval. diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc index 0ac4a893a77f..b44d9f7ff1f5 100644 --- a/src/arith/domain_touched.cc +++ b/src/arith/domain_touched.cc @@ -40,9 +40,9 @@ class BufferTouchedDomain final : public StmtExprVisitor { BufferTouchedDomain(const Buffer& buffer, bool consider_loads, bool consider_stores) : buffer_(buffer), consider_loads_(consider_loads), consider_stores_(consider_stores) {} - Domain Find(const Stmt& stmt) { + Region Find(const Stmt& stmt) { operator()(stmt); - Domain ret; + Region ret; Range none; for (size_t i = 0; i < bounds_.size(); ++i) { ret.push_back(arith::Union(bounds_[i]).cover_range(none)); @@ -107,7 +107,7 @@ class BufferTouchedDomain final : public StmtExprVisitor { std::unordered_map dom_map_; }; -Domain DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads, +Region DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads, bool consider_stores) { return BufferTouchedDomain(buffer, consider_loads, consider_stores).Find(stmt); } diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index e9ec585de164..e08f39f8135d 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -414,7 +414,7 @@ std::string CodeGenHybrid::GetTensorID(const Tensor& tensor) { if (id_map_.count(key)) { return id_map_[key]; } - std::string name_hint = tensor->op->func_name(); + std::string name_hint = tensor->op->name; if (tensor->op->num_outputs() > 1) { name_hint += "_v" + std::to_string(tensor->value_index); } diff --git a/src/te/schedule/graph.cc b/src/te/schedule/graph.cc index 62557ed8573f..09e899581d14 100644 --- a/src/te/schedule/graph.cc +++ b/src/te/schedule/graph.cc @@ -36,15 +36,15 @@ namespace tvm { namespace te { // key to specific tensor dimension. struct TensorDimKey { - tir::FunctionRef f; + Operation op; int value_index; int dim; TensorDimKey() {} - TensorDimKey(const Tensor& t, int dim) : f(t->op), value_index(t->value_index), dim(dim) {} + TensorDimKey(const Tensor& t, int dim) : op(t->op), value_index(t->value_index), dim(dim) {} TensorDimKey(const Tensor& t, size_t dim) - : f(t->op), value_index(t->value_index), dim(static_cast(dim)) {} + : op(t->op), value_index(t->value_index), dim(static_cast(dim)) {} inline bool operator==(const TensorDimKey& other) const { - return f == other.f && value_index == other.value_index && dim == other.dim; + return op == other.op && value_index == other.value_index && dim == other.dim; } inline bool operator!=(const TensorDimKey& other) const { return !operator==(other); } }; @@ -55,7 +55,7 @@ namespace std { template <> struct hash<::tvm::te::TensorDimKey> { std::size_t operator()(const ::tvm::te::TensorDimKey& k) const { - size_t lhs = ::tvm::ObjectPtrHash()(k.f); + size_t lhs = ::tvm::ObjectPtrHash()(k.op); size_t rhs = static_cast(k.value_index) << 16UL | static_cast(k.dim); lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); return lhs; @@ -378,7 +378,7 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { if (k != target && place_holder_ref.count(k)) break; stack.pop_back(); if (!reach.count(k)) { - LOG(FATAL) << "cannot find reach of " << k.f << "-" << k.dim; + LOG(FATAL) << "cannot find reach of " << k.op << "-" << k.dim; } for (TensorDimKey kk : reach.at(k)) { diff --git a/src/tir/transforms/inject_prefetch.cc b/src/tir/transforms/inject_prefetch.cc index 3b626f0108a1..9c27a71929c5 100644 --- a/src/tir/transforms/inject_prefetch.cc +++ b/src/tir/transforms/inject_prefetch.cc @@ -45,7 +45,7 @@ class PrefetchInjector : public StmtMutator { if (op && op->attr_key == attr::prefetch_scope) { Buffer buffer = Downcast(op->node); CHECK_NE(loop_nest_.size(), 0U); - Domain domain = DomainTouched(op->body, buffer, true, false); + Region domain = DomainTouched(op->body, buffer, true, false); Region region; auto iter_var = loop_nest_.back().get();