Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/pr/144'
Browse files Browse the repository at this point in the history
* origin/pr/144:
  Fix flaky qrexec agent tests
  Check for dom0 messages in more agent tests
  • Loading branch information
marmarek committed Apr 14, 2024
2 parents f3a5784 + dfd804f commit daee92e
Showing 1 changed file with 57 additions and 170 deletions.
227 changes: 57 additions & 170 deletions qrexec/tests/socket/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,30 @@ class TestAgentBase(unittest.TestCase):
target_domain = 43
target_port = 1024

def check_dom0(self, dom0):
self.assertEqual(
dom0.recv_message(),
(
qrexec.MSG_CONNECTION_TERMINATED,
struct.pack("<LL", self.target_domain, self.target_port),
),
)

def assertExpectedStdout(self, target, expected_stdout: bytes, *, exit_code=0):
messages = util.sort_messages(target.recv_all_messages())
self.assertListEqual(messages[-3:],
[
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, struct.pack("<L", exit_code))
])
stdout_entries = []
for msg_type, msg_body in messages[:-3]:
# messages before last are not empty, hence truthy
self.assertTrue(msg_body)
self.assertEqual(msg_type, qrexec.MSG_DATA_STDOUT)
stdout_entries.append(msg_body)

def setUp(self):
self.tempdir = tempfile.mkdtemp()
os.mkdir(os.path.join(self.tempdir, "local-rpc"))
Expand Down Expand Up @@ -157,14 +181,7 @@ def test_just_exec(self):
lambda: os.path.exists(os.path.join(self.tempdir, "new_file")),
"file created",
)

self.assertEqual(
dom0.recv_message(),
(
qrexec.MSG_CONNECTION_TERMINATED,
struct.pack("<LL", self.target_domain, self.target_port),
),
)
self.check_dom0(dom0)

def test_exec_cmdline(self):
self.start_agent()
Expand All @@ -186,24 +203,8 @@ def test_exec_cmdline(self):

target.send_message(qrexec.MSG_DATA_STDIN, b"")

messages = target.recv_all_messages()
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b"Hello world\n"),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)

self.assertEqual(
dom0.recv_message(),
(
qrexec.MSG_CONNECTION_TERMINATED,
struct.pack("<LL", self.target_domain, self.target_port),
),
)
self.assertExpectedStdout(target, b"Hello world\n")
self.check_dom0(dom0)

def test_trigger_service(self):
self.start_agent()
Expand All @@ -229,13 +230,7 @@ def test_trigger_service(self):
)

client.close()
self.assertEqual(
dom0.recv_message(),
(
qrexec.MSG_CONNECTION_TERMINATED,
struct.pack("<LL", self.target_domain, self.target_port),
),
)
self.check_dom0(dom0)

def test_trigger_service_refused(self):
self.start_agent()
Expand Down Expand Up @@ -310,15 +305,6 @@ def execute_qubesrpc(self, service: str, src_domain_name: str):
target.handshake()
return target, dom0

def check_dom0(self, dom0):
self.assertEqual(
dom0.recv_message(),
(
qrexec.MSG_CONNECTION_TERMINATED,
struct.pack("<LL", self.target_domain, self.target_port),
),
)

def make_executable_service(self, *args):
util.make_executable_service(self.tempdir, *args)

Expand All @@ -332,18 +318,10 @@ def test_exec_service(self):
echo "arg: $1, remote domain: $QREXEC_REMOTE_DOMAIN"
""",
)
target, _ = self.execute_qubesrpc("qubes.Service+arg", "domX")
target, dom0 = self.execute_qubesrpc("qubes.Service+arg", "domX")
target.send_message(qrexec.MSG_DATA_STDIN, b"")
messages = target.recv_all_messages()
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b"arg: arg, remote domain: domX\n"),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.assertExpectedStdout(target, b"arg: arg, remote domain: domX\n")
self.check_dom0(dom0)

def test_exec_service_keyword(self):
util.make_executable_service(
Expand All @@ -361,20 +339,11 @@ def test_exec_service_keyword(self):
)
target, dom0 = self.execute_qubesrpc("qubes.Service", "domX")
target.send_message(qrexec.MSG_DATA_STDIN, b"")
messages = target.recv_all_messages()
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b"""arg: , remote domain: domX
self.assertExpectedStdout(target, b"""arg: , remote domain: domX
target name: NONAME
target keyword: NOKEYWORD
target type: ''
"""),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
""")
self.check_dom0(dom0)

def test_exec_service_with_config(self):
Expand All @@ -395,16 +364,7 @@ def test_exec_service_with_config(self):
""")
target, dom0 = self.execute_qubesrpc("qubes.Service+arg", "domX")
target.send_message(qrexec.MSG_DATA_STDIN, b"")
messages = target.recv_all_messages()
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b"arg: arg, remote domain: domX\n"),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.assertExpectedStdout(target, b"arg: arg, remote domain: domX\n")
self.check_dom0(dom0)

def test_wait_for_session(self):
Expand Down Expand Up @@ -468,20 +428,9 @@ def _test_wait_for_session(self, config_name, service_name="qubes.Service", argu
# Do not send EOF. Shell read doesn't need it, and this checks that
# qrexec does not wait for EOF on stdin before sending the exit code
# from the remote process.
messages = target.recv_all_messages()
self.assertListEqual(
util.sort_messages(messages),
[
(
qrexec.MSG_DATA_STDOUT,
b"arg: " + argument.encode("ascii", "strict")
+ b", remote domain: domX, input: stdin data\n",
),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
expected_stdout = (b"arg: " + argument.encode("ascii", "strict")
+ b", remote domain: domX, input: stdin data\n")
self.assertExpectedStdout(target, expected_stdout)
self.check_dom0(dom0)

def test_exec_service_fail(self):
Expand Down Expand Up @@ -618,16 +567,7 @@ def test_exec_null_argument_finds_service_for_empty_argument(self):
)
target, dom0 = self.execute_qubesrpc("qubes.Service", "domX")
target.send_message(qrexec.MSG_DATA_STDIN, b"")
messages = target.recv_all_messages()
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b"specific service: qubes.Service\n"),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.assertExpectedStdout(target, b"specific service: qubes.Service\n")
self.check_dom0(dom0)

def test_socket_null_argument_finds_service_for_empty_argument(self):
Expand Down Expand Up @@ -793,10 +733,10 @@ def execute(self, cmd: str):

target = self.connect_target()
target.handshake()
return target
return target, dom0

def test_stdin_stderr(self):
target = self.execute('echo "stdout"; echo "stderr" >&2')
target, dom0 = self.execute('echo "stdout"; echo "stderr" >&2')
target.send_message(qrexec.MSG_DATA_STDIN, b"")

messages = target.recv_all_messages()
Expand All @@ -812,7 +752,7 @@ def test_stdin_stderr(self):
)

def test_pass_stdin(self):
target = self.execute("cat")
target, dom0 = self.execute("cat")

target.send_message(qrexec.MSG_DATA_STDIN, b"data 1")
self.assertEqual(
Expand All @@ -825,24 +765,16 @@ def test_pass_stdin(self):
)

target.send_message(qrexec.MSG_DATA_STDIN, b"")
messages = target.recv_all_messages()
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.assertExpectedStdout(target, b"")

def test_close_stdin_early(self):
# Make sure that we cover the error on writing stdin into living
# process.
target = self.execute(
target, dom0 = self.execute(
"""
read
exec <&-
echo closed stdin
echo "closed stdin"
sleep 1
"""
)
Expand All @@ -853,15 +785,8 @@ def test_close_stdin_early(self):
target.send_message(qrexec.MSG_DATA_STDIN, b"data 2\n")
target.send_message(qrexec.MSG_DATA_STDIN, b"")

messages = target.recv_all_messages()
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.assertExpectedStdout(target, b"")
self.check_dom0(dom0)

def test_buffer_stdin(self):
# Test to trigger WRITE_STDIN_BUFFERED.
Expand All @@ -877,7 +802,7 @@ def test_buffer_stdin(self):

fifo = os.path.join(self.tempdir, "fifo")
os.mkfifo(fifo)
target = self.execute("read <{}; cat".format(fifo))
target, dom0 = self.execute("read <{}; cat".format(fifo))

for i in range(0, data_size, msg_size):
msg = data[i : i + msg_size]
Expand All @@ -888,32 +813,11 @@ def test_buffer_stdin(self):
with open(fifo, "a") as f:
f.write("end\n")
f.flush()

messages = []
received_data = b""
while len(received_data) < data_size:
message_type, message = target.recv_message()
if message_type != qrexec.MSG_DATA_STDOUT:
messages.append((message_type, message))
else:
self.assertEqual(message_type, qrexec.MSG_DATA_STDOUT)
received_data += message

self.assertEqual(len(received_data), data_size)
self.assertEqual(received_data, data)

messages += target.recv_all_messages()
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.assertExpectedStdout(target, data)
self.check_dom0(dom0)

def test_close_stdout_stderr_early(self):
target = self.execute(
target, dom0 = self.execute(
"""\
read
echo closing stdout
Expand Down Expand Up @@ -946,9 +850,10 @@ def test_close_stdout_stderr_early(self):
target.recv_message(),
(qrexec.MSG_DATA_EXIT_CODE, struct.pack("<L", 42)),
)
self.check_dom0(dom0)

def test_stdio_socket(self):
target = self.execute(
target, dom0 = self.execute(
"""\
kill -USR1 $QREXEC_AGENT_PID
echo hello world >&0
Expand All @@ -963,21 +868,13 @@ def test_stdio_socket(self):
target.send_message(qrexec.MSG_DATA_STDIN, b"stdin\n")
target.send_message(qrexec.MSG_DATA_STDIN, b"")

messages = target.recv_all_messages()
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b"received: stdin\n"),
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.assertExpectedStdout(target, b"received: stdin\n")
self.check_dom0(dom0)

def test_exit_before_closing_streams(self):
fifo = os.path.join(self.tempdir, "fifo")
os.mkfifo(fifo)
target = self.execute(
target, dom0 = self.execute(
"""\
# duplicate original stdin to fd 3, because bash will
# close original stdin in child process
Expand Down Expand Up @@ -1009,18 +906,8 @@ def test_exit_before_closing_streams(self):
with open(fifo, "a") as f:
f.write("end\n")
f.flush()
self.assertEqual(
target.recv_message(), (qrexec.MSG_DATA_STDOUT, b"child exiting\n")
)
messages = target.recv_all_messages()
self.assertListEqual(
util.sort_messages(messages),
[
(qrexec.MSG_DATA_STDOUT, b""),
(qrexec.MSG_DATA_STDERR, b""),
(qrexec.MSG_DATA_EXIT_CODE, struct.pack("<L", 42)),
],
)
self.assertExpectedStdout(target, b"child exiting\n", exit_code=42)
self.check_dom0(dom0)


@unittest.skipIf(os.environ.get("SKIP_SOCKET_TESTS"), "socket tests not set up")
Expand Down

0 comments on commit daee92e

Please sign in to comment.