diff --git a/daemon/qrexec-daemon-common.c b/daemon/qrexec-daemon-common.c index fa4a89ce..3d42ab5e 100644 --- a/daemon/qrexec-daemon-common.c +++ b/daemon/qrexec-daemon-common.c @@ -386,9 +386,15 @@ int prepare_local_fds(struct qrexec_parsed_command *command, struct buffer *stdi // See also qrexec-agent/qrexec-agent-data.c static void handle_failed_exec(libvchan_t *data_vchan, bool is_service, int exit_code) { - struct msg_header hdr = { - .type = MSG_DATA_STDOUT, - .len = 0, + const struct msg_header hdr[2] = { + { + .type = MSG_DATA_STDERR, + .len = 0, + }, + { + .type = MSG_DATA_STDOUT, + .len = 0, + }, }; LOG(ERROR, "failed to spawn process, exiting"); @@ -404,7 +410,7 @@ static void handle_failed_exec(libvchan_t *data_vchan, bool is_service, int exit * when we support sockets as a local process. */ if (is_service) { - libvchan_send(data_vchan, &hdr, sizeof(hdr)); + libvchan_send(data_vchan, hdr, sizeof(hdr)); send_exit_code(data_vchan, exit_code); } } diff --git a/libqrexec/process_io.c b/libqrexec/process_io.c index c65b48b6..aba6b78d 100644 --- a/libqrexec/process_io.c +++ b/libqrexec/process_io.c @@ -124,6 +124,11 @@ int qrexec_process_io(const struct process_io_request *req, struct timespec normal_timeout = { 10, 0 }; struct prefix_data empty = { 0, 0 }, prefix = req->prefix_data; + if (is_service && stderr_fd == -1) { + struct msg_header hdr = { .type = MSG_DATA_STDERR, .len = 0 }; + libvchan_send(vchan, &hdr, (int)sizeof(hdr)); + } + struct buffer remote_buffer = { .data = malloc(max_chunk_size), .buflen = max_chunk_size, diff --git a/qrexec/tests/socket/agent.py b/qrexec/tests/socket/agent.py index 7b2f873a..6efc20fc 100644 --- a/qrexec/tests/socket/agent.py +++ b/qrexec/tests/socket/agent.py @@ -696,16 +696,7 @@ def test_socket_null_argument_finds_service_for_empty_argument(self): good_server.sendall(b"stdout data") good_server.close() - messages = target.recv_all_messages() - # No stderr - self.assertListEqual( - util.sort_messages(messages), - [ - (qrexec.MSG_DATA_STDOUT, b"stdout data"), - (qrexec.MSG_DATA_STDOUT, b""), - (qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"), - ], - ) + self.assertExpectedStdout(target, b"stdout data") self.check_dom0(dom0) def _test_connect_socket_bad_config(self, forbidden_key): @@ -725,7 +716,6 @@ def _test_connect_socket_bad_config(self, forbidden_key): target, dom0 = self.execute_qubesrpc("qubes.SocketService+arg2", "domX") messages = target.recv_all_messages() - # No stderr self.assertListEqual( util.sort_messages(messages), [ @@ -765,14 +755,11 @@ def test_connect_socket_exit_on_stdin_eof(self): target.send_message(qrexec.MSG_DATA_STDIN, b"") # Check for EOF on stdin self.assertEqual(server.recvall(len(message) + 1), message) - messages = target.recv_all_messages() - # No stderr - self.assertListEqual( - util.sort_messages(messages), + self.assertEqual(target.recv_all_messages(), [ + (qrexec.MSG_DATA_STDERR, b""), (qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"), - ], - ) + ]) self.check_dom0(dom0) server.close() @@ -800,15 +787,7 @@ def test_connect_socket_exit_on_stdout_eof(self): # Trigger EOF on stdout server.shutdown(socket.SHUT_WR) # Server should exit - messages = target.recv_all_messages() - # No stderr - self.assertListEqual( - util.sort_messages(messages), - [ - (qrexec.MSG_DATA_STDOUT, b""), - (qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"), - ], - ) + self.assertExpectedStdout(target, b"") self.check_dom0(dom0) server.close() @@ -836,16 +815,7 @@ def test_connect_socket_no_metadata(self): server.sendall(b"stdout data") server.close() - messages = target.recv_all_messages() - # No stderr - self.assertListEqual( - util.sort_messages(messages), - [ - (qrexec.MSG_DATA_STDOUT, b"stdout data"), - (qrexec.MSG_DATA_STDOUT, b""), - (qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"), - ], - ) + self.assertExpectedStdout(target, b"stdout data") self.check_dom0(dom0) def test_connect_socket_tcp(self): @@ -880,20 +850,13 @@ def _test_tcp_raw(self, family: int, service: str, host: str, port: int, accept= self.assertEqual(server.recvall(len(message)), message) server.sendall(b"stdout data") server.close() - messages = target.recv_all_messages() self.check_dom0(dom0) - return util.sort_messages(messages) + return target def _test_tcp(self, family: int, service: str, host: str, port: int) -> None: - # No stderr - self.assertListEqual( + self.assertExpectedStdout( self._test_tcp_raw(family, service, host, port), - [ - (qrexec.MSG_DATA_STDOUT, b"stdout data"), - (qrexec.MSG_DATA_STDOUT, b""), - (qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"), - ], - ) + b"stdout data") def test_connect_socket_tcp_port_from_arg(self): socket_path = os.path.join( @@ -930,13 +893,9 @@ def test_connect_socket_tcp_ipv6_service_arg(self): host = "::1" os.symlink(f"/dev/tcp", socket_path) service = f"qubes.SocketService+{host.replace(':', '+')}+{port}" - self.assertListEqual( + self.assertExpectedStdout( self._test_tcp_raw(socket.AF_INET6, service, host, port, skip=False), - [ - (qrexec.MSG_DATA_STDOUT, b"stdout data"), - (qrexec.MSG_DATA_STDOUT, b""), - (qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"), - ], + b"stdout data", ) def _test_connect_socket_tcp_unexpected_host(self, host): @@ -946,16 +905,9 @@ def _test_connect_socket_tcp_unexpected_host(self, host): port = 65535 path = f"/dev/tcp/{host}" os.symlink(path, socket_path) - messages = self._test_tcp_raw(socket.AF_INET, f"qubes.SocketService+{host}+{port}", + target = self._test_tcp_raw(socket.AF_INET, f"qubes.SocketService+{host}+{port}", host, port, accept=False) - self.assertListEqual( - messages, - [ - (qrexec.MSG_DATA_STDOUT, b""), - (qrexec.MSG_DATA_STDERR, b""), - (qrexec.MSG_DATA_EXIT_CODE, b"\175\0\0\0"), - ], - ) + self.assertExpectedStdout(target, b"", exit_code=125) def test_connect_socket_tcp_missing_host(self): """ @@ -1063,16 +1015,7 @@ def test_connect_socket(self): server.sendall(b"stdout data") server.close() - messages = target.recv_all_messages() - # No stderr - self.assertListEqual( - util.sort_messages(messages), - [ - (qrexec.MSG_DATA_STDOUT, b"stdout data"), - (qrexec.MSG_DATA_STDOUT, b""), - (qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"), - ], - ) + self.assertExpectedStdout(target, b"stdout data") self.check_dom0(dom0) def test_service_close_stdout_stderr_early(self): diff --git a/qrexec/tests/socket/daemon.py b/qrexec/tests/socket/daemon.py index c3d6d78a..f605d70c 100644 --- a/qrexec/tests/socket/daemon.py +++ b/qrexec/tests/socket/daemon.py @@ -768,6 +768,7 @@ def connect_service_request(self, cmd, timeout=None): source.accept() source.handshake() + self.assertEqual(source.recv_message(), (qrexec.MSG_DATA_STDERR, b"")) return source def test_run_dom0_command_and_connect_vm(self):