Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hexagon] Improved ergonomics of HexagonLauncher in unit tests. #10581

Merged
merged 9 commits into from
Mar 25, 2022
25 changes: 17 additions & 8 deletions python/tvm/contrib/hexagon/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,28 +195,37 @@ def start_session(self) -> Session:
"timeout": 0,
"key": self.HEXAGON_REMOTE_DEVICE_KEY,
}
return Session(hexagon_remote_kw)
return Session(self, hexagon_remote_kw)

def load_module(self, module_name: Union[str, pathlib.Path], session: Session):
def load_module(self, module: Union[str, pathlib.Path], session: Session):
Lunderberg marked this conversation as resolved.
Show resolved Hide resolved
"""Load TVM module.

Parameters
----------
module_name : str or pathlib.Path
Name of the module to load. It must be either a bare file name
(without any path components), or a full path in the remote
system. If it is a file name, the file must be placed in the
remote workspace.
module : Union[str, pathlib.Path, tvm.runtime.Module]

The module to load. If `module` is a
`tvm.runtime.Module`, it will be uploaded to the remote
session and loaded.

If the object passed is a string or pathlib.Path, it must
be either a bare file name (without any path components),
or a full path in the remote system. If it is a file name,
the file must already have been uploaded to the remote,
and be placed in the remote workspace.

session : Session

Remote session. The session must be established (via __enter__)
prior to calling this function.

Returns
-------
TVMModule :
TVM module object.

"""
return session.load_module(module_name)
return session.load_module(module)

def get_graph_executor(
self, graph_json: str, module_name: Union[str, pathlib.Path], session: Session
Expand Down
58 changes: 55 additions & 3 deletions python/tvm/contrib/hexagon/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

import os
import pathlib
import tempfile
from typing import Union

import tvm
from tvm import rpc as _rpc


Expand All @@ -37,10 +40,12 @@ class Session:

Lunderberg marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
launcher: "HexagonLauncherRPC",
csullivan marked this conversation as resolved.
Show resolved Hide resolved
remote_kw: dict,
session_name: str = "hexagon-rpc",
remote_stack_size_bytes: int = 128 * 1024,
):
self._launcher = launcher
self._session_name = session_name
self._remote_stack_size_bytes = remote_stack_size_bytes
self._remote_kw = remote_kw
Expand Down Expand Up @@ -74,6 +79,53 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, exc_traceback):
pass

def load_module(self, path: Union[str, pathlib.Path]):
assert isinstance(path, (str, pathlib.Path)), "Invalid path type:" + str(type(path))
return self._rpc.get_function("tvm.hexagon.load_module")(str(path))
def upload(self, local_path: Union[str, pathlib.Path], remote_filename: str):
"""Upload a local file to the remote workspace.

Parameters
----------
local_path : str or pathlib.Path
Path to the local file to be copied.
remote_filename : str
Name of the file in the remote workspace.
"""
self._launcher.upload(local_path, remote_filename)

def load_module(self, module: Union[str, pathlib.Path, tvm.IRModule]):
"""Load TVM module.

Parameters
----------
module : Union[str, pathlib.Path, tvm.runtime.Module]

The module to load. If `module` is a
`tvm.runtime.Module`, it will be uploaded to the remote
session and loaded.

If the object passed is a string or pathlib.Path, it must
be either a bare file name (without any path components),
or a full path in the remote system. If it is a file name,
the file must already have been uploaded to the remote,
and be placed in the remote workspace.

session : Session

Remote session. The session must be established (via __enter__)
prior to calling this function.

Returns
-------
TVMModule :
TVM module object.
"""
if isinstance(module, tvm.runtime.Module):
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir = pathlib.Path(temp_dir)
binary_name = "test_binary.so"
binary_path = temp_dir / binary_name
module.save(str(binary_path))
self.upload(binary_path, binary_name)
module = binary_name

assert isinstance(module, (str, pathlib.Path)), "Invalid path type:" + str(type(module))
return self._rpc.get_function("tvm.hexagon.load_module")(str(module))
132 changes: 122 additions & 10 deletions tests/python/contrib/test_hexagon/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,15 @@
values from testing parameters """

import os
import random
import socket
from typing import Optional

import pytest

import tvm
from tvm import rpc
import tvm.rpc.tracker
from tvm.contrib.hexagon.build import HexagonLauncher

HEXAGON_TOOLCHAIN = "HEXAGON_TOOLCHAIN"
TVM_TRACKER_HOST = "TVM_TRACKER_HOST"
Expand Down Expand Up @@ -59,27 +64,134 @@ def requires_hexagon_toolchain(*args):


@tvm.testing.fixture
def android_serial_number() -> str:
return os.getenv(ANDROID_SERIAL_NUMBER, default=None)
def android_serial_number() -> Optional[str]:
serial = os.getenv(ANDROID_SERIAL_NUMBER, default="")
# Setting ANDROID_SERIAL_NUMBER to an empty string should be
# equivalent to having it unset.
if not serial.strip():
serial = None
return serial


@tvm.testing.fixture
def tvm_tracker_host() -> str:
return os.getenv(TVM_TRACKER_HOST, default=None)
# NOTE on server ports:
# These tests use different port numbers for the RPC server (7070 + ...).
# The reason is that an RPC session cannot be gracefully closed without
# triggering TIME_WAIT state on the server socket. This prevents another
# server to bind to the same port until the wait time elapses.

listen_port_min = 2000 # Well above the privileged ports (1024 or lower)
listen_port_max = 9000 # Below the search range end (port_end=9199) of RPC server
previous_port = [None]
csullivan marked this conversation as resolved.
Show resolved Hide resolved

@tvm.testing.fixture
def tvm_tracker_port() -> int:
port = os.getenv(TVM_TRACKER_PORT, default=None)
port = int(port) if port else None

def get_free_port():
# https://stackoverflow.com/a/52872579/2689797
def is_port_in_use(port: int) -> bool:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0

if previous_port[0] is None:
port = random.randint(listen_port_min, listen_port_max)
else:
port = previous_port[0] + 1

while is_port_in_use(port):
port = port + 1 if port < listen_port_max else listen_port_min

previous_port[0] = port
return port


@pytest.fixture(scope="session")
def _tracker_info() -> (str, int):
env_tracker_host = os.getenv(TVM_TRACKER_HOST, default="")
env_tracker_port = os.getenv(TVM_TRACKER_PORT, default="")

if env_tracker_host or env_tracker_port:
# A tracker is already running, and we should connect to it
# when running tests.
assert env_tracker_host, "TVM_TRACKER_PORT is defined, but TVM_TRACKER_HOST is not"
assert env_tracker_port, "TVM_TRACKER_HOST is defined, but TVM_TRACKER_PORT is not"
env_tracker_port = int(env_tracker_port)

try:
tvm.rpc.connect_tracker(env_tracker_host, env_tracker_port)
except RuntimeError as exc:
message = (
"Could not connect to external tracker "
"specified by $TVM_TRACKER_HOST and $TVM_TRACKER_PORT "
f"({env_tracker_host}:{env_tracker_port})"
)
raise RuntimeError(message) from exc

yield (env_tracker_host, env_tracker_port)

else:
# No tracker is provided to the tests, so we should start one
# for the tests to use.
tracker = tvm.rpc.tracker.Tracker("127.0.0.1", get_free_port())
try:
yield (tracker.host, tracker.port)
finally:
tracker.terminate()


@pytest.fixture(scope="session")
def tvm_tracker_host(_tracker_info) -> str:
host, port = _tracker_info
return host


@pytest.fixture(scope="session")
def tvm_tracker_port(_tracker_info) -> int:
host, port = _tracker_info
return port


@tvm.testing.fixture
def rpc_server_port() -> int:
return get_free_port()


@tvm.testing.fixture
def adb_server_socket() -> str:
return os.getenv(ADB_SERVER_SOCKET, default="tcp:5037")


@tvm.testing.fixture
def hexagon_launcher(request, android_serial_number, rpc_server_port, adb_server_socket):
if android_serial_number is None:
yield None
else:
# Requesting these fixtures sets up a local tracker, if one
# hasn't been provided to us. Delaying the evaluation of
# these fixtures avoids starting a tracker unless necessary.
tvm_tracker_host = request.getfixturevalue("tvm_tracker_host")
tvm_tracker_port = request.getfixturevalue("tvm_tracker_port")

rpc_info = {
"rpc_tracker_host": tvm_tracker_host,
"rpc_tracker_port": tvm_tracker_port,
"rpc_server_port": rpc_server_port,
"adb_server_socket": adb_server_socket,
}
launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info)
launcher.start_server()
try:
yield launcher
finally:
launcher.stop_server()


@tvm.testing.fixture
def hexagon_session(hexagon_launcher):
if hexagon_launcher is None:
yield None
else:
with hexagon_launcher.start_session() as session:
yield session


# If the execution aborts while an RPC server is running, the python
# code that is supposed to shut it dowm will never execute. This will
# keep pytest from terminating (indefinitely), so add a cleanup
Expand Down
47 changes: 16 additions & 31 deletions tests/python/contrib/test_hexagon/test_cache_read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ def intrin_func(ins, outs):


@requires_hexagon_toolchain
def test_cache_read_write(
android_serial_number, tvm_tracker_host, tvm_tracker_port, adb_server_socket
):
def test_cache_read_write(hexagon_session):
size = 128
outer_shape = (size,)
factor = 16
Expand Down Expand Up @@ -105,37 +103,24 @@ def test_cache_read_write(
func = tvm.build(
s, [x, y, z], tvm.target.Target(target_hexagon, host=target_hexagon), name="dmacpy"
)
temp = utils.tempdir()
dso_binary = "test_binary.so"
dso_binary_path = temp.relpath(dso_binary)
func.save(dso_binary_path)

if not android_serial_number:
if hexagon_session is None:
pytest.skip("Skip hardware test since ANDROID_SERIAL_NUMBER is not set.")

rpc_info = {
"rpc_tracker_host": tvm_tracker_host,
"rpc_tracker_port": tvm_tracker_port,
"rpc_server_port": 7070,
"adb_server_socket": adb_server_socket,
}
launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info)
launcher.upload(dso_binary_path, dso_binary)
launcher.start_server()

with launcher.start_session() as sess:
mod = launcher.load_module(dso_binary, sess)
xt = tvm.nd.array(
np.random.randint(-128, high=127, size=size, dtype=x.dtype), device=sess.device
)
yt = tvm.nd.array(
np.random.randint(-128, high=127, size=size, dtype=x.dtype), device=sess.device
)
zt = tvm.nd.array(
np.random.randint(-128, high=127, size=size, dtype=x.dtype), device=sess.device
)
mod["dmacpy"](xt, yt, zt)
launcher.stop_server()
mod = hexagon_session.load_module(func)
xt = tvm.nd.array(
np.random.randint(low=-128, high=127, size=size, dtype=x.dtype),
device=hexagon_session.device,
)
yt = tvm.nd.array(
np.random.randint(low=-128, high=127, size=size, dtype=y.dtype),
device=hexagon_session.device,
)
zt = tvm.nd.array(
np.random.randint(low=-128, high=127, size=size, dtype=z.dtype),
device=hexagon_session.device,
)
mod["dmacpy"](xt, yt, zt)

ref = xt.numpy() + yt.numpy()
np.testing.assert_equal(zt.numpy(), ref)
Loading