diff --git a/python/tvm/contrib/cc.py b/python/tvm/contrib/cc.py index 8ad47acfe989..cdde0cbed1d5 100644 --- a/python/tvm/contrib/cc.py +++ b/python/tvm/contrib/cc.py @@ -45,12 +45,39 @@ def create_shared(output, The compiler command. """ if sys.platform == "darwin" or sys.platform.startswith("linux"): - _linux_compile(output, objects, options, cc) + _linux_compile(output, objects, options, cc, compile_shared=True) elif sys.platform == "win32": _windows_shared(output, objects, options) else: raise ValueError("Unsupported platform") + +def create_executable(output, + objects, + options=None, + cc="g++"): + """Create executable binary. + + Parameters + ---------- + output : str + The target executable. + + objects : List[str] + List of object files. + + options : List[str] + The list of additional options string. + + cc : Optional[str] + The compiler command. + """ + if sys.platform == "darwin" or sys.platform.startswith("linux"): + _linux_compile(output, objects, options, cc) + else: + raise ValueError("Unsupported platform") + + def get_target_by_dump_machine(compiler): """ Functor of get_target_triple that can get the target triple using compiler. @@ -164,9 +191,10 @@ def _fcompile(outputs, objects, options=None): return _fcompile -def _linux_compile(output, objects, options, compile_cmd="g++"): +def _linux_compile(output, objects, options, + compile_cmd="g++", compile_shared=False): cmd = [compile_cmd] - if output.endswith(".so") or output.endswith(".dylib"): + if compile_shared or output.endswith(".so") or output.endswith(".dylib"): cmd += ["-shared", "-fPIC"] if sys.platform == "darwin": cmd += ["-undefined", "dynamic_lookup"] @@ -185,6 +213,7 @@ def _linux_compile(output, objects, options, compile_cmd="g++"): if proc.returncode != 0: msg = "Compilation error:\n" msg += py_str(out) + msg += "\nCommand line: " + " ".join(cmd) raise RuntimeError(msg) diff --git a/tests/python/contrib/test_binutil.py b/tests/python/contrib/test_binutil.py index 3106e73136fa..3aa0583b2816 100644 --- a/tests/python/contrib/test_binutil.py +++ b/tests/python/contrib/test_binutil.py @@ -43,7 +43,7 @@ def make_binary(): tmp_obj = tmp_dir.relpath("obj.obj") with open(tmp_source, "w") as f: f.write(prog) - cc.create_shared(tmp_obj, tmp_source, [], + cc.create_executable(tmp_obj, tmp_source, [], cc="{}gcc".format(TOOLCHAIN_PREFIX)) prog_bin = bytearray(open(tmp_obj, "rb").read()) return prog_bin diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index 17321bdeb293..dfbb3c55227b 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -26,7 +26,7 @@ import pytest import numpy as np from tvm import rpc -from tvm.contrib import util +from tvm.contrib import util, cc from tvm.rpc.tracker import Tracker @@ -142,7 +142,7 @@ def check(remote): # Test minrpc server. temp = util.tempdir() minrpc_exec = temp.relpath("minrpc") - tvm.rpc.with_minrpc("g++")(minrpc_exec, []) + tvm.rpc.with_minrpc(cc.create_executable)(minrpc_exec, []) check(rpc.PopenSession(minrpc_exec)) # minrpc on the remote server = rpc.Server("localhost") @@ -208,7 +208,7 @@ def check_minrpc(): temp = util.tempdir() f = tvm.build(s, [A, B], "llvm --system-lib", name="myadd") path_minrpc = temp.relpath("dev_lib.minrpc") - f.export_library(path_minrpc, rpc.with_minrpc("g++")) + f.export_library(path_minrpc, rpc.with_minrpc(cc.create_executable)) with pytest.raises(RuntimeError): rpc.PopenSession("filenotexist")