diff --git a/include/tvm/tir/data_layout.h b/include/tvm/tir/data_layout.h index 0a20db6a0a632..5eaa1909a539f 100644 --- a/include/tvm/tir/data_layout.h +++ b/include/tvm/tir/data_layout.h @@ -45,10 +45,10 @@ class LayoutAxis { static const LayoutAxis& Get(const tir::IterVar& itvar); // Get the singleton LayoutAxis using name[0] (size of name must be 1). - static const LayoutAxis& make(const std::string& name); + static const LayoutAxis& make(const String& name); inline bool IsPrimal() const { return name_ >= 'A' && name_ <= 'Z'; } - inline std::string name() const { return std::string(1, name_); } + inline String name() const { return String(std::string(1, name_)); } // if current axis is primal, switch the axis to its subordinate one, // else switch to the primal. @@ -88,7 +88,7 @@ class Layout; class LayoutNode : public Object { public: /*! \brief string representation of layout, "" for scalar. */ - std::string name; + String name; /*! \brief specify each axis of the layout, * in which the variable name is the name of the axis. * The IterVar's extent indicates the size of the axis, @@ -102,7 +102,7 @@ class LayoutNode : public Object { v->Visit("axes", &axes); } - TVM_DLL static Layout make(const std::string& layout); + TVM_DLL static Layout make(const String& layout); static constexpr const char* _type_key = "Layout"; TVM_DECLARE_FINAL_OBJECT_INFO(LayoutNode, Object); @@ -128,7 +128,8 @@ class Layout : public ObjectRef { explicit Layout(const Array& axes); /*! \brief construct from a string */ - Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*) + Layout(const char* name) : Layout(String(name)) {} // NOLINT(*) + Layout(const std::string& name) : Layout(String(name)) {} // NOLINT(*) /*! * \brief construct from a string. @@ -138,7 +139,7 @@ class Layout : public ObjectRef { * indicates the split dimension. * return undefined layout if "__undef__" is passed. */ - Layout(const std::string& name); // NOLINT(*) + Layout(const String& name); // NOLINT(*) /*! * \brief access the internal node container @@ -206,16 +207,17 @@ class Layout : public ObjectRef { inline Layout ExpandPrimal(const Layout& dst_layout) { Layout new_src_layout; // 1) Find the axis which are missing in the current layout. Make them the prefix. - std::string new_src_layout_str = ""; + String new_src_layout_str = ""; for (auto dst_axis : dst_layout->axes) { if (LayoutAxis::Get(dst_axis).IsPrimal()) { if (!this->Contains(LayoutAxis::Get(dst_axis))) { - new_src_layout_str += dst_axis->var->name_hint; + new_src_layout_str.operator std::string() += + dst_axis->var->name_hint.operator std::string(); } } } // 2) Now, add the primal axis of the current layout. - new_src_layout_str += this->name(); + new_src_layout_str.operator std::string() += this->name().operator std::string(); new_src_layout = Layout(new_src_layout_str); return new_src_layout; } @@ -269,7 +271,7 @@ class Layout : public ObjectRef { } /*! \return the string description of the layout */ - inline std::string name() const { + inline String name() const { if (!defined()) return "__undef__"; return operator->()->name; } diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index 934a3a1ed1740..296726bf335ba 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -61,14 +61,14 @@ const LayoutAxis& LayoutAxis::Get(const char name) { } const LayoutAxis& LayoutAxis::Get(const IterVar& itvar) { - const std::string axis = itvar->var.get()->name_hint; + const String axis = itvar->var.get()->name_hint; CHECK_EQ(axis.size(), 1) << "Invalid layout axis " << axis; - return LayoutAxis::Get(axis[0]); + return LayoutAxis::Get(axis.operator std::string()[0]); } -const LayoutAxis& LayoutAxis::make(const std::string& name) { +const LayoutAxis& LayoutAxis::make(const String& name) { CHECK_EQ(name.length(), 1) << "Invalid axis " << name; - return LayoutAxis::Get(name[0]); + return LayoutAxis::Get(name.operator std::string()[0]); } Layout::Layout(const Array& axes) { @@ -90,7 +90,7 @@ Layout::Layout(const Array& axes) { data_ = std::move(node); } -Layout::Layout(const std::string& name) { // NOLINT(*) +Layout::Layout(const String& name) { // NOLINT(*) if (name == "__undef__") return; auto node = make_object(); @@ -100,12 +100,12 @@ Layout::Layout(const std::string& name) { // NOLINT(*) // parse layout string int32_t factor = 0; - for (char c : name) { + for (char c : name.operator std::string()) { if (c >= 'A' && c <= 'Z') { CHECK_EQ(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor << " before dimension " << c; - std::string shape_name("_shape"); - shape_name.insert(0, 1, c); + String shape_name("_shape"); + shape_name.operator std::string().insert(0, 1, c); IterVar axis = IterVarNode::make(Range(PrimExpr(0), Var(shape_name)), Var(std::string(1, c)), tir::kDataPar); node->axes.push_back(axis); @@ -144,7 +144,7 @@ Layout::Layout(const std::string& name) { // NOLINT(*) data_ = std::move(node); } -Layout LayoutNode::make(const std::string& layout) { return Layout(layout); } +Layout LayoutNode::make(const String& layout) { return Layout(layout); } Layout Layout::SubLayout(size_t pos, size_t len) const { if (!defined() || pos > ndim()) return Layout::Undef(); @@ -160,7 +160,7 @@ Layout Layout::SubLayout(size_t pos, size_t len) const { Layout Layout::Split(const LayoutAxis& axis, size_t target_pos, int32_t factor) const { if (!defined()) return Layout::Undef(); - const std::string& name = operator->()->name; + const String& name = operator->()->name; const auto axes = operator->()->axes; CHECK(target_pos <= this->ndim()) << "Invalid split position " << target_pos << " for layout " << name; @@ -367,12 +367,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed(LayoutNode::make); -TVM_REGISTER_GLOBAL("tir.LayoutIndexOf").set_body_typed([](Layout layout, std::string axis) -> int { +TVM_REGISTER_GLOBAL("tir.LayoutIndexOf").set_body_typed([](Layout layout, String axis) -> int { return layout.IndexOf(LayoutAxis::make(axis)); }); TVM_REGISTER_GLOBAL("tir.LayoutFactorOf") - .set_body_typed([](Layout layout, std::string axis) -> int { + .set_body_typed([](Layout layout, String axis) -> int { return layout.FactorOf(LayoutAxis::make(axis)); }); @@ -380,7 +380,7 @@ TVM_REGISTER_GLOBAL("tir.LayoutNdim").set_body_typed([](Layout layout) -> int { return layout.ndim(); }); -TVM_REGISTER_GLOBAL("tir.LayoutGetItem").set_body_typed([](Layout layout, int idx) -> std::string { +TVM_REGISTER_GLOBAL("tir.LayoutGetItem").set_body_typed([](Layout layout, int idx) -> String { const LayoutAxis& axis = layout[idx]; return axis.name(); });