diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index e5259fbc0da8..39e68b8333ff 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -280,6 +280,9 @@ def set_params(self, params): def set_attribute(self, name, ref): return _expr.FunctionSetAttr(self, name, ref) + def get_attribute(self, name): + return _expr.FunctionGetAttr(self, name) + @register_relay_node class Call(ExprWithOp): diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 89395bb742c1..0292a6c2bb05 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -360,6 +360,12 @@ TVM_REGISTER_GLOBAL("relay._expr.FunctionSetAttr") return FunctionSetAttr(func, name, ref); }); +TVM_REGISTER_GLOBAL("relay._expr.FunctionGetAttr") +.set_body_typed( + [](Function func, std::string name) { + return FunctionGetAttr(func, name); +}); + TVM_REGISTER_GLOBAL("relay._make.Any") .set_body_typed([]() { return Any::make(); }); diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index bdda72ca8702..b7d7eb9f389c 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -168,10 +168,12 @@ def test_function(): body = relay.Tuple(tvm.convert([])) type_params = tvm.convert([]) fn = relay.Function(params, body, ret_type, type_params) + fn = fn.set_attribute("test_attribute", tvm.tir.StringImm("value")) assert fn.params == params assert fn.body == body assert fn.type_params == type_params assert fn.span == None + assert fn.get_attribute("test_attribute") == "value" str(fn) check_json_roundtrip(fn)