Skip to content

Commit

Permalink
bpo-24334: Cleanup SSLSocket (#5252)
Browse files Browse the repository at this point in the history
* The SSLSocket is no longer implemented on top of SSLObject to
  avoid an extra level of indirection.
* Owner and session are now handled in the internal constructor.
* _ssl._SSLSocket now uses the same method names as SSLSocket and
  SSLObject.
* Channel binding type check is now handled in C code. Channel binding
  is always available.

The patch also changes the signature of SSLObject.__init__(). In my
opinion it's fine. A SSLObject is not a user-constructable object.
SSLContext.wrap_bio() is the only valid factory.
  • Loading branch information
tiran authored Feb 24, 2018
1 parent b18f8bc commit 141c5e8
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 117 deletions.
116 changes: 62 additions & 54 deletions Lib/ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,7 @@

socket_error = OSError # keep that public name in module namespace

if _ssl.HAS_TLS_UNIQUE:
CHANNEL_BINDING_TYPES = ['tls-unique']
else:
CHANNEL_BINDING_TYPES = []
CHANNEL_BINDING_TYPES = ['tls-unique']

HAS_NEVER_CHECK_COMMON_NAME = hasattr(_ssl, 'HOSTFLAG_NEVER_CHECK_SUBJECT')

Expand Down Expand Up @@ -407,11 +404,11 @@ def wrap_bio(self, incoming, outgoing, server_side=False,
server_hostname=None, session=None):
# Need to encode server_hostname here because _wrap_bio() can only
# handle ASCII str.
sslobj = self._wrap_bio(
return self.sslobject_class(
incoming, outgoing, server_side=server_side,
server_hostname=self._encode_hostname(server_hostname)
server_hostname=self._encode_hostname(server_hostname),
session=session, _context=self,
)
return self.sslobject_class(sslobj, session=session)

def set_npn_protocols(self, npn_protocols):
protos = bytearray()
Expand Down Expand Up @@ -616,12 +613,13 @@ class SSLObject:
* The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery.
"""

def __init__(self, sslobj, owner=None, session=None):
self._sslobj = sslobj
# Note: _sslobj takes a weak reference to owner
self._sslobj.owner = owner or self
if session is not None:
self._sslobj.session = session
def __init__(self, incoming, outgoing, server_side=False,
server_hostname=None, session=None, _context=None):
self._sslobj = _context._wrap_bio(
incoming, outgoing, server_side=server_side,
server_hostname=server_hostname,
owner=self, session=session
)

@property
def context(self):
Expand Down Expand Up @@ -684,7 +682,7 @@ def getpeercert(self, binary_form=False):
Return None if no certificate was provided, {} if a certificate was
provided, but not validated.
"""
return self._sslobj.peer_certificate(binary_form)
return self._sslobj.getpeercert(binary_form)

def selected_npn_protocol(self):
"""Return the currently selected NPN protocol as a string, or ``None``
Expand Down Expand Up @@ -732,13 +730,7 @@ def get_channel_binding(self, cb_type="tls-unique"):
"""Get channel binding data for current connection. Raise ValueError
if the requested `cb_type` is not supported. Return bytes of the data
or None if the data is not available (e.g. before the handshake)."""
if cb_type not in CHANNEL_BINDING_TYPES:
raise ValueError("Unsupported channel binding type")
if cb_type != "tls-unique":
raise NotImplementedError(
"{0} channel binding type not implemented"
.format(cb_type))
return self._sslobj.tls_unique_cb()
return self._sslobj.get_channel_binding(cb_type)

def version(self):
"""Return a string identifying the protocol version used by the
Expand Down Expand Up @@ -832,10 +824,10 @@ def __init__(self, sock=None, keyfile=None, certfile=None,
if connected:
# create the SSL object
try:
sslobj = self._context._wrap_socket(self, server_side,
self.server_hostname)
self._sslobj = SSLObject(sslobj, owner=self,
session=self._session)
self._sslobj = self._context._wrap_socket(
self, server_side, self.server_hostname,
owner=self, session=self._session,
)
if do_handshake_on_connect:
timeout = self.gettimeout()
if timeout == 0.0:
Expand Down Expand Up @@ -895,10 +887,13 @@ def read(self, len=1024, buffer=None):
Return zero-length string on EOF."""

self._checkClosed()
if not self._sslobj:
if self._sslobj is None:
raise ValueError("Read on closed or unwrapped SSL socket.")
try:
return self._sslobj.read(len, buffer)
if buffer is not None:
return self._sslobj.read(len, buffer)
else:
return self._sslobj.read(len)
except SSLError as x:
if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
if buffer is not None:
Expand All @@ -913,7 +908,7 @@ def write(self, data):
number of bytes of DATA actually transmitted."""

self._checkClosed()
if not self._sslobj:
if self._sslobj is None:
raise ValueError("Write on closed or unwrapped SSL socket.")
return self._sslobj.write(data)

Expand All @@ -929,41 +924,42 @@ def getpeercert(self, binary_form=False):

def selected_npn_protocol(self):
self._checkClosed()
if not self._sslobj or not _ssl.HAS_NPN:
if self._sslobj is None or not _ssl.HAS_NPN:
return None
else:
return self._sslobj.selected_npn_protocol()

def selected_alpn_protocol(self):
self._checkClosed()
if not self._sslobj or not _ssl.HAS_ALPN:
if self._sslobj is None or not _ssl.HAS_ALPN:
return None
else:
return self._sslobj.selected_alpn_protocol()

def cipher(self):
self._checkClosed()
if not self._sslobj:
if self._sslobj is None:
return None
else:
return self._sslobj.cipher()

def shared_ciphers(self):
self._checkClosed()
if not self._sslobj:
if self._sslobj is None:
return None
return self._sslobj.shared_ciphers()
else:
return self._sslobj.shared_ciphers()

def compression(self):
self._checkClosed()
if not self._sslobj:
if self._sslobj is None:
return None
else:
return self._sslobj.compression()

def send(self, data, flags=0):
self._checkClosed()
if self._sslobj:
if self._sslobj is not None:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to send() on %s" %
Expand All @@ -974,7 +970,7 @@ def send(self, data, flags=0):

def sendto(self, data, flags_or_addr, addr=None):
self._checkClosed()
if self._sslobj:
if self._sslobj is not None:
raise ValueError("sendto not allowed on instances of %s" %
self.__class__)
elif addr is None:
Expand All @@ -990,7 +986,7 @@ def sendmsg(self, *args, **kwargs):

def sendall(self, data, flags=0):
self._checkClosed()
if self._sslobj:
if self._sslobj is not None:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to sendall() on %s" %
Expand All @@ -1008,15 +1004,15 @@ def sendfile(self, file, offset=0, count=None):
"""Send a file, possibly by using os.sendfile() if this is a
clear-text socket. Return the total number of bytes sent.
"""
if self._sslobj is None:
if self._sslobj is not None:
return self._sendfile_use_send(file, offset, count)
else:
# os.sendfile() works with plain sockets only
return super().sendfile(file, offset, count)
else:
return self._sendfile_use_send(file, offset, count)

def recv(self, buflen=1024, flags=0):
self._checkClosed()
if self._sslobj:
if self._sslobj is not None:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to recv() on %s" %
Expand All @@ -1031,7 +1027,7 @@ def recv_into(self, buffer, nbytes=None, flags=0):
nbytes = len(buffer)
elif nbytes is None:
nbytes = 1024
if self._sslobj:
if self._sslobj is not None:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to recv_into() on %s" %
Expand All @@ -1042,15 +1038,15 @@ def recv_into(self, buffer, nbytes=None, flags=0):

def recvfrom(self, buflen=1024, flags=0):
self._checkClosed()
if self._sslobj:
if self._sslobj is not None:
raise ValueError("recvfrom not allowed on instances of %s" %
self.__class__)
else:
return super().recvfrom(buflen, flags)

def recvfrom_into(self, buffer, nbytes=None, flags=0):
self._checkClosed()
if self._sslobj:
if self._sslobj is not None:
raise ValueError("recvfrom_into not allowed on instances of %s" %
self.__class__)
else:
Expand All @@ -1066,7 +1062,7 @@ def recvmsg_into(self, *args, **kwargs):

def pending(self):
self._checkClosed()
if self._sslobj:
if self._sslobj is not None:
return self._sslobj.pending()
else:
return 0
Expand All @@ -1078,7 +1074,7 @@ def shutdown(self, how):

def unwrap(self):
if self._sslobj:
s = self._sslobj.unwrap()
s = self._sslobj.shutdown()
self._sslobj = None
return s
else:
Expand All @@ -1096,6 +1092,11 @@ def do_handshake(self, block=False):
if timeout == 0.0 and block:
self.settimeout(None)
self._sslobj.do_handshake()
if self.context.check_hostname:
if not self.server_hostname:
raise ValueError("check_hostname needs server_hostname "
"argument")
match_hostname(self.getpeercert(), self.server_hostname)
finally:
self.settimeout(timeout)

Expand All @@ -1104,11 +1105,12 @@ def _real_connect(self, addr, connect_ex):
raise ValueError("can't connect in server-side mode")
# Here we assume that the socket is client-side, and not
# connected at the time of the call. We connect it, then wrap it.
if self._connected:
if self._connected or self._sslobj is not None:
raise ValueError("attempt to connect already-connected SSLSocket!")
sslobj = self.context._wrap_socket(self, False, self.server_hostname)
self._sslobj = SSLObject(sslobj, owner=self,
session=self._session)
self._sslobj = self.context._wrap_socket(
self, False, self.server_hostname,
owner=self, session=self._session
)
try:
if connect_ex:
rc = super().connect_ex(addr)
Expand Down Expand Up @@ -1151,18 +1153,24 @@ def get_channel_binding(self, cb_type="tls-unique"):
if the requested `cb_type` is not supported. Return bytes of the data
or None if the data is not available (e.g. before the handshake).
"""
if self._sslobj is None:
if self._sslobj is not None:
return self._sslobj.get_channel_binding(cb_type)
else:
if cb_type not in CHANNEL_BINDING_TYPES:
raise ValueError(
"{0} channel binding type not implemented".format(cb_type)
)
return None
return self._sslobj.get_channel_binding(cb_type)

def version(self):
"""
Return a string identifying the protocol version used by the
current SSL channel, or None if there is no established channel.
"""
if self._sslobj is None:
if self._sslobj is not None:
return self._sslobj.version()
else:
return None
return self._sslobj.version()


# Python does not support forward declaration of types.
Expand Down
4 changes: 4 additions & 0 deletions Lib/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,8 @@ def test_wrapped_unconnected(self):
self.assertRaises(OSError, ss.recvfrom_into, bytearray(b'x'), 1)
self.assertRaises(OSError, ss.send, b'x')
self.assertRaises(OSError, ss.sendto, b'x', ('0.0.0.0', 0))
self.assertRaises(NotImplementedError, ss.sendmsg,
[b'x'], (), 0, ('0.0.0.0', 0))

def test_timeout(self):
# Issue #8524: when creating an SSL socket, the timeout of the
Expand Down Expand Up @@ -3381,11 +3383,13 @@ def test_version_basic(self):
chatty=False) as server:
with context.wrap_socket(socket.socket()) as s:
self.assertIs(s.version(), None)
self.assertIs(s._sslobj, None)
s.connect((HOST, server.port))
if ssl.OPENSSL_VERSION_INFO >= (1, 0, 2):
self.assertEqual(s.version(), 'TLSv1.2')
else: # 0.9.8 to 1.0.1
self.assertIn(s.version(), ('TLSv1', 'TLSv1.2'))
self.assertIs(s._sslobj, None)
self.assertIs(s.version(), None)

@unittest.skipUnless(ssl.HAS_TLSv1_3,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Internal implementation details of ssl module were cleaned up. The SSLSocket
has one less layer of indirection. Owner and session information are now
handled by the SSLSocket and SSLObject constructor. Channel binding
implementation has been simplified.
Loading

0 comments on commit 141c5e8

Please sign in to comment.