Skip to content

Commit

Permalink
fix: crash when host has multiple NVIDIA GPUs (#435)
Browse files Browse the repository at this point in the history
Signed-off-by: Josh Usiskin <[email protected]>
  • Loading branch information
jusiskin authored Oct 9, 2024
1 parent f1580b1 commit 760118c
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 8 deletions.
20 changes: 14 additions & 6 deletions src/deadline_worker_agent/startup/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _get_gpu_count(*, verbose: bool = True) -> int:
"""
try:
output = subprocess.check_output(
["nvidia-smi", "--query-gpu=count", "--format=csv,noheader"]
["nvidia-smi", "--query-gpu=count", "-i=0", "--format=csv,noheader"]
)
except FileNotFoundError:
if verbose:
Expand Down Expand Up @@ -110,7 +110,7 @@ def _get_gpu_memory(*, verbose: bool = True) -> int:
The total GPU memory available on the machine.
"""
try:
output = subprocess.check_output(
output_bytes = subprocess.check_output(
["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader"]
)
except FileNotFoundError:
Expand All @@ -131,10 +131,18 @@ def _get_gpu_memory(*, verbose: bool = True) -> int:
if verbose:
_logger.warning("Could not detect GPU memory, unexpected error running nvidia-smi")
return 0
else:
if verbose:
_logger.info("Total GPU Memory: %s", output.decode().strip())
return int(output.decode().strip().replace("MiB", ""))
output = output_bytes.decode().strip()

mem_per_gpu: list[int] = []
for line in output.splitlines():
mem_mib = int(line.replace("MiB", ""))
mem_per_gpu.append(mem_mib)

min_memory = min(mem_per_gpu)

if verbose:
_logger.info("Minimum total memory of all GPUs: %s", min_memory)
return min_memory


def capability_type(capability_name_str: str) -> Literal["amount", "attr"]:
Expand Down
25 changes: 23 additions & 2 deletions test/unit/startup/test_capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def test_get_gpu_count(

# THEN
check_output_mock.assert_called_once_with(
["nvidia-smi", "--query-gpu=count", "--format=csv,noheader"]
["nvidia-smi", "--query-gpu=count", "-i=0", "--format=csv,noheader"]
)
assert result == 2

Expand Down Expand Up @@ -274,7 +274,7 @@ def test_get_gpu_count_nvidia_smi_error(

# THEN
check_output_mock.assert_called_once_with(
["nvidia-smi", "--query-gpu=count", "--format=csv,noheader"]
["nvidia-smi", "--query-gpu=count", "-i=0", "--format=csv,noheader"]
)

assert result == expected_result
Expand All @@ -301,6 +301,27 @@ def test_get_gpu_memory(
)
assert result == 6800

@patch.object(capabilities_mod.subprocess, "check_output")
def test_get_multi_gpu_memory(
self,
check_output_mock: MagicMock,
) -> None:
"""
Tests that the _get_gpu_memory function returns the minimum total memory among all GPUs
reported by nvidia-smi.
"""
# GIVEN
check_output_mock.return_value = b"6800 MiB\n1200MiB"

# WHEN
result = capabilities_mod._get_gpu_memory()

# THEN
check_output_mock.assert_called_once_with(
["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader"]
)
assert result == 1200

@pytest.mark.parametrize(
("exception", "expected_result"),
(
Expand Down

0 comments on commit 760118c

Please sign in to comment.