Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Add pytest marker to opt-out of global test cleanup #3863

Merged
merged 1 commit into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,15 @@ def cleanup():


@pytest.fixture()
def should_do_global_cleanup_after_test() -> bool:
def should_do_global_cleanup_after_test(request) -> bool:
"""Allow subdirectories to skip global cleanup by overriding this fixture.
This can provide a ~10x speedup for non-GPU unit tests since they don't need
to initialize torch.
"""

if request.node.get_closest_marker("skip_global_cleanup"):
return False

return True


Expand Down
3 changes: 3 additions & 0 deletions tests/spec_decode/test_batch_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


@pytest.mark.parametrize('num_target_seq_ids', [100])
@pytest.mark.skip_global_cleanup
def test_create_target_seq_id_iterator(num_target_seq_ids: int):
"""Verify all new sequence ids are greater than all input
seq ids.
Expand All @@ -27,6 +28,7 @@ def test_create_target_seq_id_iterator(num_target_seq_ids: int):


@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.skip_global_cleanup
def test_get_token_ids_to_score(k: int):
"""Verify correct tokens are selected for scoring.
"""
Expand All @@ -53,6 +55,7 @@ def test_get_token_ids_to_score(k: int):


@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.skip_global_cleanup
def test_create_single_target_seq_group_metadata(k: int):
"""Verify correct creation of a batch-expanded seq group metadata.
"""
Expand Down
6 changes: 3 additions & 3 deletions tests/spec_decode/test_spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def test_empty_input_batch(k: int, batch_size: int):
**execute_model_data.to_dict())


@torch.inference_mode()
@pytest.mark.skip_global_cleanup
def test_init_device():
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
well as other GPU initialization.
Expand Down Expand Up @@ -537,7 +537,7 @@ def test_init_cache_engine():
@pytest.mark.parametrize('available_cpu_blocks', [500])
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
@torch.inference_mode()
@pytest.mark.skip_global_cleanup
def test_profile_num_available_blocks(available_gpu_blocks: int,
available_cpu_blocks: int,
target_cache_block_size_bytes: int,
Expand Down Expand Up @@ -584,7 +584,7 @@ def test_profile_num_available_blocks(available_gpu_blocks: int,
@pytest.mark.parametrize('target_cache_block_size_bytes',
[2 * 2 * 4096, 2 * 2 * 8192])
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
@torch.inference_mode()
@pytest.mark.skip_global_cleanup
def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
target_cache_block_size_bytes: int,
draft_kv_size_bytes: int):
Expand Down
Loading