diff --git a/python/tvm/contrib/util.py b/python/tvm/contrib/util.py index 2ebe175e8160..e980e5520802 100644 --- a/python/tvm/contrib/util.py +++ b/python/tvm/contrib/util.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Common system utilities""" +import atexit import os import tempfile import shutil @@ -23,27 +24,46 @@ except ImportError: fcntl = None - class TempDirectory(object): """Helper object to manage temp directory during testing. Automatically removes the directory when it went out of scope. """ + + TEMPDIRS = set() + @classmethod + def remove_tempdirs(cls): + temp_dirs = getattr(cls, 'TEMPDIRS', None) + if temp_dirs is None: + return + + for path in temp_dirs: + shutil.rmtree(path, ignore_errors=True) + + cls.TEMPDIRS = None + def __init__(self, custom_path=None): if custom_path: os.mkdir(custom_path) self.temp_dir = custom_path else: self.temp_dir = tempfile.mkdtemp() - self._rmtree = shutil.rmtree + + self.TEMPDIRS.add(self.temp_dir) def remove(self): """Remote the tmp dir""" if self.temp_dir: - self._rmtree(self.temp_dir, ignore_errors=True) + shutil.rmtree(self.temp_dir, ignore_errors=True) + self.TEMPDIRS.remove(self.temp_dir) self.temp_dir = None def __del__(self): + temp_dirs = getattr(self, 'TEMPDIRS', None) + if temp_dirs is None: + # Do nothing if the atexit hook has already run. + return + self.remove() def relpath(self, name): @@ -72,6 +92,9 @@ def listdir(self): return os.listdir(self.temp_dir) +atexit.register(TempDirectory.remove_tempdirs) + + def tempdir(custom_path=None): """Create temp dir which deletes the contents when exit.