Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CONTRIB][CC] Enhance cc.cross_compiler #4817

Merged
merged 2 commits into from
Feb 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 42 additions & 27 deletions python/tvm/contrib/cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,38 +87,23 @@ def get_target_triple():
create_shared.get_target_triple = get_target_by_dump_machine(
"g++" if sys.platform == "darwin" or sys.platform.startswith("linux") else None)

def build_create_shared_func(options=None, compile_cmd="g++"):
"""Build create_shared function with particular default options and compile_cmd.

Parameters
----------
options : List[str]
The list of additional options string.
def cross_compiler(compile_func,
options=None,
output_format=None,
get_target_triple=None):
"""Create a cross compiler function by specializing compile_func with options.

compile_cmd : Optional[str]
The compiler command.

Returns
-------
create_shared_wrapper : Callable[[str, str, Optional[str]], None]
A compilation function that can be passed to export_library or to autotvm.LocalBuilder.
"""
def create_shared_wrapper(output, objects, options=options, compile_cmd=compile_cmd):
create_shared(output, objects, options, compile_cmd)
create_shared_wrapper.output_format = create_shared.output_format
create_shared_wrapper.get_target_triple = get_target_by_dump_machine(compile_cmd)
return create_shared_wrapper
This function can be used to construct compile functions that
can be passed to AutoTVM measure or export_library.


def cross_compiler(compile_func, base_options=None, output_format="so", get_target_triple=None):
"""Create a cross compiler function.

Parameters
----------
compile_func : Callable[[str, str, Optional[str]], None]
compile_func : Union[str, Callable[[str, str, Optional[str]], None]]
Function that performs the actual compilation

base_options : Optional[List[str]]
options : Optional[List[str]]
List of additional optional string.

output_format : Optional[str]
Expand All @@ -131,14 +116,44 @@ def cross_compiler(compile_func, base_options=None, output_format="so", get_targ
-------
fcompile : Callable[[str, str, Optional[str]], None]
A compilation function that can be passed to export_library.

Examples
--------
.. code-block:: python

from tvm.contrib import cc, ndk
# export using arm gcc
mod = build_runtime_module()
mod.export_library(path_dso,
cc.cross_compiler("arm-linux-gnueabihf-gcc"))
# specialize ndk compilation options.
specialized_ndk = cc.cross_compiler(
ndk.create_shared,
["--sysroot=/path/to/sysroot", "-shared", "-fPIC", "-lm"])
mod.export_library(path_dso, specialized_ndk)
"""
if base_options is None:
base_options = []
base_options = [] if options is None else options
kwargs = {}

# handle case where compile_func is the name of the cc
if isinstance(compile_func, str):
kwargs = {"cc" : compile_func}
compile_func = create_shared


def _fcompile(outputs, objects, options=None):
all_options = base_options
if options is not None:
all_options += options
compile_func(outputs, objects, options=all_options)
compile_func(outputs, objects, options=all_options, **kwargs)

if not output_format and hasattr(compile_func, "output_format"):
output_format = compile_func.output_format
output_format = output_format if output_format else "so"

if not get_target_triple and hasattr(compile_func, "get_target_triple"):
get_target_triple = compile_func.get_target_triple

_fcompile.output_format = output_format
_fcompile.get_target_triple = get_target_triple
return _fcompile
Expand Down
7 changes: 4 additions & 3 deletions tests/python/unittest/test_module_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def check_device(device):
raise ValueError("Unsupported platform")

path_dso = temp.relpath("dev_lib.so")
f.export_library(path_dso)
# test cross compiler function
f.export_library(path_dso, cc.cross_compiler("g++"))

f1 = tvm.module.load(path_dso)
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
Expand All @@ -134,8 +135,8 @@ def check_stackvm(device):
name = "myadd_%s" % device
f = tvm.build(s, [A, B], device, "stackvm", name=name)
path_dso = temp.relpath("dev_lib.stackvm")
#f.export_library(path_dso)
#f1 = tvm.module.load(path_dso)
f.export_library(path_dso)
f1 = tvm.module.load(path_dso)
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
f(a, b)
Expand Down