From 1e849b756cef87664d3bc2fec1c792cf6f3584ae Mon Sep 17 00:00:00 2001 From: Neo Chien Date: Thu, 28 May 2020 16:56:06 +0800 Subject: [PATCH] [TIR][REFACTOR] std::string -> String Migration in TIR nodes (#5596) * [TIR][REFACTOR] std::string -> String Migration for Var node and SizeVar Node * update json_compact.py --- include/tvm/tir/stmt.h | 2 +- include/tvm/tir/var.h | 12 ++++++------ python/tvm/ir/json_compact.py | 4 ++-- src/printer/tir_text_printer.cc | 2 +- src/tir/ir/data_layout.cc | 6 +++--- src/tir/ir/expr.cc | 15 +++++++-------- 6 files changed, 20 insertions(+), 21 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index e1fef552b84a..bbc37febee71 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -20,7 +20,7 @@ * \file tvm/tir/stmt.h * \brief TIR statements. */ -// Acknowledgement: Mnay low-level stmts originate from Halide. +// Acknowledgement: Many low-level stmts originate from Halide. #ifndef TVM_TIR_STMT_H_ #define TVM_TIR_STMT_H_ diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index a89c665b9377..4db462daa5da 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -50,7 +50,7 @@ class VarNode : public PrimExprNode { * \brief The hint to the variable name. * \note Each variable is uniquely identified by its address. */ - std::string name_hint; + String name_hint; /*! * \brief type annotaion of the variable. * @@ -92,19 +92,19 @@ class Var : public PrimExpr { * \param name_hint variable name * \param dtype data type */ - TVM_DLL explicit Var(std::string name_hint = "v", DataType dtype = DataType::Int(32)); + TVM_DLL explicit Var(String name_hint = "v", DataType dtype = DataType::Int(32)); /*! * \brief Constructor which provides a more detailed type annotation. * \param name_hint variable name. * \param type_annotation The type annotation. */ - TVM_DLL explicit Var(std::string name_hint, Type type_annotation); + TVM_DLL explicit Var(String name_hint, Type type_annotation); /*! * \brief Make a new copy of var with same type, append suffix * \param suffix The suffix to be appended. * \return the new Var copy */ - TVM_DLL Var copy_with_suffix(const std::string& suffix) const; + TVM_DLL Var copy_with_suffix(const String& suffix) const; /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. @@ -138,7 +138,7 @@ class SizeVar : public Var { * \param name_hint variable name * \param t data type */ - TVM_DLL explicit SizeVar(std::string name_hint = "s", DataType t = DataType::Int(32)); + TVM_DLL explicit SizeVar(String name_hint = "s", DataType t = DataType::Int(32)); /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. @@ -178,7 +178,7 @@ enum IterVarType : int { /*! * \brief The IterVar itself is a thread-index * of a fixed thread launching group. - * Note that this is already assumed to be paralellized. + * Note that this is already assumed to be parallelized. * * Disallow: split/fuse/vectorize/parallel */ diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index 2abfd81188fd..9d90685d9a60 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -130,8 +130,8 @@ def _convert(item, nodes): "relay.ModulePass": _rename("transform.ModulePass"), "relay.Sequential": _rename("transform.Sequential"), # TIR - "Variable": _update_tir_var("tir.Var"), - "SizeVar": _update_tir_var("tir.SizeVar"), + "Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")], + "SizeVar": [_update_tir_var("tir.SizeVar"), _update_from_std_str("name")], } return create_updater(node_map, "0.6", "0.7") diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 0bcc1488cba8..4d22cbb68a9b 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -567,7 +567,7 @@ Doc TIRTextPrinter::AllocVar(const Var& var) { if (it != memo_var_.end()) { return it->second; } - std::string name = var->name_hint; + std::string name = var->name_hint.operator std::string(); if (name.length() == 0 || !std::isalpha(name[0])) { name = "v" + name; } diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index 23e13edadde5..6c389825501b 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -82,7 +82,7 @@ Layout::Layout(const Array& axes) { } CHECK_EQ(axis->var.get()->name_hint.size(), 1) << "Invalid layout axis " << axis->var.get()->name_hint; - char c = axis->var.get()->name_hint[0]; + char c = axis->var.get()->name_hint.operator std::string()[0]; CHECK((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) << "Invalid layout axis " << c; repr << axis->var.get()->name_hint; } @@ -127,7 +127,7 @@ Layout::Layout(const std::string& name) { // NOLINT(*) // validate layout std::vector exist_axis(256, false); for (const IterVar& v : node->axes) { - auto axis_str = v->var.get()->name_hint; + auto axis_str = v->var.get()->name_hint.operator std::string(); CHECK_EQ(axis_str.size(), 1); char axis = axis_str[0]; CHECK((axis >= 'a' && axis <= 'z') || (axis >= 'A' && axis <= 'Z')); @@ -135,7 +135,7 @@ Layout::Layout(const std::string& name) { // NOLINT(*) exist_axis[axis] = true; } for (const IterVar& v : node->axes) { - char axis = v->var.get()->name_hint[0]; + char axis = v->var.get()->name_hint.operator std::string()[0]; if (axis >= 'a' && axis <= 'z') { CHECK(exist_axis[axis - 'a' + 'A']) << "Invalid layout " << name << ": missing axis " << std::toupper(axis); diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 569415546ec8..8b9a8e2f7812 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -34,14 +34,14 @@ namespace tvm { namespace tir { -Var::Var(std::string name_hint, DataType dtype) { +Var::Var(String name_hint, DataType dtype) { auto n = make_object(); n->name_hint = std::move(name_hint); n->dtype = std::move(dtype); data_ = std::move(n); } -Var::Var(std::string name_hint, Type type_annotation) { +Var::Var(String name_hint, Type type_annotation) { auto n = make_object(); n->name_hint = std::move(name_hint); n->dtype = GetRuntimeDataType(type_annotation); @@ -49,7 +49,7 @@ Var::Var(std::string name_hint, Type type_annotation) { data_ = std::move(n); } -Var Var::copy_with_suffix(const std::string& suffix) const { +Var Var::copy_with_suffix(const String& suffix) const { const VarNode* node = get(); ObjectPtr new_ptr; if (auto* ptr = this->as()) { @@ -57,19 +57,18 @@ Var Var::copy_with_suffix(const std::string& suffix) const { } else { new_ptr = make_object(*node); } - new_ptr->name_hint += suffix; - + new_ptr->name_hint = new_ptr->name_hint.operator std::string() + suffix.operator std::string(); return Var(new_ptr); } -SizeVar::SizeVar(std::string name_hint, DataType dtype) { +SizeVar::SizeVar(String name_hint, DataType dtype) { auto n = make_object(); n->name_hint = std::move(name_hint); n->dtype = std::move(dtype); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](std::string name_hint, runtime::TVMArgValue type) { +TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](String name_hint, runtime::TVMArgValue type) { if (type.IsObjectRef()) { return Var(name_hint, type.operator Type()); } else { @@ -77,7 +76,7 @@ TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](std::string name_hint, runtime: } }); -TVM_REGISTER_GLOBAL("tir.SizeVar").set_body_typed([](std::string s, DataType t) { +TVM_REGISTER_GLOBAL("tir.SizeVar").set_body_typed([](String s, DataType t) { return SizeVar(s, t); });