Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PYTHON] Enhance with_attr API, cleanup MakeAPILegacy in testcases #5335

Merged
merged 1 commit into from
Apr 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions python/tvm/ir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
"""Function defintiions."""
from enum import IntEnum
import tvm.runtime

from .expr import RelayExpr
from . import _ffi_api

Expand All @@ -34,3 +36,32 @@ def attrs(self):
"""Return the attrs member of the function.
"""
return _ffi_api.BaseFunc_Attrs(self)

def with_attr(self, attr_key_or_dict, attr_value=None):
"""Create a new copy of the function and update the attribute.

Parameters
----------
attr_key_or_dict : Union[str, dict]
The attribute key to use or a dict containing multiple key value pairs.

attr_value : Object
The new attribute value.

Returns
-------
func : Function
A new copy of the function
"""
# make sure we first copy so that we can safely do copy on write
# for multiple updates.
res = _ffi_api.BaseFuncCopy(self)

if isinstance(attr_key_or_dict, dict):
for key, val in attr_key_or_dict.items():
res = _ffi_api.BaseFuncWithAttr(
res._move(), key, tvm.runtime.convert(val))
return res

return _ffi_api.BaseFuncWithAttr(
res._move(), attr_key_or_dict, tvm.runtime.convert(attr_value))
19 changes: 0 additions & 19 deletions python/tvm/relay/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,3 @@ def __call__(self, *args):
Arguments.
"""
return Call(self, args, None, None)

def with_attr(self, attr_key, attr_value):
"""Create a new copy of the function and update the attribute

Parameters
----------
attr_key : str
The attribute key to use.

attr_value : Object
The new attribute value.

Returns
-------
func : Function
A new copy of the function
"""
return _ffi_api.FunctionWithAttr(
self, attr_key, convert(attr_value))
37 changes: 0 additions & 37 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,41 +168,4 @@ def compare_derivative(j, n_der, grad):
x_name, grad.shape, dist, max_diff, avg_diff)


def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
"""Legacy adapter to build a Module from statement.

Used for migrating existing test cases only.

Parameters
----------
stmt: Stmt
The input statement.

name: str
The name of the funciton.

args: list of Buffer or Vars
The function arguments

num_unpacked_args: int
Number of unpacked arguments.

nolias: bool
Whether allow noalias.

Returns
-------
mod : IRModule
The created IRModule.
"""
assert num_unpacked_args == 0
f = tvm.tir.PrimFunc(args, stmt).with_attr(
"global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True)
if noalias:
f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule({name: f})
return mod


tvm._ffi._init_api("testing", __name__)
19 changes: 0 additions & 19 deletions python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,3 @@ def __init__(self,

self.__init_handle_by_constructor__(
_ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs)

def with_attr(self, attr_key, attr_value):
"""Create a new copy of the function and update the attribute

Parameters
----------
attr_key : str
The attribute key to use.

attr_value : Object
The new attribute value.

Returns
-------
func : Function
A new copy of the function
"""
return _ffi_api.PrimFuncWithAttr(
self, attr_key, tvm.runtime.convert(attr_value))
26 changes: 26 additions & 0 deletions src/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@
*/
#include <tvm/runtime/registry.h>
#include <tvm/ir/function.h>
// NOTE: reverse dependency on relay, tir/
// These dependencies do not happen at the interface-level,
// and are only used in minimum cases where they are clearly marked.
//
// Rationale: We calls into the type specific WithAttr function
#include <tvm/tir/function.h>
#include <tvm/relay/function.h>


namespace tvm {

Expand All @@ -31,4 +39,22 @@ TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs")
return func->attrs;
});

TVM_REGISTER_GLOBAL("ir.BaseFuncCopy")
.set_body_typed([](BaseFunc func) {
return func;
});

TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr")
.set_body_typed([](BaseFunc func, std::string key, ObjectRef value) -> BaseFunc {
if (func->IsInstance<tir::PrimFuncNode>()) {
return WithAttr(Downcast<tir::PrimFunc>(std::move(func)), key, value);
} else if (func->IsInstance<relay::FunctionNode>()) {
return WithAttr(Downcast<relay::Function>(std::move(func)), key, value);
} else {
LOG(FATAL) << "Do not support function type " << func->GetTypeKey();
return func;
}
});


} // namespace tvm
8 changes: 7 additions & 1 deletion src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,13 +362,19 @@ IRModule IRModule::FromExpr(
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
auto mod = IRModule(global_funcs, type_definitions);
BaseFunc func;
std::string gv_name = "main";

if (auto* func_node = expr.as<BaseFuncNode>()) {
func = GetRef<BaseFunc>(func_node);
if (auto opt = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
gv_name = opt.value();
}

} else {
func = relay::Function(relay::FreeVars(expr), expr, Type(),
relay::FreeTypeVars(expr, mod), {});
}
auto main_gv = GlobalVar("main");
auto main_gv = GlobalVar(gv_name);
mod->Add(main_gv, func);
return mod;
}
Expand Down
6 changes: 0 additions & 6 deletions src/relay/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,5 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< node->attrs << ")";
});

TVM_REGISTER_GLOBAL("relay.ir.FunctionWithAttr")
.set_body_typed(
[](Function func, std::string name, ObjectRef ref) {
return WithAttr(std::move(func), name, ref);
});

} // namespace relay
} // namespace tvm
6 changes: 0 additions & 6 deletions src/tir/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,5 @@ TVM_REGISTER_GLOBAL("tir.PrimFunc")
return PrimFunc(params, body, ret_type, buffer_map, attrs);
});


TVM_REGISTER_GLOBAL("tir.PrimFuncWithAttr")
.set_body_typed([](PrimFunc func, std::string name, ObjectRef ref) {
return WithAttr(std::move(func), name, ref);
});

} // namespace tir
} // namespace tvm
3 changes: 2 additions & 1 deletion tests/python/unittest/test_runtime_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def test_dltensor_compatible():
A[i + 1] = A[i] + 1
stmt = ib.get()

mod = tvm.testing.MakeAPILegacy(stmt, "arange", [Ab], 0, True)
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "arange"))
f = tvm.build(mod, target="stackvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
aview = MyTensorView(a)
Expand Down
7 changes: 5 additions & 2 deletions tests/python/unittest/test_runtime_module_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,11 @@ def save_object(names):
tvm.tir.Store(Ab.data,
tvm.tir.Load(dtype, Ab.data, i) + 1,
i + 1))
m = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
m = tvm.driver.build(m, target="llvm")
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr(
"global_symbol", "main")
)
m = tvm.driver.build(mod, target="llvm")
for name in names:
m.save(name)

Expand Down
12 changes: 8 additions & 4 deletions tests/python/unittest/test_target_codegen_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,11 @@ def test_llvm_intrin():
"int32", "prefetch", args, tvm.tir.Call.Intrinsic, None, 0)))
body = ib.get()

func = tvm.testing.MakeAPILegacy(body, "prefetch", [A], 0, True)
fcode = tvm.build(func, None, "llvm")
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([A], body).with_attr(
"global_symbol", "prefetch")
)
fcode = tvm.build(mod, None, "llvm")


def test_llvm_overloaded_intrin():
Expand Down Expand Up @@ -111,8 +114,9 @@ def test_llvm_lookup_intrin():
x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.v8i8", tvm.tir.const(1, 'uint32'), A[z])
ib.emit(x)
body = ib.get()
func = tvm.testing.MakeAPILegacy(body, "ctpop", [A], 0, True)
fcode = tvm.build(func, None, "llvm")
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "main"))
fcode = tvm.build(mod, None, "llvm")


def test_llvm_large_uintimm():
Expand Down
23 changes: 8 additions & 15 deletions tests/python/unittest/test_target_codegen_static_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,6 @@
import numpy as np


def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
"""Legacy adapter to create a API"""
f = tvm.tir.PrimFunc(args, stmt).with_attr(
"global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True)
if noalias:
f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule.from_expr(f)
return tvm.tir.transform.MakePackedAPI()(mod)


def test_static_callback():
dtype = 'int64'
n = te.size_var('n')
Expand All @@ -44,8 +33,11 @@ def test_static_callback():
with ib.for_range(0, n, "i", for_type="parallel") as i:
A[i] = A[i] + 1
stmt = ib.get()
fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
f = tvm.driver.build(fapi, target="llvm")

mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "ramp")
)
f = tvm.driver.build(mod, target="llvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
f(a)
Expand All @@ -67,8 +59,9 @@ def test_cb(sh, A):
return sh

stmt = ib.get()
fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
f = tvm.driver.build(fapi, target="llvm")
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "ramp"))
f = tvm.driver.build(mod, target="llvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)

Expand Down
34 changes: 14 additions & 20 deletions tests/python/unittest/test_target_codegen_vm_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,6 @@ def run_jit(fapi, check):
s = f.get_source()
check(f)


def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
"""Legacy adapter to create a API"""
f = tvm.tir.PrimFunc(args, stmt).with_attr(
"global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True)
if noalias:
f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule.from_expr(f)
return tvm.tir.transform.MakePackedAPI()(mod)


def test_stack_vm_basic():
a = tvm.nd.array(np.zeros(10, dtype='float32'))
@tvm.register_func
Expand All @@ -48,8 +36,11 @@ def tvm_call_back_get_shape(shape0):
n = te.size_var('n')
Ab = tvm.tir.decl_buffer((n, ), "float32")
stmt = tvm.tir.Evaluate(tvm.tir.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.testing.MakeAPILegacy(stmt, "print_shape", [Ab], 0, True)
run_jit(fapi, lambda f: f(a))

mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "print_shape"))

run_jit(mod, lambda f: f(a))


@tvm.register_func
Expand All @@ -69,12 +60,13 @@ def test_stack_vm_loop():
ib.emit(tvm.tir.call_packed("tvm_stack_vm_print", i))

stmt = ib.get()
fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "ramp"))
a = tvm.nd.array(np.zeros(10, dtype=dtype))
def check(f):
f(a)
np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
run_jit(fapi, check)
run_jit(mod, check)


def test_stack_vm_cond():
Expand All @@ -91,14 +83,15 @@ def test_stack_vm_cond():
A[i + 1] = A[i] + 2

stmt = ib.get()
fapi = tvm.testing.MakeAPILegacy(stmt, "test", [Ab], 0, True)
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "test"))
def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
y = np.arange(a.shape[0]) * 2
y[5:] -= 1
np.testing.assert_equal(a.asnumpy(), y)
run_jit(fapi, check)
run_jit(mod, check)

def test_vm_parallel():
dtype = 'int64'
Expand All @@ -110,12 +103,13 @@ def test_vm_parallel():
with ib.for_range(0, n, "i", for_type="parallel") as i:
A[i] = A[i] + 1
stmt = ib.get()
fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "test"))
def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
np.testing.assert_equal(a.asnumpy(), np.ones(a.shape[0]))
run_jit(fapi, check)
run_jit(mod, check)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_tir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def test_prim_func():
assert func.buffer_map[func.params[2]].same_as(b)

assert len(func.buffer_map) == 1
f2 = func.with_attr("calling_conv", 1)
f2 = func.with_attr({"calling_conv": 1, "tir.noalias": True})
assert f2.attrs["calling_conv"].value == 1
assert func.attrs is None

Expand Down
Loading