diff --git a/launcher/sdw_updater_gui/Updater.py b/launcher/sdw_updater_gui/Updater.py index 4dcb51fe..c53c5320 100644 --- a/launcher/sdw_updater_gui/Updater.py +++ b/launcher/sdw_updater_gui/Updater.py @@ -19,7 +19,7 @@ try: import qubesadmin - qubes = qubesadmin.Qubes() + qubes = qubesadmin.Qubes() # pragma: no cover except ImportError: qubes = None @@ -505,13 +505,13 @@ def _force_shutdown_vm(vm): ) return False - return _wait_for_is_running(vm, False) + return _wait_for(vm, (lambda vm: qubes.domains[vm].is_running() is False)) -def _wait_for_is_running(vm, expected, timeout=60, interval=0.2): +def _wait_for(vm, condition, timeout=60, interval=0.2): """ - Poll for a VM to enter the given is_running state, and give up after a - timeout is reached. + Poll for a VM to enter the state using the function condition() + that must return True when the state is reached Return value: - True if the VM reached the expected state @@ -520,19 +520,21 @@ def _wait_for_is_running(vm, expected, timeout=60, interval=0.2): start_time = time.time() stop_time = start_time + timeout while time.time() < stop_time: - state = qubes.domains[vm].is_running() + # Evaluate condition before time measurement to include its runtime + condition_reached = condition(vm) elapsed = time.time() - start_time - if state == expected: + if condition_reached: sdlog.info( - "VM '{}' entered expected state (is_running() is {}) " - "after {:.2f} seconds of polling.".format(vm, expected, elapsed) + "VM '{}' entered expected state after {:.2f} seconds of " + "polling.".format(vm, elapsed) ) + return True time.sleep(interval) sdlog.error( - "VM '{}' did not enter expected state (is_running() is {}) " - "in the provided timeout of {} seconds.".format(vm, expected, timeout) + "VM '{}' did not enter expected state in the provided timeout of " + "{} seconds.".format(vm, timeout) ) return False diff --git a/launcher/tests/test_updater.py b/launcher/tests/test_updater.py index 8e805ce3..f7d23407 100644 --- a/launcher/tests/test_updater.py +++ b/launcher/tests/test_updater.py @@ -1,8 +1,8 @@ import json import os import pytest -import time import subprocess +import re from importlib.machinery import SourceFileLoader from datetime import datetime, timedelta from tempfile import TemporaryDirectory @@ -57,6 +57,13 @@ "sd-viewer": UpdateStatus.UPDATES_REQUIRED, } +VM_POLLING_REGEX_SUCCESS = ( + r"VM '.*' entered expected state after (.*) seconds of polling." +) +VM_POLLING_REGEX_FAILURE = ( + r"VM '.*' did not enter expected state in the provided timeout of (.*) seconds." +) + def test_updater_vms_present(): assert len(updater.current_templates) == 9 @@ -507,20 +514,34 @@ def test_apply_updates_dom0_failure(mocked_info, mocked_error, mocked_call): @mock.patch("Updater.sdlog.error") @mock.patch("Updater.sdlog.info") -def test_vm_polling(mocked_info, mocked_error): - def mock_api(results): - for r in results: - yield r - time.sleep(0.1) +def test_vm_polling_success(mocked_info, mocked_error): + poll_results = mock.MagicMock(side_effect=(False, False, True)) + assert updater._wait_for("sys-net", poll_results, interval=0.1, timeout=1) is True + assert mocked_info.called + info_string = mocked_info.call_args[0][0] + match = re.search(VM_POLLING_REGEX_SUCCESS, info_string) + assert match is not None + elapsed = float(match.group(1)) + # With a sleep interval of 0.1, at least 0.2 seconds should pass before we + # get the expected result. + assert elapsed >= 0.20 + assert not mocked_error.called - with mock.patch("Updater.qubes") as mocked_qubes: - mocked_qubes.domains = {"sys-net": mock.MagicMock()} - mocked_qubes.domains["sys-net"].is_running = mock.MagicMock( - side_effect=mock_api((True, True, False)) - ) - assert updater._wait_for_is_running("sys-net", False, timeout=1) is True - assert mocked_info.called - assert not mocked_error.called + +@mock.patch("Updater.sdlog.error") +@mock.patch("Updater.sdlog.info") +def test_vm_polling_failure(mocked_info, mocked_error): + poll_results = mock.MagicMock(side_effect=(False, False, False)) + assert ( + updater._wait_for("sys-net", poll_results, interval=0.1, timeout=0.3) is False + ) + assert not mocked_info.called + assert mocked_error.called + error_string = mocked_error.call_args[0][0] + match = re.search(VM_POLLING_REGEX_FAILURE, error_string) + assert match is not None + timeout = float(match.group(1)) + assert timeout == 0.30 @pytest.mark.parametrize("vm", current_templates.keys())