diff --git a/tests/conftest.py b/tests/conftest.py index 17ffd281c..59146963d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +import gc + import pytest import torch @@ -20,6 +22,13 @@ def pytest_runtest_call(item): raise +@pytest.hookimpl(trylast=True) +def pytest_runtest_teardown(item, nextitem): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + @pytest.fixture(scope="session") def requires_cuda() -> bool: cuda_available = torch.cuda.is_available()