Skip to content

Commit

Permalink
Use an executor to prevent GSSAPI calls from blocking the event loop
Browse files Browse the repository at this point in the history
Some operations such as GSSAPI calls can sometimes block the event
loop if not run in an executor. However, doing that requires packet
handlers to be asynchronous. This commit adds support for async
packet handlers for key exchange and auth, and changes the GSSAPI
handlers to run the step() call in an executor.
  • Loading branch information
ronf committed Sep 10, 2024
1 parent 358c175 commit cb87de9
Show file tree
Hide file tree
Showing 10 changed files with 210 additions and 155 deletions.
33 changes: 17 additions & 16 deletions asyncssh/auth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2013-2022 by Ron Frederick <[email protected]> and others.
# Copyright (c) 2013-2024 by Ron Frederick <[email protected]> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
Expand Down Expand Up @@ -27,6 +27,7 @@
from .gss import GSSBase, GSSError
from .logging import SSHLogger
from .misc import ProtocolError, PasswordChangeRequired, get_symbol_names
from .misc import run_in_executor
from .packet import Boolean, String, UInt32, SSHPacket, SSHPacketHandler
from .public_key import SigningKey
from .saslprep import saslprep, SASLPrepError
Expand Down Expand Up @@ -199,8 +200,8 @@ def _finish(self) -> None:
else:
self.send_packet(MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE)

def _process_response(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
async def _process_response(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a GSS response from the server"""

mech = packet.get_string()
Expand All @@ -212,7 +213,7 @@ def _process_response(self, _pkttype: int, _pktid: int,
raise ProtocolError('Mechanism mismatch')

try:
token = self._gss.step()
token = await run_in_executor(self._gss.step)
assert token is not None

self.send_packet(MSG_USERAUTH_GSSAPI_TOKEN, String(token))
Expand All @@ -225,8 +226,8 @@ def _process_response(self, _pkttype: int, _pktid: int,

self._conn.try_next_auth()

def _process_token(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
async def _process_token(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a GSS token from the server"""

token: Optional[bytes] = packet.get_string()
Expand All @@ -235,7 +236,7 @@ def _process_token(self, _pkttype: int, _pktid: int,
assert self._gss is not None

try:
token = self._gss.step(token)
token = await run_in_executor(self._gss.step, token)

if token:
self.send_packet(MSG_USERAUTH_GSSAPI_TOKEN, String(token))
Expand All @@ -261,8 +262,8 @@ def _process_error(self, _pkttype: int, _pktid: int,
self.logger.debug1('GSS error from server: %s', msg)
self._got_error = True

def _process_error_token(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
async def _process_error_token(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a GSS error token from the server"""

token = packet.get_string()
Expand All @@ -271,7 +272,7 @@ def _process_error_token(self, _pkttype: int, _pktid: int,
assert self._gss is not None

try:
self._gss.step(token)
await run_in_executor(self._gss.step, token)
except GSSError as exc:
if not self._got_error: # pragma: no cover
self.logger.debug1('GSS error from server: %s', str(exc))
Expand Down Expand Up @@ -649,15 +650,15 @@ async def _finish(self) -> None:
else:
self.send_failure()

def _process_token(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
async def _process_token(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a GSS token from the client"""

token: Optional[bytes] = packet.get_string()
packet.check_end()

try:
token = self._gss.step(token)
token = await run_in_executor(self._gss.step, token)

if token:
self.send_packet(MSG_USERAUTH_GSSAPI_TOKEN, String(token))
Expand All @@ -682,15 +683,15 @@ def _process_exchange_complete(self, _pkttype: int, _pktid: int,
else:
self.send_failure()

def _process_error_token(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
async def _process_error_token(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a GSS error token from the client"""

token = packet.get_string()
packet.check_end()

try:
self._gss.step(token)
await run_in_executor(self._gss.step, token)
except GSSError as exc:
self.logger.debug1('GSS error from client: %s', str(exc))

Expand Down
57 changes: 40 additions & 17 deletions asyncssh/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,17 +1326,7 @@ def data_received(self, data: bytes, datatype: DataType = None) -> None:

self._inpbuf += data

self._reset_keepalive_timer()

# pylint: disable=broad-except
try:
while self._inpbuf and self._recv_handler():
pass
except DisconnectError as exc:
self._send_disconnect(exc.code, exc.reason, exc.lang)
self._force_close(exc)
except Exception:
self.internal_error()
self._recv_data()
# pylint: enable=arguments-differ

def eof_received(self) -> None:
Expand Down Expand Up @@ -1442,6 +1432,21 @@ def _send_version(self) -> None:

self._send(version + b'\r\n')

def _recv_data(self) -> None:
"""Parse received data"""

self._reset_keepalive_timer()

# pylint: disable=broad-except
try:
while self._inpbuf and self._recv_handler():
pass
except DisconnectError as exc:
self._send_disconnect(exc.code, exc.reason, exc.lang)
self._force_close(exc)
except Exception:
self.internal_error()

def _recv_version(self) -> bool:
"""Receive and parse the remote SSH version"""

Expand Down Expand Up @@ -1595,11 +1600,20 @@ def _recv_packet(self) -> bool:

if not skip_reason:
try:
processed = handler.process_packet(pkttype, seq, packet)
result = handler.process_packet(pkttype, seq, packet)
except PacketDecodeError as exc:
raise ProtocolError(str(exc)) from None

if not processed:
if inspect.isawaitable(result):
# Buffer received data until current packet is processed
self._recv_handler = lambda: False

task = self.create_task(result)
task.add_done_callback(functools.partial(
self._finish_recv_packet, pkttype, seq, is_async=True))

return False
elif not result:
if self._strict_kex and not self._recv_encryption:
exc_reason = 'Strict key exchange violation: ' \
'unexpected packet type %d received' % pkttype
Expand All @@ -1611,6 +1625,14 @@ def _recv_packet(self) -> bool:
if exc_reason:
raise ProtocolError(exc_reason)

self._finish_recv_packet(pkttype, seq)
return True

def _finish_recv_packet(self, pkttype: int, seq: int,
_task: Optional[asyncio.Task] = None,
is_async: bool = False) -> None:
"""Finish processing a packet"""

if pkttype > MSG_USERAUTH_LAST:
self._auth_final = True

Expand All @@ -1625,7 +1647,8 @@ def _recv_packet(self) -> bool:
else:
self._recv_seq = (seq + 1) & 0xffffffff

return True
if is_async and self._inpbuf:
self._recv_data()

def send_packet(self, pkttype: int, *args: bytes,
handler: Optional[SSHPacketLogger] = None) -> None:
Expand Down Expand Up @@ -2218,8 +2241,8 @@ def _process_ext_info(self, _pkttype: int, _pktid: int,
self._server_sig_algs = \
set(extensions.get(b'server-sig-algs', b'').split(b','))

def _process_kexinit(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
async def _process_kexinit(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a key exchange request"""

if self._kex:
Expand Down Expand Up @@ -2323,7 +2346,7 @@ def _process_kexinit(self, _pkttype: int, _pktid: int,
self.logger.debug1('Beginning key exchange')
self.logger.debug2(' Key exchange alg: %s', self._kex.algorithm)

self._kex.start()
await self._kex.start()

def _process_newkeys(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
Expand Down
4 changes: 2 additions & 2 deletions asyncssh/kex.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2013-2022 by Ron Frederick <[email protected]> and others.
# Copyright (c) 2013-2024 by Ron Frederick <[email protected]> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
Expand Down Expand Up @@ -58,7 +58,7 @@ def __init__(self, alg: bytes, conn: 'SSHConnection', hash_alg: HashType):
self._hash_alg = hash_alg


def start(self) -> None:
async def start(self) -> None:
"""Start key exchange"""

raise NotImplementedError
Expand Down
44 changes: 22 additions & 22 deletions asyncssh/kex_dh.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2013-2022 by Ron Frederick <[email protected]> and others.
# Copyright (c) 2013-2024 by Ron Frederick <[email protected]> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
Expand Down Expand Up @@ -33,7 +33,7 @@
from .gss import GSSError
from .kex import Kex, register_kex_alg, register_gss_kex_alg
from .misc import HashType, KeyExchangeFailed, ProtocolError
from .misc import get_symbol_names
from .misc import get_symbol_names, run_in_executor
from .packet import Boolean, MPInt, String, UInt32, SSHPacket
from .public_key import SigningKey, VerifyingKey

Expand Down Expand Up @@ -274,7 +274,7 @@ def _process_reply(self, _pkttype: int, _pktid: int,
host_key = client_conn.validate_server_host_key(host_key_data)
self._verify_reply(host_key, host_key_data, sig)

def start(self) -> None:
async def start(self) -> None:
"""Start DH key exchange"""

if self._conn.is_client():
Expand Down Expand Up @@ -384,7 +384,7 @@ def _process_group(self, _pkttype: int, _pktid: int,
self._gex_data += MPInt(p) + MPInt(g)
self._perform_init()

def start(self) -> None:
async def start(self) -> None:
"""Start DH group exchange"""

if self._conn.is_client():
Expand Down Expand Up @@ -455,7 +455,7 @@ def _compute_server_shared(self) -> bytes:
except ValueError:
raise ProtocolError('Invalid ECDH client public key') from None

def start(self) -> None:
async def start(self) -> None:
"""Start ECDH key exchange"""

if self._conn.is_client():
Expand Down Expand Up @@ -567,11 +567,11 @@ def _send_continue(self) -> None:

self.send_packet(MSG_KEXGSS_CONTINUE, String(self._token))

def _process_token(self, token: Optional[bytes] = None) -> None:
async def _process_token(self, token: Optional[bytes] = None) -> None:
"""Process a GSS token"""

try:
self._token = self._gss.step(token)
self._token = await run_in_executor(self._gss.step, token)
except GSSError as exc:
if self._conn.is_server():
self.send_packet(MSG_KEXGSS_ERROR, UInt32(exc.maj_code),
Expand All @@ -583,8 +583,8 @@ def _process_token(self, token: Optional[bytes] = None) -> None:

raise KeyExchangeFailed(str(exc)) from None

def _process_init(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
async def _process_gss_init(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a GSS init message"""

if self._conn.is_client():
Expand All @@ -603,7 +603,7 @@ def _process_init(self, _pkttype: int, _pktid: int,
else:
self._host_key_data = b''

self._process_token(token)
await self._process_token(token)

if self._gss.complete:
self._check_secure()
Expand All @@ -612,8 +612,8 @@ def _process_init(self, _pkttype: int, _pktid: int,
else:
self._send_continue()

def _process_continue(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
async def _process_continue(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a GSS continue message"""

token = packet.get_string()
Expand All @@ -622,16 +622,16 @@ def _process_continue(self, _pkttype: int, _pktid: int,
if self._conn.is_client() and self._gss.complete:
raise ProtocolError('Unexpected kexgss continue msg')

self._process_token(token)
await self._process_token(token)

if self._conn.is_server() and self._gss.complete:
self._check_secure()
self._perform_reply(self._gss, self._host_key_data)
else:
self._send_continue()

def _process_complete(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
async def _process_complete(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a GSS complete message"""

if self._conn.is_server():
Expand All @@ -647,7 +647,7 @@ def _process_complete(self, _pkttype: int, _pktid: int,
if self._gss.complete:
raise ProtocolError('Non-empty token after complete')

self._process_token(token)
await self._process_token(token)

if self._token:
raise ProtocolError('Non-empty token after complete')
Expand Down Expand Up @@ -682,12 +682,12 @@ def _process_error(self, _pkttype: int, _pktid: int,
self._conn.logger.debug1('GSS error: %s',
msg.decode('utf-8', errors='ignore'))

def start(self) -> None:
async def start(self) -> None:
"""Start GSS key exchange"""

if self._conn.is_client():
self._process_token()
super().start()
await self._process_token()
await super().start()


class _KexGSS(_KexGSSBase, _KexDH):
Expand All @@ -696,7 +696,7 @@ class _KexGSS(_KexGSSBase, _KexDH):
_handler_names = get_symbol_names(globals(), 'MSG_KEXGSS_')

_packet_handlers = {
MSG_KEXGSS_INIT: _KexGSSBase._process_init,
MSG_KEXGSS_INIT: _KexGSSBase._process_gss_init,
MSG_KEXGSS_CONTINUE: _KexGSSBase._process_continue,
MSG_KEXGSS_COMPLETE: _KexGSSBase._process_complete,
MSG_KEXGSS_HOSTKEY: _KexGSSBase._process_hostkey,
Expand All @@ -713,7 +713,7 @@ class _KexGSSGex(_KexGSSBase, _KexDHGex):
_group_type = MSG_KEXGSS_GROUP

_packet_handlers = {
MSG_KEXGSS_INIT: _KexGSSBase._process_init,
MSG_KEXGSS_INIT: _KexGSSBase._process_gss_init,
MSG_KEXGSS_CONTINUE: _KexGSSBase._process_continue,
MSG_KEXGSS_COMPLETE: _KexGSSBase._process_complete,
MSG_KEXGSS_HOSTKEY: _KexGSSBase._process_hostkey,
Expand All @@ -729,7 +729,7 @@ class _KexGSSECDH(_KexGSSBase, _KexECDH):
_handler_names = get_symbol_names(globals(), 'MSG_KEXGSS_')

_packet_handlers = {
MSG_KEXGSS_INIT: _KexGSSBase._process_init,
MSG_KEXGSS_INIT: _KexGSSBase._process_gss_init,
MSG_KEXGSS_CONTINUE: _KexGSSBase._process_continue,
MSG_KEXGSS_COMPLETE: _KexGSSBase._process_complete,
MSG_KEXGSS_HOSTKEY: _KexGSSBase._process_hostkey,
Expand Down
Loading

0 comments on commit cb87de9

Please sign in to comment.