Skip to content

Commit

Permalink
[Relay] Expose FunctionGetAttr to Python (#4905)
Browse files Browse the repository at this point in the history
* [Relay] Expose FunctionGetAttr to Python

* add test

Co-authored-by: Jon Soifer <[email protected]>
  • Loading branch information
soiferj and jonso4 authored Feb 18, 2020
1 parent 9d64654 commit 41835d1
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 0 deletions.
3 changes: 3 additions & 0 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,9 @@ def set_params(self, params):
def set_attribute(self, name, ref):
return _expr.FunctionSetAttr(self, name, ref)

def get_attribute(self, name):
return _expr.FunctionGetAttr(self, name)


@register_relay_node
class Call(ExprWithOp):
Expand Down
6 changes: 6 additions & 0 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,12 @@ TVM_REGISTER_GLOBAL("relay._expr.FunctionSetAttr")
return FunctionSetAttr(func, name, ref);
});

TVM_REGISTER_GLOBAL("relay._expr.FunctionGetAttr")
.set_body_typed(
[](Function func, std::string name) {
return FunctionGetAttr(func, name);
});

TVM_REGISTER_GLOBAL("relay._make.Any")
.set_body_typed([]() { return Any::make(); });

Expand Down
2 changes: 2 additions & 0 deletions tests/python/relay/test_ir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,12 @@ def test_function():
body = relay.Tuple(tvm.convert([]))
type_params = tvm.convert([])
fn = relay.Function(params, body, ret_type, type_params)
fn = fn.set_attribute("test_attribute", tvm.tir.StringImm("value"))
assert fn.params == params
assert fn.body == body
assert fn.type_params == type_params
assert fn.span == None
assert fn.get_attribute("test_attribute") == "value"
str(fn)
check_json_roundtrip(fn)

Expand Down

0 comments on commit 41835d1

Please sign in to comment.