Skip to content

Commit

Permalink
fix: mock HostMetricsLogger in entrypoint tests (#70)
Browse files Browse the repository at this point in the history
Signed-off-by: Jericho Tolentino <[email protected]>
Signed-off-by: Graeme McHale <[email protected]>
  • Loading branch information
jericht authored and gmchale79 committed Nov 2, 2023
1 parent 30a6a1c commit 1eed0fe
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 95 deletions.
111 changes: 61 additions & 50 deletions src/deadline_worker_agent/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

from __future__ import annotations

from logging import Logger
from logging import Logger, getLogger
from threading import Timer
from typing import Any

import os
import psutil

module_logger = getLogger(__name__)


class HostMetricsLogger:
"""Context manager that regularly logs host metrics"""
Expand Down Expand Up @@ -39,58 +41,67 @@ def log_metrics(self):
Queries information about the host machine and logs the information as a space-delimited
line of the form: <label> <value> ...
"""
memory = psutil.virtual_memory()
swap = psutil.swap_memory()
disk = psutil.disk_usage(os.sep)

# On Windows it may be necessary to issue diskperf -y command from cmd.exe first in order to enable IO counters
disk_counters = psutil.disk_io_counters(nowrap=True)
if disk_counters is None:
disk_read = disk_write = "NOT_AVAILABLE"
elif not (hasattr(disk_counters, "read_bytes") and hasattr(disk_counters, "write_bytes")):
# TODO: Support disk speed on NetBSD and OpenBSD
disk_read = disk_write = "NOT_SUPPORTED"
try:
cpu_percent = psutil.cpu_percent()
memory = psutil.virtual_memory()
swap = psutil.swap_memory()
disk = psutil.disk_usage(os.sep)
disk_counters = psutil.disk_io_counters(nowrap=True)
network = psutil.net_io_counters(nowrap=True)
except Exception as e:
module_logger.warning(
f"Failed to get host metrics. Skipping host metrics log message. Error: {e}"
)
else:
disk_read = str(round(disk_counters.read_bytes / self.interval_s))
disk_write = str(round(disk_counters.write_bytes / self.interval_s))
# On Windows it may be necessary to issue diskperf -y command from cmd.exe first in order to enable IO counters
if disk_counters is None:
disk_read = disk_write = "NOT_AVAILABLE"
elif not (
hasattr(disk_counters, "read_bytes") and hasattr(disk_counters, "write_bytes")
):
# TODO: Support disk speed on NetBSD and OpenBSD
disk_read = disk_write = "NOT_SUPPORTED"
else:
disk_read = str(round(disk_counters.read_bytes / self.interval_s))
disk_write = str(round(disk_counters.write_bytes / self.interval_s))

# We need to poll network IO to get rate
network = psutil.net_io_counters(nowrap=True)
if network is None:
network_sent = network_recv = "NOT_AVAILABLE"
else:
if self._prev_network:
network_sent_bps = round(
(network.bytes_sent - self._prev_network.bytes_sent) / self.interval_s
)
network_recv_bps = round(
(network.bytes_recv - self._prev_network.bytes_recv) / self.interval_s
)
# We need to poll network IO to get rate
if network is None:
network_sent = network_recv = "NOT_AVAILABLE"
else:
network_sent_bps = network_recv_bps = 0
network_sent = str(network_sent_bps)
network_recv = str(network_recv_bps)
self._prev_network = network

stats = {
"cpu-usage-percent": str(psutil.cpu_percent()),
"memory-total-bytes": str(memory.total),
"memory-used-bytes": str(memory.total - memory.available),
"memory-used-percent": str(memory.percent),
"swap-used-bytes": str(swap.used),
"total-disk-bytes": str(disk.total),
"total-disk-used-bytes": str(disk.used),
"total-disk-used-percent": str(round(disk.used / disk.total, ndigits=1)),
"user-disk-available-bytes": str(disk.free),
"network-sent-bytes-per-second": network_sent,
"network-recv-bytes-per-second": network_recv,
"disk-read-bytes-per-second": disk_read,
"disk-write-bytes-per-second": disk_write,
}

# Output as space-delimited "key value" pairs for consumption by Cloudwatch to use as metrics
self.logger.info(" ".join(" ".join(kvp) for kvp in stats.items()))
self._set_timer()
if self._prev_network:
network_sent_bps = round(
(network.bytes_sent - self._prev_network.bytes_sent) / self.interval_s
)
network_recv_bps = round(
(network.bytes_recv - self._prev_network.bytes_recv) / self.interval_s
)
else:
network_sent_bps = network_recv_bps = 0
network_sent = str(network_sent_bps)
network_recv = str(network_recv_bps)
self._prev_network = network

stats = {
"cpu-usage-percent": str(cpu_percent),
"memory-total-bytes": str(memory.total),
"memory-used-bytes": str(memory.total - memory.available),
"memory-used-percent": str(memory.percent),
"swap-used-bytes": str(swap.used),
"total-disk-bytes": str(disk.total),
"total-disk-used-bytes": str(disk.used),
"total-disk-used-percent": str(round(disk.used / disk.total, ndigits=1)),
"user-disk-available-bytes": str(disk.free),
"network-sent-bytes-per-second": network_sent,
"network-recv-bytes-per-second": network_recv,
"disk-read-bytes-per-second": disk_read,
"disk-write-bytes-per-second": disk_write,
}

# Output as space-delimited "key value" pairs for consumption by Cloudwatch to use as metrics
self.logger.info(" ".join(" ".join(kvp) for kvp in stats.items()))
finally:
self._set_timer()

def _set_timer(self) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/deadline_worker_agent/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
worker_persistence_dir: Path,
worker_logs_dir: Path | None,
host_metrics_logging: bool,
host_metrics_logging_interval_seconds: float | None,
host_metrics_logging_interval_seconds: float | None = None,
) -> None:
self._deadline_client = deadline_client
self._s3_client = s3_client
Expand Down
91 changes: 51 additions & 40 deletions test/unit/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
import deadline_worker_agent.metrics as metrics_mod


@pytest.fixture(autouse=True)
def mock_psutil_module() -> Generator[MagicMock, None, None]:
"""Mock the entire psutil module to prevent future errors due to KeyError from psutil.virtual_memory()"""
with patch.object(metrics_mod, "psutil") as mock:
yield mock


class TestHostMetricsLogger:
BYTES_PATTERN = r"[0-9]+(?:\.[0-9]+)?"
PERCENT_PATTERN = r"[0-9]{1,3}(?:\.[0-9]+)?"
Expand All @@ -28,11 +35,11 @@ def host_metrics_logger(self, logger: MagicMock) -> HostMetricsLogger:

def test_enter(self, host_metrics_logger: HostMetricsLogger):
# GIVEN
with patch.object(host_metrics_logger, "_set_timer") as mock_set_timer:
with patch.object(host_metrics_logger, "log_metrics") as mock_log_metrics:
# WHEN
with host_metrics_logger:
# THEN
mock_set_timer.assert_called_once()
mock_log_metrics.assert_called_once()

@pytest.mark.parametrize("timer_exists", [True, False])
def test_exit(
Expand All @@ -44,9 +51,10 @@ def test_exit(
timer = MagicMock()

# WHEN
with host_metrics_logger:
if timer_exists:
host_metrics_logger._timer = timer
with patch.object(host_metrics_logger, "__enter__"):
with host_metrics_logger:
if timer_exists:
host_metrics_logger._timer = timer

# THEN
if timer_exists:
Expand Down Expand Up @@ -140,7 +148,7 @@ def disk_io_counters(self) -> tuple:
)
return dioc(123, 321, 123123, 321321, 100, 200)

@pytest.fixture
@pytest.fixture(autouse=True)
def mock_psutil(
self,
virtual_memory: tuple,
Expand Down Expand Up @@ -276,42 +284,45 @@ def test_network_rate_not_available(
assert re.search(r"network-sent-bytes-per-second NOT_AVAILABLE", log_line)
assert re.search(r"network-recv-bytes-per-second NOT_AVAILABLE", log_line)

def test_log_metrics_correct_encoding(self, caplog: pytest.LogCaptureFixture) -> None:
# GIVEN
DECIMAL_NUMBER_PATTERN = r"\d+(?:\.\d+)?"
EXPECTED_LOG_MESSAGE_PATTERN = " ".join(
# fmt: off
[
"cpu-usage-percent", DECIMAL_NUMBER_PATTERN,
"memory-total-bytes", DECIMAL_NUMBER_PATTERN,
"memory-used-bytes", DECIMAL_NUMBER_PATTERN,
"memory-used-percent", DECIMAL_NUMBER_PATTERN,
"swap-used-bytes", DECIMAL_NUMBER_PATTERN,
"total-disk-bytes", DECIMAL_NUMBER_PATTERN,
"total-disk-used-bytes", DECIMAL_NUMBER_PATTERN,
"total-disk-used-percent", DECIMAL_NUMBER_PATTERN,
"user-disk-available-bytes", DECIMAL_NUMBER_PATTERN,
"network-sent-bytes-per-second", rf"(?:{DECIMAL_NUMBER_PATTERN}|NOT_AVAILABLE)",
"network-recv-bytes-per-second", rf"(?:{DECIMAL_NUMBER_PATTERN}|NOT_AVAILABLE)",
"disk-read-bytes-per-second", rf"(?:{DECIMAL_NUMBER_PATTERN}|NOT_AVAILABLE|NOT_SUPPORTED)",
"disk-write-bytes-per-second", rf"(?:{DECIMAL_NUMBER_PATTERN}|NOT_AVAILABLE|NOT_SUPPORTED)",
]
# fmt: on
)
logger = logging.getLogger(__name__)
caplog.set_level(0, logger.name)
host_metrics_logger = HostMetricsLogger(logger=logger, interval_s=1)
def test_log_metrics_correct_encoding(
self,
caplog: pytest.LogCaptureFixture,
) -> None:
# GIVEN
DECIMAL_NUMBER_PATTERN = r"\d+(?:\.\d+)?"
EXPECTED_LOG_MESSAGE_PATTERN = " ".join(
# fmt: off
[
"cpu-usage-percent", DECIMAL_NUMBER_PATTERN,
"memory-total-bytes", DECIMAL_NUMBER_PATTERN,
"memory-used-bytes", DECIMAL_NUMBER_PATTERN,
"memory-used-percent", DECIMAL_NUMBER_PATTERN,
"swap-used-bytes", DECIMAL_NUMBER_PATTERN,
"total-disk-bytes", DECIMAL_NUMBER_PATTERN,
"total-disk-used-bytes", DECIMAL_NUMBER_PATTERN,
"total-disk-used-percent", DECIMAL_NUMBER_PATTERN,
"user-disk-available-bytes", DECIMAL_NUMBER_PATTERN,
"network-sent-bytes-per-second", rf"(?:{DECIMAL_NUMBER_PATTERN}|NOT_AVAILABLE)",
"network-recv-bytes-per-second", rf"(?:{DECIMAL_NUMBER_PATTERN}|NOT_AVAILABLE)",
"disk-read-bytes-per-second", rf"(?:{DECIMAL_NUMBER_PATTERN}|NOT_AVAILABLE|NOT_SUPPORTED)",
"disk-write-bytes-per-second", rf"(?:{DECIMAL_NUMBER_PATTERN}|NOT_AVAILABLE|NOT_SUPPORTED)",
]
# fmt: on
)
logger = logging.getLogger(__name__)
caplog.set_level(0, logger.name)
host_metrics_logger = HostMetricsLogger(logger=logger, interval_s=1)

# WHEN
with (
# We don't want to actually create/start a timer
patch.object(metrics_mod, "Timer"),
):
host_metrics_logger.log_metrics()
# WHEN
with (
# We don't want to actually create/start a timer
patch.object(metrics_mod, "Timer"),
):
host_metrics_logger.log_metrics()

# THEN
assert len(caplog.messages) == 1
assert re.match(EXPECTED_LOG_MESSAGE_PATTERN, caplog.messages[0])
# THEN
assert len(caplog.messages) == 1
assert re.match(EXPECTED_LOG_MESSAGE_PATTERN, caplog.messages[0])


def get_first_and_only_call_arg(mock: MagicMock) -> Any:
Expand Down
7 changes: 3 additions & 4 deletions test/unit/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def worker(
# This is unused, but declaring it as a dependency fixture ensures we mock the scheduler class
# before we instantiate the Worker instance within this fixture body
mock_scheduler_cls: MagicMock,
) -> Worker:
) -> Generator[Worker, None, None]:
with patch.object(worker_mod, "HostMetricsLogger"):
return Worker(
yield Worker(
farm_id=farm_id,
deadline_client=client,
boto_session=boto_session,
Expand All @@ -70,8 +70,7 @@ def worker(
cleanup_session_user_processes=True,
worker_persistence_dir=Path("/var/lib/deadline"),
worker_logs_dir=worker_logs_dir,
host_metrics_logging=True,
host_metrics_logging_interval_seconds=60,
host_metrics_logging=False,
)


Expand Down

0 comments on commit 1eed0fe

Please sign in to comment.