Skip to content

Commit

Permalink
add const char* constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Apr 10, 2020
1 parent bc5b203 commit 936ae45
Show file tree
Hide file tree
Showing 20 changed files with 29 additions and 42 deletions.
10 changes: 9 additions & 1 deletion include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,15 @@ class String : public ObjectRef {
* \note If user passes const reference, it will trigger copy. If it's rvalue,
* it will be moved into other.
*/
explicit String(std::string other);
String(std::string other); // NOLINT(*)

/*!
* \brief Construct a new String object
*
* \param other a char array.
*/
String(const char* other) // NOLINT(*)
: String(std::string(other)) {}

/*!
* \brief Change the value the reference object points to.
Expand Down
6 changes: 0 additions & 6 deletions src/node/serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,6 @@ class JSONAttrGetter : public AttrVisitor {
node_->data.push_back(
node_index_->at(const_cast<Object*>(kv.second.get())));
}
} else if (node->IsInstance<StringObj>()) {
node_->data.push_back(node_index_->at(node));
} else {
// recursively index normal object.
reflection_->VisitAttrs(node, this);
Expand Down Expand Up @@ -375,10 +373,6 @@ class JSONAttrSetter : public AttrVisitor {
n->data[node_->keys[i]]
= ObjectRef(node_list_->at(node_->data[i]));
}
} else if (node->IsInstance<StringObj>()) {
StringObj* n = static_cast<StringObj*>(node);
auto saved = node_list_->at(node_->data[0]);
saved = runtime::GetObjectPtr<StringObj>(n);
} else {
reflection_->VisitAttrs(node, this);
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ class RelayBuildModule : public runtime::ModuleNode {
}

Array<Pass> pass_seqs;
Array<runtime::String> entry_functions{runtime::String("main")};
Array<runtime::String> entry_functions{"main"};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));

// Run all dialect legalization passes.
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/alter_op_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,7 @@ Pass AlterOpLayout() {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::alter_op_layout::AlterOpLayout(f));
};
return CreateFunctionPass(pass_func, 3, "AlterOpLayout",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 3, "AlterOpLayout", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.AlterOpLayout")
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ Pass AnnotateTarget(const std::string& target) {
return Downcast<Function>(relay::annotate_target::AnnotateTarget(f, target));
};
auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc",
{runtime::String("InferType")});
{"InferType"});
return transform::Sequential({func_pass, InferType()}, "AnnotateTarget");
}

Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/canonicalize_cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ Pass CanonicalizeCast() {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CanonicalizeCast(f));
};
return CreateFunctionPass(pass_func, 3, "CanonicalizeCast",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast")
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/canonicalize_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ Pass CanonicalizeOps() {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CanonicalizeOps(f));
};
return CreateFunctionPass(pass_func, 3, "CanonicalizeOps",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 3, "CanonicalizeOps", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeOps")
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/combine_parallel_conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,7 @@ Pass CombineParallelConv2D(uint64_t min_num_branches) {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CombineParallelConv2D(f, min_num_branches));
};
return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.CombineParallelConv2D")
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/combine_parallel_dense.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ Pass CombineParallelDense(uint64_t min_num_branches) {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CombineParallelDense(f, min_num_branches));
};
return CreateFunctionPass(pass_func, 4, "CombineParallelDense",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 4, "CombineParallelDense", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.CombineParallelDense")
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/combine_parallel_op_batch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,7 @@ Pass CombineParallelOpBatch(const std::string& op_name,
batch_op_name,
min_num_branches));
};
return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.CombineParallelOpBatch")
Expand Down
4 changes: 1 addition & 3 deletions src/relay/transforms/convert_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,7 @@ Pass ConvertLayout(const std::string& desired_layout) {
return Downcast<Function>(relay::convert_op_layout::ConvertLayout(f, desired_layout));
};
return CreateFunctionPass(
pass_func, 3, "ConvertLayout",
{runtime::String("InferType"),
runtime::String("CanonicalizeOps")});
pass_func, 3, "ConvertLayout", {"InferType", "CanonicalizeOps"});
}

TVM_REGISTER_GLOBAL("relay._transform.ConvertLayout").set_body_typed(ConvertLayout);
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/device_annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -573,8 +573,7 @@ Pass RewriteAnnotatedOps(int fallback_device) {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::RewriteAnnotatedOps(f, fallback_device));
};
return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.RewriteDeviceAnnotation")
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/eliminate_common_subexpr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ Pass EliminateCommonSubexpr(PackedFunc fskip) {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(EliminateCommonSubexpr(f, fskip));
};
return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr")
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/fast_math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ Pass FastMath() {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(FastMath(f));
};
return CreateFunctionPass(pass_func, 4, "FastMath",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 4, "FastMath", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.FastMath")
Expand Down
6 changes: 2 additions & 4 deletions src/relay/transforms/fold_scale_axis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -960,8 +960,7 @@ Pass ForwardFoldScaleAxis() {
return Downcast<Function>(
relay::fold_scale_axis::ForwardFoldScaleAxis(f));
};
return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis")
Expand All @@ -973,8 +972,7 @@ Pass BackwardFoldScaleAxis() {
return Downcast<Function>(
relay::fold_scale_axis::BackwardFoldScaleAxis(f));
};
return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis")
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -980,8 +980,7 @@ Pass FuseOps(int fuse_opt_level) {
int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
return Downcast<Function>(FuseOps(f, opt_level, m));
};
return CreateFunctionPass(pass_func, 1, "FuseOps",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 1, "FuseOps", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.FuseOps")
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/legalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ Pass Legalize(const std::string& legalize_map_attr_name) {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name));
};
return CreateFunctionPass(pass_func, 1, "Legalize", {runtime::String("InferType")});
return CreateFunctionPass(pass_func, 1, "Legalize", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.Legalize").set_body_typed(Legalize);
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/simplify_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,7 @@ Pass SimplifyInference() {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(SimplifyInference(f));
};
return CreateFunctionPass(pass_func, 0, "SimplifyInference",
{runtime::String("InferType")});
return CreateFunctionPass(pass_func, 0, "SimplifyInference", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference")
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/container_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ TEST(String, empty) {
using namespace std;
String s{"hello"};
CHECK_EQ(s.empty(), false);
s = "";
s = std::string("");
CHECK_EQ(s.empty(), true);
}

Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/test_ir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@

def check_json_roundtrip(node):
json_str = tvm.ir.save_json(node)
print(node)
back = tvm.ir.load_json(json_str)
print(back)
assert tvm.ir.structural_equal(back, node, map_free_vars=True)


Expand Down Expand Up @@ -99,11 +97,13 @@ def test_function():
type_params = tvm.runtime.convert([])
fn = relay.Function(params, body, ret_type, type_params)
fn = fn.with_attr("test_attribute", "value")
fn = fn.with_attr("test_attribute1", "value1")
assert fn.params == params
assert fn.body == body
assert fn.type_params == type_params
assert fn.span == None
assert fn.attrs["test_attribute"] == "value"
assert fn.attrs["test_attribute1"] == "value1"
str(fn)
check_json_roundtrip(fn)

Expand Down

0 comments on commit 936ae45

Please sign in to comment.