Skip to content

Commit

Permalink
Fix flaky qrexec agent tests
Browse files Browse the repository at this point in the history
Various tests assumed that qrexec would deliver stdout in a single
message.  Qrexec does not make this guarantee: calls to write(2),
send(2), sendmsg(2), etc are not guaranteed to correspond 1-to-1
to MSG_DATA_STDOUT messages on the vchan.  This caused
https://gitlab.com/QubesOS/qubes-core-qrexec/-/jobs/6616564043 to
wrongly fail, even thouggh the code is correct.

Fix this problem by concatenating the payloads of all stdout messages
into a single bytes object before comparing with the expected stdout
value.  Also add a utility function for this, saving a lot of code in
tests.
  • Loading branch information
DemiMarie committed Apr 13, 2024
1 parent edc80de commit dfd804f
Showing 1 changed file with 29 additions and 129 deletions.
158 changes: 29 additions & 129 deletions qrexec/tests/socket/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,21 @@ def check_dom0(self, dom0):
),
)

def assertExpectedStdout(self, target, expected_stdout: bytes, *, exit_code=0):
messages = util.sort_messages(target.recv_all_messages())
self.assertListEqual(messages[-3:],
[
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, struct.pack("<L", exit_code))
])
stdout_entries = []
for msg_type, msg_body in messages[:-3]:
# messages before last are not empty, hence truthy
self.assertTrue(msg_body)
self.assertEqual(msg_type, qrexec.MSG_DATA_STDOUT)
stdout_entries.append(msg_body)

def setUp(self):
self.tempdir = tempfile.mkdtemp()
os.mkdir(os.path.join(self.tempdir, "local-rpc"))
Expand Down Expand Up @@ -188,16 +203,7 @@ def test_exec_cmdline(self):

target.send_message(qrexec.MSG_DATA_STDIN, b"")

messages = target.recv_all_messages()
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b"Hello world\n"),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.assertExpectedStdout(target, b"Hello world\n")
self.check_dom0(dom0)

def test_trigger_service(self):
Expand Down Expand Up @@ -314,16 +320,7 @@ def test_exec_service(self):
)
target, dom0 = self.execute_qubesrpc("qubes.Service+arg", "domX")
target.send_message(qrexec.MSG_DATA_STDIN, b"")
messages = target.recv_all_messages()
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b"arg: arg, remote domain: domX\n"),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.assertExpectedStdout(target, b"arg: arg, remote domain: domX\n")
self.check_dom0(dom0)

def test_exec_service_keyword(self):
Expand All @@ -342,20 +339,11 @@ def test_exec_service_keyword(self):
)
target, dom0 = self.execute_qubesrpc("qubes.Service", "domX")
target.send_message(qrexec.MSG_DATA_STDIN, b"")
messages = target.recv_all_messages()
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b"""arg: , remote domain: domX
self.assertExpectedStdout(target, b"""arg: , remote domain: domX
target name: NONAME
target keyword: NOKEYWORD
target type: ''
"""),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
""")
self.check_dom0(dom0)

def test_exec_service_with_config(self):
Expand All @@ -376,16 +364,7 @@ def test_exec_service_with_config(self):
""")
target, dom0 = self.execute_qubesrpc("qubes.Service+arg", "domX")
target.send_message(qrexec.MSG_DATA_STDIN, b"")
messages = target.recv_all_messages()
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b"arg: arg, remote domain: domX\n"),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.assertExpectedStdout(target, b"arg: arg, remote domain: domX\n")
self.check_dom0(dom0)

def test_wait_for_session(self):
Expand Down Expand Up @@ -449,20 +428,9 @@ def _test_wait_for_session(self, config_name, service_name="qubes.Service", argu
# Do not send EOF. Shell read doesn't need it, and this checks that
# qrexec does not wait for EOF on stdin before sending the exit code
# from the remote process.
messages = target.recv_all_messages()
self.assertListEqual(
util.sort_messages(messages),
[
(
qrexec.MSG_DATA_STDOUT,
b"arg: " + argument.encode("ascii", "strict")
+ b", remote domain: domX, input: stdin data\n",
),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
expected_stdout = (b"arg: " + argument.encode("ascii", "strict")
+ b", remote domain: domX, input: stdin data\n")
self.assertExpectedStdout(target, expected_stdout)
self.check_dom0(dom0)

def test_exec_service_fail(self):
Expand Down Expand Up @@ -599,16 +567,7 @@ def test_exec_null_argument_finds_service_for_empty_argument(self):
)
target, dom0 = self.execute_qubesrpc("qubes.Service", "domX")
target.send_message(qrexec.MSG_DATA_STDIN, b"")
messages = target.recv_all_messages()
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b"specific service: qubes.Service\n"),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.assertExpectedStdout(target, b"specific service: qubes.Service\n")
self.check_dom0(dom0)

def test_socket_null_argument_finds_service_for_empty_argument(self):
Expand Down Expand Up @@ -806,16 +765,7 @@ def test_pass_stdin(self):
)

target.send_message(qrexec.MSG_DATA_STDIN, b"")
messages = target.recv_all_messages()
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.check_dom0(dom0)
self.assertExpectedStdout(target, b"")

def test_close_stdin_early(self):
# Make sure that we cover the error on writing stdin into living
Expand All @@ -835,15 +785,7 @@ def test_close_stdin_early(self):
target.send_message(qrexec.MSG_DATA_STDIN, b"data 2\n")
target.send_message(qrexec.MSG_DATA_STDIN, b"")

messages = target.recv_all_messages()
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.assertExpectedStdout(target, b"")
self.check_dom0(dom0)

def test_buffer_stdin(self):
Expand Down Expand Up @@ -871,29 +813,7 @@ def test_buffer_stdin(self):
with open(fifo, "a") as f:
f.write("end\n")
f.flush()

messages = []
received_data = b""
while len(received_data) < data_size:
message_type, message = target.recv_message()
if message_type != qrexec.MSG_DATA_STDOUT:
messages.append((message_type, message))
else:
self.assertEqual(message_type, qrexec.MSG_DATA_STDOUT)
received_data += message

self.assertEqual(len(received_data), data_size)
self.assertEqual(received_data, data)

messages += target.recv_all_messages()
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.assertExpectedStdout(target, data)
self.check_dom0(dom0)

def test_close_stdout_stderr_early(self):
Expand Down Expand Up @@ -948,16 +868,7 @@ def test_stdio_socket(self):
target.send_message(qrexec.MSG_DATA_STDIN, b"stdin\n")
target.send_message(qrexec.MSG_DATA_STDIN, b"")

messages = target.recv_all_messages()
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b"received: stdin\n"),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.assertExpectedStdout(target, b"received: stdin\n")
self.check_dom0(dom0)

def test_exit_before_closing_streams(self):
Expand Down Expand Up @@ -995,18 +906,7 @@ def test_exit_before_closing_streams(self):
with open(fifo, "a") as f:
f.write("end\n")
f.flush()
self.assertEqual(
target.recv_message(), (qrexec.MSG_DATA_STDOUT, b"child exiting\n")
)
messages = target.recv_all_messages()
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, struct.pack("<L", 42)),
],
)
self.assertExpectedStdout(target, b"child exiting\n", exit_code=42)
self.check_dom0(dom0)


Expand Down

0 comments on commit dfd804f

Please sign in to comment.