Skip to content

Commit

Permalink
[RPC] Better handle tempdir if subprocess killed. (apache#3574)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalint13 authored and wweic committed Sep 6, 2019
1 parent 4b7d9f1 commit bba856d
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 13 deletions.
17 changes: 13 additions & 4 deletions python/tvm/contrib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@ class TempDirectory(object):
Automatically removes the directory when it went out of scope.
"""
def __init__(self):
self.temp_dir = tempfile.mkdtemp()
def __init__(self, custom_path=None):
if custom_path:
os.mkdir(custom_path)
self.temp_dir = custom_path
else:
self.temp_dir = tempfile.mkdtemp()
self._rmtree = shutil.rmtree

def remove(self):
Expand Down Expand Up @@ -69,15 +73,20 @@ def listdir(self):
return os.listdir(self.temp_dir)


def tempdir():
def tempdir(custom_path=None):
"""Create temp dir which deletes the contents when exit.
Parameters
----------
custom_path : str, optional
Manually specify the exact temp dir path
Returns
-------
temp : TempDirectory
The temp directory object
"""
return TempDirectory()
return TempDirectory(custom_path)


class FileLock(object):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def load(path, fmt=""):
_cc.create_shared(path + ".so", path)
path += ".so"
elif path.endswith(".tar"):
tar_temp = _util.tempdir()
tar_temp = _util.tempdir(custom_path=path.replace('.tar', ''))
_tar.untar(path, tar_temp.temp_dir)
files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
_cc.create_shared(path + ".so", files)
Expand Down
20 changes: 12 additions & 8 deletions python/tvm/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,12 @@

logger = logging.getLogger('RPCServer')

def _server_env(load_library):
def _server_env(load_library, work_path=None):
"""Server environment function return temp dir"""
temp = util.tempdir()
if work_path:
temp = work_path
else:
temp = util.tempdir()

# pylint: disable=unused-variable
@register_func("tvm.rpc.server.workpath")
Expand All @@ -76,16 +79,15 @@ def load_module(file_name):
temp.libs = libs
return temp


def _serve_loop(sock, addr, load_library):
def _serve_loop(sock, addr, load_library, work_path=None):
"""Server loop"""
sockfd = sock.fileno()
temp = _server_env(load_library)
temp = _server_env(load_library, work_path)
base._ServerLoop(sockfd)
temp.remove()
if not work_path:
temp.remove()
logger.info("Finish serving %s", addr)


def _parse_server_opt(opts):
# parse client options
ret = {}
Expand Down Expand Up @@ -196,9 +198,10 @@ def _accept_conn(listen_sock, tracker_conn, ping_period=2):
raise exc

# step 3: serving
work_path = util.tempdir()
logger.info("connection from %s", addr)
server_proc = multiprocessing.Process(target=_serve_loop,
args=(conn, addr, load_library))
args=(conn, addr, load_library, work_path))
server_proc.deamon = True
server_proc.start()
# close from our side.
Expand All @@ -208,6 +211,7 @@ def _accept_conn(listen_sock, tracker_conn, ping_period=2):
if server_proc.is_alive():
logger.info("Timeout in RPC session, kill..")
server_proc.terminate()
work_path.remove()


def _connect_proxy_loop(addr, key, load_library):
Expand Down

0 comments on commit bba856d

Please sign in to comment.