Skip to content

Commit

Permalink
Check for dom0 messages in more agent tests
Browse files Browse the repository at this point in the history
Use a utility function for doing so, instead of open-coding the checks.
Some code goes away.
  • Loading branch information
DemiMarie committed Apr 13, 2024
1 parent 2be9adc commit edc80de
Showing 1 changed file with 29 additions and 42 deletions.
71 changes: 29 additions & 42 deletions qrexec/tests/socket/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ class TestAgentBase(unittest.TestCase):
target_domain = 43
target_port = 1024

def check_dom0(self, dom0):
self.assertEqual(
dom0.recv_message(),
(
qrexec.MSG_CONNECTION_TERMINATED,
struct.pack("<LL", self.target_domain, self.target_port),
),
)

def setUp(self):
self.tempdir = tempfile.mkdtemp()
os.mkdir(os.path.join(self.tempdir, "local-rpc"))
Expand Down Expand Up @@ -157,14 +166,7 @@ def test_just_exec(self):
lambda: os.path.exists(os.path.join(self.tempdir, "new_file")),
"file created",
)

self.assertEqual(
dom0.recv_message(),
(
qrexec.MSG_CONNECTION_TERMINATED,
struct.pack("<LL", self.target_domain, self.target_port),
),
)
self.check_dom0(dom0)

def test_exec_cmdline(self):
self.start_agent()
Expand Down Expand Up @@ -196,14 +198,7 @@ def test_exec_cmdline(self):
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)

self.assertEqual(
dom0.recv_message(),
(
qrexec.MSG_CONNECTION_TERMINATED,
struct.pack("<LL", self.target_domain, self.target_port),
),
)
self.check_dom0(dom0)

def test_trigger_service(self):
self.start_agent()
Expand All @@ -229,13 +224,7 @@ def test_trigger_service(self):
)

client.close()
self.assertEqual(
dom0.recv_message(),
(
qrexec.MSG_CONNECTION_TERMINATED,
struct.pack("<LL", self.target_domain, self.target_port),
),
)
self.check_dom0(dom0)

def test_trigger_service_refused(self):
self.start_agent()
Expand Down Expand Up @@ -310,15 +299,6 @@ def execute_qubesrpc(self, service: str, src_domain_name: str):
target.handshake()
return target, dom0

def check_dom0(self, dom0):
self.assertEqual(
dom0.recv_message(),
(
qrexec.MSG_CONNECTION_TERMINATED,
struct.pack("<LL", self.target_domain, self.target_port),
),
)

def make_executable_service(self, *args):
util.make_executable_service(self.tempdir, *args)

Expand All @@ -332,7 +312,7 @@ def test_exec_service(self):
echo "arg: $1, remote domain: $QREXEC_REMOTE_DOMAIN"
""",
)
target, _ = self.execute_qubesrpc("qubes.Service+arg", "domX")
target, dom0 = self.execute_qubesrpc("qubes.Service+arg", "domX")
target.send_message(qrexec.MSG_DATA_STDIN, b"")
messages = target.recv_all_messages()
self.assertListEqual(
Expand All @@ -344,6 +324,7 @@ def test_exec_service(self):
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.check_dom0(dom0)

def test_exec_service_keyword(self):
util.make_executable_service(
Expand Down Expand Up @@ -793,10 +774,10 @@ def execute(self, cmd: str):

target = self.connect_target()
target.handshake()
return target
return target, dom0

def test_stdin_stderr(self):
target = self.execute('echo "stdout"; echo "stderr" >&2')
target, dom0 = self.execute('echo "stdout"; echo "stderr" >&2')
target.send_message(qrexec.MSG_DATA_STDIN, b"")

messages = target.recv_all_messages()
Expand All @@ -812,7 +793,7 @@ def test_stdin_stderr(self):
)

def test_pass_stdin(self):
target = self.execute("cat")
target, dom0 = self.execute("cat")

target.send_message(qrexec.MSG_DATA_STDIN, b"data 1")
self.assertEqual(
Expand All @@ -834,15 +815,16 @@ def test_pass_stdin(self):
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.check_dom0(dom0)

def test_close_stdin_early(self):
# Make sure that we cover the error on writing stdin into living
# process.
target = self.execute(
target, dom0 = self.execute(
"""
read
exec <&-
echo closed stdin
echo "closed stdin"
sleep 1
"""
)
Expand All @@ -862,6 +844,7 @@ def test_close_stdin_early(self):
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.check_dom0(dom0)

def test_buffer_stdin(self):
# Test to trigger WRITE_STDIN_BUFFERED.
Expand All @@ -877,7 +860,7 @@ def test_buffer_stdin(self):

fifo = os.path.join(self.tempdir, "fifo")
os.mkfifo(fifo)
target = self.execute("read <{}; cat".format(fifo))
target, dom0 = self.execute("read <{}; cat".format(fifo))

for i in range(0, data_size, msg_size):
msg = data[i : i + msg_size]
Expand Down Expand Up @@ -911,9 +894,10 @@ def test_buffer_stdin(self):
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.check_dom0(dom0)

def test_close_stdout_stderr_early(self):
target = self.execute(
target, dom0 = self.execute(
"""\
read
echo closing stdout
Expand Down Expand Up @@ -946,9 +930,10 @@ def test_close_stdout_stderr_early(self):
target.recv_message(),
(qrexec.MSG_DATA_EXIT_CODE, struct.pack("<L", 42)),
)
self.check_dom0(dom0)

def test_stdio_socket(self):
target = self.execute(
target, dom0 = self.execute(
"""\
kill -USR1 $QREXEC_AGENT_PID
echo hello world >&0
Expand All @@ -973,11 +958,12 @@ def test_stdio_socket(self):
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
],
)
self.check_dom0(dom0)

def test_exit_before_closing_streams(self):
fifo = os.path.join(self.tempdir, "fifo")
os.mkfifo(fifo)
target = self.execute(
target, dom0 = self.execute(
"""\
# duplicate original stdin to fd 3, because bash will
# close original stdin in child process
Expand Down Expand Up @@ -1021,6 +1007,7 @@ def test_exit_before_closing_streams(self):
(qrexec.MSG_DATA_EXIT_CODE, struct.pack("<L", 42)),
],
)
self.check_dom0(dom0)


@unittest.skipIf(os.environ.get("SKIP_SOCKET_TESTS"), "socket tests not set up")
Expand Down

0 comments on commit edc80de

Please sign in to comment.