Skip to content

Commit

Permalink
[RELAY] Fix segfault in pretty print when ObjectRef is null (apache#5681
Browse files Browse the repository at this point in the history
)

* [RELAY] Fix segfault in pretty print when ObjectRef is null

Encountered when pretty printing module with function attribute equal to NullValue<ObjectRef>().

Change-Id: I2e7b304859f03038730ba9c3b9db41ebd3e1fbb5

* Add test case

Change-Id: I579b20da3f5d49054823392be80aaf78a055f596
  • Loading branch information
lhutton1 authored and Trevor Morris committed Jun 18, 2020
1 parent c86e9bf commit 40d66b1
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ Doc RelayTextPrinter::PrintScope(const ObjectRef& node) {
}

Doc RelayTextPrinter::PrintFinal(const ObjectRef& node) {
if (node->IsInstance<BaseFuncNode>() && !node->IsInstance<relay::FunctionNode>()) {
if (node.defined() && node->IsInstance<BaseFuncNode>() &&
!node->IsInstance<relay::FunctionNode>()) {
// Temporarily skip non-relay functions.
// TODO(tvm-team) enhance the code to work for all functions
} else if (node.as<ExprNode>()) {
Expand All @@ -105,8 +106,8 @@ Doc RelayTextPrinter::PrintFinal(const ObjectRef& node) {
}

Doc RelayTextPrinter::Print(const ObjectRef& node, bool meta, bool try_inline) {
bool is_non_relay_func =
node->IsInstance<BaseFuncNode>() && !node->IsInstance<relay::FunctionNode>();
bool is_non_relay_func = node.defined() && node->IsInstance<BaseFuncNode>() &&
!node->IsInstance<relay::FunctionNode>();
if (node.as<ExprNode>() && !is_non_relay_func) {
return PrintExpr(Downcast<Expr>(node), meta, try_inline);
} else if (node.as<TypeNode>()) {
Expand Down
10 changes: 10 additions & 0 deletions tests/python/relay/test_ir_text_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,15 @@ def @main[A]() -> fn (A, List[A]) -> List[A] {
assert main_def_str.strip() in mod_str


def test_null_attribute():
x = relay.var("x")
y = relay.var("y")
z = relay.Function([x], y)
z = z.with_attr("TestAttribute", None)
txt = astext(z)
assert "TestAttribute=(nullptr)" in txt


if __name__ == "__main__":
do_print[0] = True
test_lstm()
Expand All @@ -262,3 +271,4 @@ def @main[A]() -> fn (A, List[A]) -> List[A] {
test_variable_name()
test_call_node_order()
test_unapplied_constructor()
test_null_attribute()

0 comments on commit 40d66b1

Please sign in to comment.