Skip to content

Commit

Permalink
[µTVM] Avoid use of builtin math functions (apache#6630)
Browse files Browse the repository at this point in the history
  • Loading branch information
areusch authored and Tushar Dey committed Oct 13, 2020
1 parent cb73cd5 commit 8adb52c
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 20 deletions.
46 changes: 36 additions & 10 deletions python/tvm/micro/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,31 +67,50 @@ def path(self):
RUNTIME_SRC_REGEX = re.compile(r"^.*\.cc?$", re.IGNORECASE)


_COMMON_CFLAGS = ["-Wall", "-Werror"]


_CRT_DEFAULT_OPTIONS = {
"ccflags": ["-std=c++11"],
"ldflags": ["-std=gnu++14"],
"cflags": ["-std=c11"] + _COMMON_CFLAGS,
"ccflags": ["-std=c++11"] + _COMMON_CFLAGS,
"ldflags": ["-std=c++11"],
"include_dirs": [
f"{TVM_ROOT_DIR}/include",
f"{TVM_ROOT_DIR}/3rdparty/dlpack/include",
f"{TVM_ROOT_DIR}/3rdparty/libcrc/include",
f"{TVM_ROOT_DIR}/3rdparty/dmlc-core/include",
f"{CRT_ROOT_DIR}/include",
],
"profile": {"common": ["-Wno-unused-variable"]},
}


_CRT_GENERATED_LIB_OPTIONS = copy.copy(_CRT_DEFAULT_OPTIONS)


# Disable due to limitation in the TVM C codegen, which generates lots of local variable
# declarations at the top of generated code without caring whether they're used.
# Example:
# void* arg0 = (((TVMValue*)args)[0].v_handle);
# int32_t arg0_code = ((int32_t*)arg_type_ids)[(0)];
_CRT_GENERATED_LIB_OPTIONS["cflags"].append("-Wno-unused-variable")


# Many TVM-intrinsic operators (i.e. expf, in particular)
_CRT_GENERATED_LIB_OPTIONS["cflags"].append("-fno-builtin")


def default_options(target_include_dir):
"""Return default opts passed to Compile commands."""
bin_opts = copy.deepcopy(_CRT_DEFAULT_OPTIONS)
bin_opts["include_dirs"].append(target_include_dir)
lib_opts = copy.deepcopy(bin_opts)
lib_opts["profile"]["common"].append("-Werror")
lib_opts["cflags"] = ["-Wno-error=incompatible-pointer-types"]
return {"bin_opts": bin_opts, "lib_opts": lib_opts}


def build_static_runtime(workspace, compiler, module, lib_opts=None, bin_opts=None):
def build_static_runtime(
workspace, compiler, module, lib_opts=None, bin_opts=None, generated_lib_opts=None
):
"""Build the on-device runtime, statically linking the given modules.
Parameters
Expand All @@ -102,11 +121,15 @@ def build_static_runtime(workspace, compiler, module, lib_opts=None, bin_opts=No
module : IRModule
Module to statically link.
lib_opts : dict
Extra kwargs passed to library(),
lib_opts : Optional[dict]
The `options` parameter passed to compiler.library().
bin_opts : Optional[dict]
The `options` parameter passed to compiler.binary().
bin_opts : dict
Extra kwargs passed to binary(),
generated_lib_opts : Optional[dict]
The `options` parameter passed to compiler.library() when compiling the generated TVM C
source module.
Returns
-------
Expand All @@ -115,6 +138,9 @@ def build_static_runtime(workspace, compiler, module, lib_opts=None, bin_opts=No
"""
lib_opts = _CRT_DEFAULT_OPTIONS if lib_opts is None else lib_opts
bin_opts = _CRT_DEFAULT_OPTIONS if bin_opts is None else bin_opts
generated_lib_opts = (
_CRT_GENERATED_LIB_OPTIONS if generated_lib_opts is None else generated_lib_opts
)

mod_build_dir = workspace.relpath(os.path.join("build", "module"))
os.makedirs(mod_build_dir)
Expand All @@ -136,7 +162,7 @@ def build_static_runtime(workspace, compiler, module, lib_opts=None, bin_opts=No

libs.append(compiler.library(lib_build_dir, lib_srcs, lib_opts))

libs.append(compiler.library(mod_build_dir, [mod_src_path], lib_opts))
libs.append(compiler.library(mod_build_dir, [mod_src_path], generated_lib_opts))

runtime_build_dir = workspace.relpath(f"build/runtime")
os.makedirs(runtime_build_dir)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/micro/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def flash(self, micro_binary):
[micro_binary.abspath(micro_binary.binary_file)]
)
return transport.DebugWrapperTransport(
debugger=gdb_wrapper, transport=gdb_wrapper.Transport()
debugger=gdb_wrapper, transport=gdb_wrapper.transport()
)

return transport.SubprocessTransport([micro_binary.abspath(micro_binary.binary_file)])
10 changes: 5 additions & 5 deletions python/tvm/micro/debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def start(self):
self.did_terminate = threading.Event()
self.old_signal = signal.signal(signal.SIGINT, signal.SIG_IGN)
self.popen = subprocess.Popen(**kwargs)
threading.Thread(target=self._WaitRestoreSignal).start()
threading.Thread(target=self._wait_restore_signal).start()

def stop(self):
self.did_terminate.set()
Expand Down Expand Up @@ -130,14 +130,14 @@ def _wait_for_process_death(self):
self.stdout.close()

def start(self):
to_return = super(GdbTransportDebugger, self).Start()
to_return = super(GdbTransportDebugger, self).start()
threading.Thread(target=self._wait_for_process_death, daemon=True).start()
return to_return

def stop(self):
self.stdin.close()
self.stdout.close()
super(GdbTransportDebugger, self).Stop()
super(GdbTransportDebugger, self).stop()

class _Transport(_transport.Transport):
def __init__(self, gdb_transport_debugger):
Expand Down Expand Up @@ -189,11 +189,11 @@ def popen_kwargs(self):
def start(self):
if self.wrapping_context_manager is not None:
self.wrapping_context_manager.__enter__()
super(GdbRemoteDebugger, self).Start()
super(GdbRemoteDebugger, self).start()

def stop(self):
try:
super(GdbRemoteDebugger, self).Stop()
super(GdbRemoteDebugger, self).stop()
finally:
if self.wrapping_context_manager is not None:
self.wrapping_context_manager.__exit__(None, None, None)
6 changes: 3 additions & 3 deletions python/tvm/micro/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,12 @@ def __init__(self, debugger, transport):
self.debugger.on_terminate_callbacks.append(self.transport.close)

def open(self):
self.debugger.Start()
self.debugger.start()

try:
self.transport.open()
except Exception:
self.debugger.Stop()
self.debugger.stop()
raise

def write(self, data):
Expand All @@ -232,7 +232,7 @@ def read(self, n):

def close(self):
self.transport.close()
self.debugger.Stop()
self.debugger.stop()


TransportContextManager = typing.ContextManager[Transport]
6 changes: 5 additions & 1 deletion src/runtime/crt/common/func_registry.c
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,12 @@ tvm_crt_error_t TVMMutableFuncRegistry_Set(TVMMutableFuncRegistry* reg, const ch
idx++;
}

if (reg_name_ptr > ((const char*)reg->registry.funcs)) {
return kTvmErrorFunctionRegistryFull;
}

size_t name_len = strlen(name);
ssize_t names_bytes_remaining = ((const char*)reg->registry.funcs) - reg_name_ptr;
size_t names_bytes_remaining = ((const char*)reg->registry.funcs) - reg_name_ptr;
if (idx >= reg->max_functions || name_len + 1 > names_bytes_remaining) {
return kTvmErrorFunctionRegistryFull;
}
Expand Down
1 change: 1 addition & 0 deletions src/target/source/codegen_c_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts) {
declared_globals_.clear();
decl_stream << "#include \"tvm/runtime/c_runtime_api.h\"\n";
decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n";
decl_stream << "#include <math.h>\n";
decl_stream << "void* " << module_name_ << " = NULL;\n";
CodeGenC::Init(output_ssa);
}
Expand Down
17 changes: 17 additions & 0 deletions tests/python/unittest/test_crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,24 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), uint8]) {
assert (out.asnumpy() == np.array([6, 10])).all()


def test_std_math_functions():
"""Verify that standard math functions can be used."""
workspace = tvm.micro.Workspace()
A = tvm.te.placeholder((2,), dtype="float32", name="A")
B = tvm.te.compute(A.shape, lambda i: tvm.te.exp(A[i]), name="B")
s = tvm.te.create_schedule(B.op)

with _make_sess_from_op(workspace, "myexpf", s, [A, B]) as sess:
A_data = tvm.nd.array(np.array([2.0, 3.0], dtype="float32"), ctx=sess.context)
B_data = tvm.nd.array(np.array([2.0, 3.0], dtype="float32"), ctx=sess.context)
lib = sess.get_system_lib()
func = lib["myexpf"]
func(A_data, B_data)
np.testing.assert_allclose(B_data.asnumpy(), np.array([7.389056, 20.085537]))


if __name__ == "__main__":
test_compile_runtime()
test_reset()
test_graph_runtime()
test_std_math_functions()

0 comments on commit 8adb52c

Please sign in to comment.