Skip to content

Commit

Permalink
[TIR][REFACTIR] Update TIR nodes std::string->String. (#5793)
Browse files Browse the repository at this point in the history
This PR updates the remaining TIR node's member to use
String instead of std::string.
  • Loading branch information
tqchen authored Jun 13, 2020
1 parent 59f5cbe commit 8578096
Show file tree
Hide file tree
Showing 31 changed files with 102 additions and 68 deletions.
2 changes: 1 addition & 1 deletion include/tvm/ir/adt.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ namespace tvm {
class ConstructorNode : public RelayExprNode {
public:
/*! \brief The name (only a hint) */
std::string name_hint;
String name_hint;
/*! \brief Input to the constructor. */
Array<Type> inputs;
/*! \brief The datatype the constructor will construct. */
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
* struct MyAttrs : public tvm::AttrsNode<MyAttrs> {
* float learning_rate;
* int num_hidden;
* std::string name;
* String name;
* // declare attribute fields in header file
* TVM_DECLARE_ATTRS(MyAttrs, "attrs.MyAttrs") {
* TVM_ATTR_FIELD(num_hidden).set_lower_bound(1);
Expand Down Expand Up @@ -106,11 +106,11 @@ struct AttrError : public dmlc::Error {
class AttrFieldInfoNode : public Object {
public:
/*! \brief name of the field */
std::string name;
String name;
/*! \brief type docstring information in str. */
std::string type_info;
String type_info;
/*! \brief detailed description of the type */
std::string description;
String description;

void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/env_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ namespace tvm {
class EnvFuncNode : public Object {
public:
/*! \brief Unique name of the global function */
std::string name;
String name;
/*! \brief The internal packed function */
runtime::PackedFunc func;
/*! \brief constructor */
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/ir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,21 @@ class OpAttrMap;
class OpNode : public RelayExprNode {
public:
/*! \brief name of the operator */
std::string name;
String name;
/*! \brief the type of the operator */
mutable FuncType op_type;
/*!
* \brief detailed description of the operator
* This can be used to generate docstring automatically for the operator.
*/
std::string description;
String description;
/* \brief Information of input arguments to the operator */
Array<AttrFieldInfo> arguments;
/*!
* \brief The type key of the attribute field
* This can be empty, in which case it defaults to anything.
*/
std::string attrs_type_key;
String attrs_type_key;
/*!
* \brief attribute type index,
* this field varies in each run and is not exposed to frontend.
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,10 +253,10 @@ class PassInfoNode : public Object {
int opt_level;

/*! \brief The name of an optimization/analysis pass. */
std::string name;
String name;

/*! \brief The passes that are required to perform the current pass. */
Array<runtime::String> required;
Array<String> required;

PassInfoNode() = default;

Expand Down Expand Up @@ -407,7 +407,7 @@ class Sequential : public Pass {
*/
TVM_DLL Pass
CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level, const String& name, const Array<runtime::String>& required);
int opt_level, String name, Array<runtime::String> required);

/*!
* \brief A special trace pass that prints the header and IR to LOG(INFO).
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/node/reflection.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ class ReflectionVTable {
* \return The corresponding attribute value.
* \note This function will throw an exception if the object does not contain the field.
*/
TVM_DLL runtime::TVMRetValue GetAttr(Object* self, const std::string& attr_name) const;
TVM_DLL runtime::TVMRetValue GetAttr(Object* self, const String& attr_name) const;

/*!
* \brief List all the fields in the object.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/op_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class OpImplementationNode : public Object {
/*! \brief Schedule function */
FTVMSchedule fschedule;
/*! \brief Name of the implementation */
std::string name;
String name;
/*! \brief Priority level */
int plevel;

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ using Sequential = tvm::transform::Sequential;
*/
TVM_DLL Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
int opt_level, const String& name, const tvm::Array<runtime::String>& required);
int opt_level, String name, tvm::Array<String> required);

/*! \brief Remove expressions which does not effect the program result.
*
Expand Down
30 changes: 26 additions & 4 deletions include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@
#include <utility>
#include <vector>

namespace llvm {
// String to llvm object compatibility.
class StringRef;
} // namespace llvm

namespace tvm {

struct ObjectEqual;
Expand Down Expand Up @@ -1161,7 +1166,14 @@ class String : public ObjectRef {
* \param other The value for the new String
*
*/
inline String operator=(std::string other);
inline String& operator=(std::string other);

/*!
* \brief Change the value the reference object points to.
*
* \param other The value for the new String
*/
inline String& operator=(const char* other);

/*!
* \brief Compare is less than other std::string
Expand Down Expand Up @@ -1304,12 +1316,20 @@ class String : public ObjectRef {
const char* data() const { return get()->data; }

/*!
* \brief Convert String to an std::sting object
* \brief Convert String to an std::string object
*
* \return std::string
*/
operator std::string() const { return std::string{get()->data, size()}; }

// LLVM compatibility function, implemented in src/target/llvm/llvm_common.h
/*!
* \brief Convert String to an llvm::StringRef object
*
* \return llvm::StringRef
*/
inline operator llvm::StringRef() const;

/*!
* \brief Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String
* \param val The value to be checked
Expand Down Expand Up @@ -1382,12 +1402,14 @@ inline String::String(std::string other) {
data_ = std::move(ptr);
}

inline String String::operator=(std::string other) {
inline String& String::operator=(std::string other) {
String replace{std::move(other)};
data_.swap(replace.data_);
return Downcast<String>(*this);
return *this;
}

inline String& String::operator=(const char* other) { return operator=(std::string(other)); }

inline String operator+(const std::string lhs, const String& rhs) {
return lhs + rhs.operator std::string();
}
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ class BufferNode : public Object {
PrimExpr elem_offset;
// Meta data
/*! \brief optional name of the buffer */
std::string name;
String name;
/*! \brief storage scope of the buffer, if other than global */
std::string scope;
String scope;
/*! \brief Alignment requirement of data pointer in bytes. */
int data_alignment;
/*!
Expand Down Expand Up @@ -134,7 +134,7 @@ class Buffer : public ObjectRef {
// User can specify data_alignment and offset_factor to be 0
// A default value will be picked.
TVM_DLL Buffer(Var ptr, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
PrimExpr elem_offset, std::string name, std::string scope, int data_alignment,
PrimExpr elem_offset, String name, String scope, int data_alignment,
int offset_factor, BufferType buffer_type);

/*!
Expand Down Expand Up @@ -187,7 +187,7 @@ class Buffer : public ObjectRef {
* \sa Buffer for complete constructor.
*/
TVM_DLL Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
std::string name = "buffer");
String name = "buffer");

/*!
* \brief Base node for data producers.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/data_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class LayoutAxis {
class LayoutNode : public Object {
public:
/*! \brief string representation of layout, "" for scalar. */
std::string name;
String name;
/*! \brief specify each axis of the layout,
* in which the variable name is the name of the axis.
* The IterVar's extent indicates the size of the axis,
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ using FloatImmNode = tvm::FloatImmNode;
class StringImmNode : public PrimExprNode {
public:
/*! \brief The constant value content. */
std::string value;
String value;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
Expand All @@ -74,7 +74,7 @@ class StringImmNode : public PrimExprNode {
*/
class StringImm : public PrimExpr {
public:
TVM_DLL StringImm(std::string value);
TVM_DLL StringImm(String value);
TVM_DEFINE_OBJECT_REF_METHODS(StringImm, PrimExpr, StringImmNode);
};

Expand Down Expand Up @@ -889,7 +889,7 @@ class CallNode : public PrimExprNode {
PureIntrinsic = 5
};
/*! \brief The name of the function/intrinsic. */
std::string name;
String name;
/*! \brief The arguments. */
Array<PrimExpr> args;
/*! \brief Type of calls. */
Expand Down Expand Up @@ -958,7 +958,7 @@ class Call : public PrimExpr {
public:
using CallType = CallNode::CallType;

TVM_DLL Call(DataType dtype, std::string name, Array<PrimExpr> args, CallType call_type);
TVM_DLL Call(DataType dtype, String name, Array<PrimExpr> args, CallType call_type);
TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode);
};

Expand Down
4 changes: 2 additions & 2 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class AttrStmtNode : public StmtNode {
/*! \brief this is attribute about certain node */
ObjectRef node;
/*! \brief the type key of the attribute */
std::string attr_key;
String attr_key;
/*! \brief The attribute value, value is well defined at current scope. */
PrimExpr value;
/*! \brief The body statement to be executed */
Expand Down Expand Up @@ -144,7 +144,7 @@ class AttrStmtNode : public StmtNode {
*/
class AttrStmt : public Stmt {
public:
TVM_DLL AttrStmt(ObjectRef node, std::string type_key, PrimExpr value, Stmt body);
TVM_DLL AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body);

TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode);
};
Expand Down
10 changes: 5 additions & 5 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ using tvm::transform::Sequential;
*/
TVM_DLL Pass CreatePrimFuncPass(
const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
int opt_level, const std::string& name, const tvm::Array<runtime::String>& required);
int opt_level, String name, tvm::Array<String> required);

/*!
* \brief Inject prefetch instructions into stmt.
Expand Down Expand Up @@ -88,7 +88,7 @@ TVM_DLL Pass StorageFlatten(int cache_line_size, bool create_bound_attribute = f
* Expr pad_value)
* \return The pass.
*/
TVM_DLL Pass InjectCopyIntrin(std::string pragma_key, runtime::PackedFunc fintrin);
TVM_DLL Pass InjectCopyIntrin(String pragma_key, runtime::PackedFunc fintrin);

/*!
* \brief Detect and insert sync points to co-processor.
Expand All @@ -103,7 +103,7 @@ TVM_DLL Pass CoProcSync();
* \param attr_key The attribute key to be checked.
* \return The pass.
*/
TVM_DLL Pass LiftAttrScope(std::string attr_key);
TVM_DLL Pass LiftAttrScope(String attr_key);

/*!
* \brief partition loops in the stmt.
Expand Down Expand Up @@ -222,7 +222,7 @@ TVM_DLL Pass MakePackedAPI(int num_unpacked_args);
*
* \return The pass.
*/
TVM_DLL Pass RemapThreadAxis(Map<runtime::String, IterVar> axis_map);
TVM_DLL Pass RemapThreadAxis(Map<String, IterVar> axis_map);

/*!
* \brief Lower custom datatypes.
Expand Down Expand Up @@ -260,7 +260,7 @@ TVM_DLL Pass SkipAssert();
* \param storage_scope The storage scope considered.
* \return The pass.
*/
TVM_DLL Pass ThreadSync(std::string storage_scope);
TVM_DLL Pass ThreadSync(String storage_scope);

/*!
* \brief Lower cross thread alleduce.
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/tir/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ class IterVarNode : public Object {
* \brief additional tag on the iteration variable,
* set this if this is binded already to a known thread tag.
*/
std::string thread_tag;
String thread_tag;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dom", &dom);
Expand Down Expand Up @@ -278,7 +278,7 @@ class IterVarNode : public Object {
*/
class IterVar : public ObjectRef {
public:
TVM_DLL IterVar(Range dom, Var var, IterVarType iter_type, std::string thread_tag = "");
TVM_DLL IterVar(Range dom, Var var, IterVarType iter_type, String thread_tag = "");
/*!
* \return the corresponding var in the IterVar.
*/
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def _convert(item, nodes):
"relay.IncompleteType": _rename("IncompleteType"),
"relay.TypeRelation": _rename("TypeRelation"),
"relay.TypeCall": _rename("TypeCall"),
"relay.Constructor": [_update_from_std_str("name_hint")],
"relay.Module": _rename("IRModule"),
"relay.SourceName": _rename("SourceName"),
"relay.Span": _rename("Span"),
Expand All @@ -137,6 +138,11 @@ def _convert(item, nodes):
# TIR
"Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")],
"SizeVar": [_update_tir_var("tir.SizeVar"), _update_from_std_str("name")],
"StringImm": [_update_from_std_str("value")],
"Call": [_update_from_std_str("name")],
"AttrStmt": [_update_from_std_str("attr_key")],
"Layout": [_update_from_std_str("name")],
"Buffer": [_update_from_std_str("name"), _update_from_std_str("scope")],
}
return create_updater(node_map, "0.6", "0.7")

Expand Down
5 changes: 2 additions & 3 deletions src/ir/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,8 @@ ObjectPtr<Object> CreateOp(const std::string& name) {
return Op2ObjectPtr::Get(op);
}

TVM_REGISTER_NODE_TYPE(OpNode).set_creator(CreateOp).set_repr_bytes([](const Object* n) {
return static_cast<const OpNode*>(n)->name;
});
TVM_REGISTER_NODE_TYPE(OpNode).set_creator(CreateOp).set_repr_bytes(
[](const Object* n) -> std::string { return static_cast<const OpNode*>(n)->name; });

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<OpNode>([](const ObjectRef& ref, ReprPrinter* p) {
Expand Down
5 changes: 2 additions & 3 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -376,16 +376,15 @@ IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) c
}

Pass CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level, const String& name,
const tvm::Array<runtime::String>& required) {
int opt_level, String name, tvm::Array<String> required) {
PassInfo pass_info = PassInfo(opt_level, name, required);
return ModulePass(pass_func, pass_info);
}

TVM_REGISTER_NODE_TYPE(PassInfoNode);

TVM_REGISTER_GLOBAL("transform.PassInfo")
.set_body_typed([](int opt_level, String name, tvm::Array<runtime::String> required) {
.set_body_typed([](int opt_level, String name, tvm::Array<String> required) {
return PassInfo(opt_level, name, required);
});

Expand Down
Loading

0 comments on commit 8578096

Please sign in to comment.