Skip to content

Commit

Permalink
[Misc] [CI/Build] Speed up block manager CPU-only unit tests ~10x by …
Browse files Browse the repository at this point in the history
…opting-out of GPU cleanup (#3783)
  • Loading branch information
cadedaniel authored Apr 2, 2024
1 parent 7d4e1b8 commit eb69d68
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 18 deletions.
14 changes: 12 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,20 @@ def cleanup():
torch.cuda.empty_cache()


@pytest.fixture()
def should_do_global_cleanup_after_test() -> bool:
"""Allow subdirectories to skip global cleanup by overriding this fixture.
This can provide a ~10x speedup for non-GPU unit tests since they don't need
to initialize torch.
"""
return True


@pytest.fixture(autouse=True)
def cleanup_fixture():
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
yield
cleanup()
if should_do_global_cleanup_after_test:
cleanup()


@pytest.fixture(scope="session")
Expand Down
12 changes: 12 additions & 0 deletions tests/core/block/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import pytest


@pytest.fixture()
def should_do_global_cleanup_after_test() -> bool:
"""Disable the global cleanup fixture for tests in this directory. This
provides a ~10x speedup for unit tests that don't load a model to GPU.
This requires that tests in this directory clean up after themselves if they
use the GPU.
"""
return False
17 changes: 1 addition & 16 deletions tests/core/block/e2e/conftest.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,10 @@
import contextlib
import gc

import pytest
import ray
import torch

from tests.conftest import cleanup
from vllm import LLM
from vllm.model_executor.parallel_utils.parallel_state import (
destroy_model_parallel)
from vllm.model_executor.utils import set_random_seed


def cleanup():
destroy_model_parallel()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
gc.collect()
torch.cuda.empty_cache()
ray.shutdown()


@pytest.fixture
def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, seed):
Expand Down

0 comments on commit eb69d68

Please sign in to comment.