Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][REFACTOR] std::string -> String Migration in TIR nodes #5596

Merged
merged 2 commits into from
May 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
* \file tvm/tir/stmt.h
* \brief TIR statements.
*/
// Acknowledgement: Mnay low-level stmts originate from Halide.
// Acknowledgement: Many low-level stmts originate from Halide.
#ifndef TVM_TIR_STMT_H_
#define TVM_TIR_STMT_H_

Expand Down
12 changes: 6 additions & 6 deletions include/tvm/tir/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class VarNode : public PrimExprNode {
* \brief The hint to the variable name.
* \note Each variable is uniquely identified by its address.
*/
std::string name_hint;
String name_hint;
/*!
* \brief type annotaion of the variable.
*
Expand Down Expand Up @@ -92,19 +92,19 @@ class Var : public PrimExpr {
* \param name_hint variable name
* \param dtype data type
*/
TVM_DLL explicit Var(std::string name_hint = "v", DataType dtype = DataType::Int(32));
TVM_DLL explicit Var(String name_hint = "v", DataType dtype = DataType::Int(32));
/*!
* \brief Constructor which provides a more detailed type annotation.
* \param name_hint variable name.
* \param type_annotation The type annotation.
*/
TVM_DLL explicit Var(std::string name_hint, Type type_annotation);
TVM_DLL explicit Var(String name_hint, Type type_annotation);
/*!
* \brief Make a new copy of var with same type, append suffix
* \param suffix The suffix to be appended.
* \return the new Var copy
*/
TVM_DLL Var copy_with_suffix(const std::string& suffix) const;
TVM_DLL Var copy_with_suffix(const String& suffix) const;
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
Expand Down Expand Up @@ -138,7 +138,7 @@ class SizeVar : public Var {
* \param name_hint variable name
* \param t data type
*/
TVM_DLL explicit SizeVar(std::string name_hint = "s", DataType t = DataType::Int(32));
TVM_DLL explicit SizeVar(String name_hint = "s", DataType t = DataType::Int(32));
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
Expand Down Expand Up @@ -178,7 +178,7 @@ enum IterVarType : int {
/*!
* \brief The IterVar itself is a thread-index
* of a fixed thread launching group.
* Note that this is already assumed to be paralellized.
* Note that this is already assumed to be parallelized.
*
* Disallow: split/fuse/vectorize/parallel
*/
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def _convert(item, nodes):
"relay.ModulePass": _rename("transform.ModulePass"),
"relay.Sequential": _rename("transform.Sequential"),
# TIR
"Variable": _update_tir_var("tir.Var"),
"SizeVar": _update_tir_var("tir.SizeVar"),
"Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")],
"SizeVar": [_update_tir_var("tir.SizeVar"), _update_from_std_str("name")],
}
return create_updater(node_map, "0.6", "0.7")

Expand Down
2 changes: 1 addition & 1 deletion src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ Doc TIRTextPrinter::AllocVar(const Var& var) {
if (it != memo_var_.end()) {
return it->second;
}
std::string name = var->name_hint;
std::string name = var->name_hint.operator std::string();
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "v" + name;
}
Expand Down
6 changes: 3 additions & 3 deletions src/tir/ir/data_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ Layout::Layout(const Array<IterVar>& axes) {
}
CHECK_EQ(axis->var.get()->name_hint.size(), 1)
<< "Invalid layout axis " << axis->var.get()->name_hint;
char c = axis->var.get()->name_hint[0];
char c = axis->var.get()->name_hint.operator std::string()[0];
CHECK((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) << "Invalid layout axis " << c;
repr << axis->var.get()->name_hint;
}
Expand Down Expand Up @@ -127,15 +127,15 @@ Layout::Layout(const std::string& name) { // NOLINT(*)
// validate layout
std::vector<bool> exist_axis(256, false);
for (const IterVar& v : node->axes) {
auto axis_str = v->var.get()->name_hint;
auto axis_str = v->var.get()->name_hint.operator std::string();
CHECK_EQ(axis_str.size(), 1);
char axis = axis_str[0];
CHECK((axis >= 'a' && axis <= 'z') || (axis >= 'A' && axis <= 'Z'));
CHECK(!exist_axis[axis]) << "Invalid layout " << name << ": duplicate axis " << axis;
exist_axis[axis] = true;
}
for (const IterVar& v : node->axes) {
char axis = v->var.get()->name_hint[0];
char axis = v->var.get()->name_hint.operator std::string()[0];
if (axis >= 'a' && axis <= 'z') {
CHECK(exist_axis[axis - 'a' + 'A'])
<< "Invalid layout " << name << ": missing axis " << std::toupper(axis);
Expand Down
15 changes: 7 additions & 8 deletions src/tir/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,50 +34,49 @@
namespace tvm {
namespace tir {

Var::Var(std::string name_hint, DataType dtype) {
Var::Var(String name_hint, DataType dtype) {
auto n = make_object<VarNode>();
n->name_hint = std::move(name_hint);
n->dtype = std::move(dtype);
data_ = std::move(n);
}

Var::Var(std::string name_hint, Type type_annotation) {
Var::Var(String name_hint, Type type_annotation) {
auto n = make_object<VarNode>();
n->name_hint = std::move(name_hint);
n->dtype = GetRuntimeDataType(type_annotation);
n->type_annotation = std::move(type_annotation);
data_ = std::move(n);
}

Var Var::copy_with_suffix(const std::string& suffix) const {
Var Var::copy_with_suffix(const String& suffix) const {
const VarNode* node = get();
ObjectPtr<VarNode> new_ptr;
if (auto* ptr = this->as<SizeVarNode>()) {
new_ptr = make_object<SizeVarNode>(*ptr);
} else {
new_ptr = make_object<VarNode>(*node);
}
new_ptr->name_hint += suffix;

new_ptr->name_hint = new_ptr->name_hint.operator std::string() + suffix.operator std::string();
return Var(new_ptr);
}

SizeVar::SizeVar(std::string name_hint, DataType dtype) {
SizeVar::SizeVar(String name_hint, DataType dtype) {
auto n = make_object<SizeVarNode>();
n->name_hint = std::move(name_hint);
n->dtype = std::move(dtype);
data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](std::string name_hint, runtime::TVMArgValue type) {
TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](String name_hint, runtime::TVMArgValue type) {
if (type.IsObjectRef<Type>()) {
return Var(name_hint, type.operator Type());
} else {
return Var(name_hint, type.operator DataType());
}
});

TVM_REGISTER_GLOBAL("tir.SizeVar").set_body_typed([](std::string s, DataType t) {
TVM_REGISTER_GLOBAL("tir.SizeVar").set_body_typed([](String s, DataType t) {
return SizeVar(s, t);
});

Expand Down