Skip to content

Commit

Permalink
[TIR][REFACTOR] std::string -> String Migration in TIR nodes (apache#…
Browse files Browse the repository at this point in the history
…5596)

* [TIR][REFACTOR] std::string -> String Migration for Var node and SizeVar Node

* update json_compact.py
  • Loading branch information
cchung100m authored and Trevor Morris committed Jun 9, 2020
1 parent a7ab91e commit e03458b
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 21 deletions.
2 changes: 1 addition & 1 deletion include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
* \file tvm/tir/stmt.h
* \brief TIR statements.
*/
// Acknowledgement: Mnay low-level stmts originate from Halide.
// Acknowledgement: Many low-level stmts originate from Halide.
#ifndef TVM_TIR_STMT_H_
#define TVM_TIR_STMT_H_

Expand Down
12 changes: 6 additions & 6 deletions include/tvm/tir/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class VarNode : public PrimExprNode {
* \brief The hint to the variable name.
* \note Each variable is uniquely identified by its address.
*/
std::string name_hint;
String name_hint;
/*!
* \brief type annotaion of the variable.
*
Expand Down Expand Up @@ -92,19 +92,19 @@ class Var : public PrimExpr {
* \param name_hint variable name
* \param dtype data type
*/
TVM_DLL explicit Var(std::string name_hint = "v", DataType dtype = DataType::Int(32));
TVM_DLL explicit Var(String name_hint = "v", DataType dtype = DataType::Int(32));
/*!
* \brief Constructor which provides a more detailed type annotation.
* \param name_hint variable name.
* \param type_annotation The type annotation.
*/
TVM_DLL explicit Var(std::string name_hint, Type type_annotation);
TVM_DLL explicit Var(String name_hint, Type type_annotation);
/*!
* \brief Make a new copy of var with same type, append suffix
* \param suffix The suffix to be appended.
* \return the new Var copy
*/
TVM_DLL Var copy_with_suffix(const std::string& suffix) const;
TVM_DLL Var copy_with_suffix(const String& suffix) const;
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
Expand Down Expand Up @@ -138,7 +138,7 @@ class SizeVar : public Var {
* \param name_hint variable name
* \param t data type
*/
TVM_DLL explicit SizeVar(std::string name_hint = "s", DataType t = DataType::Int(32));
TVM_DLL explicit SizeVar(String name_hint = "s", DataType t = DataType::Int(32));
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
Expand Down Expand Up @@ -178,7 +178,7 @@ enum IterVarType : int {
/*!
* \brief The IterVar itself is a thread-index
* of a fixed thread launching group.
* Note that this is already assumed to be paralellized.
* Note that this is already assumed to be parallelized.
*
* Disallow: split/fuse/vectorize/parallel
*/
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def _convert(item, nodes):
"relay.ModulePass": _rename("transform.ModulePass"),
"relay.Sequential": _rename("transform.Sequential"),
# TIR
"Variable": _update_tir_var("tir.Var"),
"SizeVar": _update_tir_var("tir.SizeVar"),
"Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")],
"SizeVar": [_update_tir_var("tir.SizeVar"), _update_from_std_str("name")],
}
return create_updater(node_map, "0.6", "0.7")

Expand Down
2 changes: 1 addition & 1 deletion src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ Doc TIRTextPrinter::AllocVar(const Var& var) {
if (it != memo_var_.end()) {
return it->second;
}
std::string name = var->name_hint;
std::string name = var->name_hint.operator std::string();
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "v" + name;
}
Expand Down
6 changes: 3 additions & 3 deletions src/tir/ir/data_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ Layout::Layout(const Array<IterVar>& axes) {
}
CHECK_EQ(axis->var.get()->name_hint.size(), 1)
<< "Invalid layout axis " << axis->var.get()->name_hint;
char c = axis->var.get()->name_hint[0];
char c = axis->var.get()->name_hint.operator std::string()[0];
CHECK((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) << "Invalid layout axis " << c;
repr << axis->var.get()->name_hint;
}
Expand Down Expand Up @@ -127,15 +127,15 @@ Layout::Layout(const std::string& name) { // NOLINT(*)
// validate layout
std::vector<bool> exist_axis(256, false);
for (const IterVar& v : node->axes) {
auto axis_str = v->var.get()->name_hint;
auto axis_str = v->var.get()->name_hint.operator std::string();
CHECK_EQ(axis_str.size(), 1);
char axis = axis_str[0];
CHECK((axis >= 'a' && axis <= 'z') || (axis >= 'A' && axis <= 'Z'));
CHECK(!exist_axis[axis]) << "Invalid layout " << name << ": duplicate axis " << axis;
exist_axis[axis] = true;
}
for (const IterVar& v : node->axes) {
char axis = v->var.get()->name_hint[0];
char axis = v->var.get()->name_hint.operator std::string()[0];
if (axis >= 'a' && axis <= 'z') {
CHECK(exist_axis[axis - 'a' + 'A'])
<< "Invalid layout " << name << ": missing axis " << std::toupper(axis);
Expand Down
15 changes: 7 additions & 8 deletions src/tir/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,50 +34,49 @@
namespace tvm {
namespace tir {

Var::Var(std::string name_hint, DataType dtype) {
Var::Var(String name_hint, DataType dtype) {
auto n = make_object<VarNode>();
n->name_hint = std::move(name_hint);
n->dtype = std::move(dtype);
data_ = std::move(n);
}

Var::Var(std::string name_hint, Type type_annotation) {
Var::Var(String name_hint, Type type_annotation) {
auto n = make_object<VarNode>();
n->name_hint = std::move(name_hint);
n->dtype = GetRuntimeDataType(type_annotation);
n->type_annotation = std::move(type_annotation);
data_ = std::move(n);
}

Var Var::copy_with_suffix(const std::string& suffix) const {
Var Var::copy_with_suffix(const String& suffix) const {
const VarNode* node = get();
ObjectPtr<VarNode> new_ptr;
if (auto* ptr = this->as<SizeVarNode>()) {
new_ptr = make_object<SizeVarNode>(*ptr);
} else {
new_ptr = make_object<VarNode>(*node);
}
new_ptr->name_hint += suffix;

new_ptr->name_hint = new_ptr->name_hint.operator std::string() + suffix.operator std::string();
return Var(new_ptr);
}

SizeVar::SizeVar(std::string name_hint, DataType dtype) {
SizeVar::SizeVar(String name_hint, DataType dtype) {
auto n = make_object<SizeVarNode>();
n->name_hint = std::move(name_hint);
n->dtype = std::move(dtype);
data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](std::string name_hint, runtime::TVMArgValue type) {
TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](String name_hint, runtime::TVMArgValue type) {
if (type.IsObjectRef<Type>()) {
return Var(name_hint, type.operator Type());
} else {
return Var(name_hint, type.operator DataType());
}
});

TVM_REGISTER_GLOBAL("tir.SizeVar").set_body_typed([](std::string s, DataType t) {
TVM_REGISTER_GLOBAL("tir.SizeVar").set_body_typed([](String s, DataType t) {
return SizeVar(s, t);
});

Expand Down

0 comments on commit e03458b

Please sign in to comment.