diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 50b406b2c030..083f87f89bc9 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -360,7 +360,15 @@ class String : public ObjectRef { * \note If user passes const reference, it will trigger copy. If it's rvalue, * it will be moved into other. */ - explicit String(std::string other); + String(std::string other); // NOLINT(*) + + /*! + * \brief Construct a new String object + * + * \param other a char array. + */ + String(const char* other) // NOLINT(*) + : String(std::string(other)) {} /*! * \brief Change the value the reference object points to. diff --git a/src/node/serialization.cc b/src/node/serialization.cc index cae9fdbffb02..ee6072d77c1c 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -280,8 +280,6 @@ class JSONAttrGetter : public AttrVisitor { node_->data.push_back( node_index_->at(const_cast(kv.second.get()))); } - } else if (node->IsInstance()) { - node_->data.push_back(node_index_->at(node)); } else { // recursively index normal object. reflection_->VisitAttrs(node, this); @@ -375,10 +373,6 @@ class JSONAttrSetter : public AttrVisitor { n->data[node_->keys[i]] = ObjectRef(node_list_->at(node_->data[i])); } - } else if (node->IsInstance()) { - StringObj* n = static_cast(node); - auto saved = node_list_->at(node_->data[0]); - saved = runtime::GetObjectPtr(n); } else { reflection_->VisitAttrs(node, this); } diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index bc0685b5d905..6e6faf9274ef 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -272,7 +272,7 @@ class RelayBuildModule : public runtime::ModuleNode { } Array pass_seqs; - Array entry_functions{runtime::String("main")}; + Array entry_functions{"main"}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); // Run all dialect legalization passes. diff --git a/src/relay/transforms/alter_op_layout.cc b/src/relay/transforms/alter_op_layout.cc index 59cf9f98288f..aab0b3a30a7c 100644 --- a/src/relay/transforms/alter_op_layout.cc +++ b/src/relay/transforms/alter_op_layout.cc @@ -125,8 +125,7 @@ Pass AlterOpLayout() { [=](Function f, IRModule m, PassContext pc) { return Downcast(relay::alter_op_layout::AlterOpLayout(f)); }; - return CreateFunctionPass(pass_func, 3, "AlterOpLayout", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 3, "AlterOpLayout", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.AlterOpLayout") diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 65e17fcefddb..44ef35a285f5 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -226,7 +226,7 @@ Pass AnnotateTarget(const std::string& target) { return Downcast(relay::annotate_target::AnnotateTarget(f, target)); }; auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc", - {runtime::String("InferType")}); + {"InferType"}); return transform::Sequential({func_pass, InferType()}, "AnnotateTarget"); } diff --git a/src/relay/transforms/canonicalize_cast.cc b/src/relay/transforms/canonicalize_cast.cc index 4b35ba219b67..ebcbd578b5f0 100644 --- a/src/relay/transforms/canonicalize_cast.cc +++ b/src/relay/transforms/canonicalize_cast.cc @@ -133,8 +133,7 @@ Pass CanonicalizeCast() { [=](Function f, IRModule m, PassContext pc) { return Downcast(CanonicalizeCast(f)); }; - return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast") diff --git a/src/relay/transforms/canonicalize_ops.cc b/src/relay/transforms/canonicalize_ops.cc index 44140a902a2f..1d3111b29d7d 100644 --- a/src/relay/transforms/canonicalize_ops.cc +++ b/src/relay/transforms/canonicalize_ops.cc @@ -74,8 +74,7 @@ Pass CanonicalizeOps() { [=](Function f, IRModule m, PassContext pc) { return Downcast(CanonicalizeOps(f)); }; - return CreateFunctionPass(pass_func, 3, "CanonicalizeOps", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 3, "CanonicalizeOps", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeOps") diff --git a/src/relay/transforms/combine_parallel_conv2d.cc b/src/relay/transforms/combine_parallel_conv2d.cc index 3c8eea04d28f..af6b1353f5ac 100644 --- a/src/relay/transforms/combine_parallel_conv2d.cc +++ b/src/relay/transforms/combine_parallel_conv2d.cc @@ -220,8 +220,7 @@ Pass CombineParallelConv2D(uint64_t min_num_branches) { [=](Function f, IRModule m, PassContext pc) { return Downcast(CombineParallelConv2D(f, min_num_branches)); }; - return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.CombineParallelConv2D") diff --git a/src/relay/transforms/combine_parallel_dense.cc b/src/relay/transforms/combine_parallel_dense.cc index 2dc8321e517b..1278020ac735 100644 --- a/src/relay/transforms/combine_parallel_dense.cc +++ b/src/relay/transforms/combine_parallel_dense.cc @@ -80,8 +80,7 @@ Pass CombineParallelDense(uint64_t min_num_branches) { [=](Function f, IRModule m, PassContext pc) { return Downcast(CombineParallelDense(f, min_num_branches)); }; - return CreateFunctionPass(pass_func, 4, "CombineParallelDense", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 4, "CombineParallelDense", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.CombineParallelDense") diff --git a/src/relay/transforms/combine_parallel_op_batch.cc b/src/relay/transforms/combine_parallel_op_batch.cc index f63f169be408..361565ef11d7 100644 --- a/src/relay/transforms/combine_parallel_op_batch.cc +++ b/src/relay/transforms/combine_parallel_op_batch.cc @@ -193,8 +193,7 @@ Pass CombineParallelOpBatch(const std::string& op_name, batch_op_name, min_num_branches)); }; - return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.CombineParallelOpBatch") diff --git a/src/relay/transforms/convert_layout.cc b/src/relay/transforms/convert_layout.cc index d43a0851e099..dbb2c38e3f27 100644 --- a/src/relay/transforms/convert_layout.cc +++ b/src/relay/transforms/convert_layout.cc @@ -133,9 +133,7 @@ Pass ConvertLayout(const std::string& desired_layout) { return Downcast(relay::convert_op_layout::ConvertLayout(f, desired_layout)); }; return CreateFunctionPass( - pass_func, 3, "ConvertLayout", - {runtime::String("InferType"), - runtime::String("CanonicalizeOps")}); + pass_func, 3, "ConvertLayout", {"InferType", "CanonicalizeOps"}); } TVM_REGISTER_GLOBAL("relay._transform.ConvertLayout").set_body_typed(ConvertLayout); diff --git a/src/relay/transforms/device_annotation.cc b/src/relay/transforms/device_annotation.cc index 9955ef6ee7d2..908ba87a8c52 100644 --- a/src/relay/transforms/device_annotation.cc +++ b/src/relay/transforms/device_annotation.cc @@ -573,8 +573,7 @@ Pass RewriteAnnotatedOps(int fallback_device) { [=](Function f, IRModule m, PassContext pc) { return Downcast(relay::RewriteAnnotatedOps(f, fallback_device)); }; - return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.RewriteDeviceAnnotation") diff --git a/src/relay/transforms/eliminate_common_subexpr.cc b/src/relay/transforms/eliminate_common_subexpr.cc index 696e83a7db53..68c59f5ea2ef 100644 --- a/src/relay/transforms/eliminate_common_subexpr.cc +++ b/src/relay/transforms/eliminate_common_subexpr.cc @@ -91,8 +91,7 @@ Pass EliminateCommonSubexpr(PackedFunc fskip) { [=](Function f, IRModule m, PassContext pc) { return Downcast(EliminateCommonSubexpr(f, fskip)); }; - return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr") diff --git a/src/relay/transforms/fast_math.cc b/src/relay/transforms/fast_math.cc index 668982e561e8..8234dea5e075 100644 --- a/src/relay/transforms/fast_math.cc +++ b/src/relay/transforms/fast_math.cc @@ -70,8 +70,7 @@ Pass FastMath() { [=](Function f, IRModule m, PassContext pc) { return Downcast(FastMath(f)); }; - return CreateFunctionPass(pass_func, 4, "FastMath", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 4, "FastMath", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.FastMath") diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index 11325f6526b8..cfe74bfd8ef1 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -960,8 +960,7 @@ Pass ForwardFoldScaleAxis() { return Downcast( relay::fold_scale_axis::ForwardFoldScaleAxis(f)); }; - return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis") @@ -973,8 +972,7 @@ Pass BackwardFoldScaleAxis() { return Downcast( relay::fold_scale_axis::BackwardFoldScaleAxis(f)); }; - return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis") diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index cdd29394a204..f646042962f0 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -980,8 +980,7 @@ Pass FuseOps(int fuse_opt_level) { int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; return Downcast(FuseOps(f, opt_level, m)); }; - return CreateFunctionPass(pass_func, 1, "FuseOps", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 1, "FuseOps", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.FuseOps") diff --git a/src/relay/transforms/legalize.cc b/src/relay/transforms/legalize.cc index d13cc493650a..0b5c671ab7f6 100644 --- a/src/relay/transforms/legalize.cc +++ b/src/relay/transforms/legalize.cc @@ -101,7 +101,7 @@ Pass Legalize(const std::string& legalize_map_attr_name) { [=](Function f, IRModule m, PassContext pc) { return Downcast(relay::legalize::Legalize(f, legalize_map_attr_name)); }; - return CreateFunctionPass(pass_func, 1, "Legalize", {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 1, "Legalize", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.Legalize").set_body_typed(Legalize); diff --git a/src/relay/transforms/simplify_inference.cc b/src/relay/transforms/simplify_inference.cc index b33799a26b43..d349fdddeeea 100644 --- a/src/relay/transforms/simplify_inference.cc +++ b/src/relay/transforms/simplify_inference.cc @@ -204,8 +204,7 @@ Pass SimplifyInference() { [=](Function f, IRModule m, PassContext pc) { return Downcast(SimplifyInference(f)); }; - return CreateFunctionPass(pass_func, 0, "SimplifyInference", - {runtime::String("InferType")}); + return CreateFunctionPass(pass_func, 0, "SimplifyInference", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference") diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index f1198e727401..063247db09b6 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -261,7 +261,7 @@ TEST(String, empty) { using namespace std; String s{"hello"}; CHECK_EQ(s.empty(), false); - s = ""; + s = std::string(""); CHECK_EQ(s.empty(), true); } diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index f182cbaae631..5a71023e5d60 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -25,9 +25,7 @@ def check_json_roundtrip(node): json_str = tvm.ir.save_json(node) - print(node) back = tvm.ir.load_json(json_str) - print(back) assert tvm.ir.structural_equal(back, node, map_free_vars=True) @@ -99,11 +97,13 @@ def test_function(): type_params = tvm.runtime.convert([]) fn = relay.Function(params, body, ret_type, type_params) fn = fn.with_attr("test_attribute", "value") + fn = fn.with_attr("test_attribute1", "value1") assert fn.params == params assert fn.body == body assert fn.type_params == type_params assert fn.span == None assert fn.attrs["test_attribute"] == "value" + assert fn.attrs["test_attribute1"] == "value1" str(fn) check_json_roundtrip(fn)