Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bpo-24334: Cleanup SSLSocket #5252

Merged
merged 1 commit into from
Feb 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 62 additions & 54 deletions Lib/ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,7 @@

socket_error = OSError # keep that public name in module namespace

if _ssl.HAS_TLS_UNIQUE:
CHANNEL_BINDING_TYPES = ['tls-unique']
else:
CHANNEL_BINDING_TYPES = []
CHANNEL_BINDING_TYPES = ['tls-unique']

HAS_NEVER_CHECK_COMMON_NAME = hasattr(_ssl, 'HOSTFLAG_NEVER_CHECK_SUBJECT')

Expand Down Expand Up @@ -407,11 +404,11 @@ def wrap_bio(self, incoming, outgoing, server_side=False,
server_hostname=None, session=None):
# Need to encode server_hostname here because _wrap_bio() can only
# handle ASCII str.
sslobj = self._wrap_bio(
return self.sslobject_class(
incoming, outgoing, server_side=server_side,
server_hostname=self._encode_hostname(server_hostname)
server_hostname=self._encode_hostname(server_hostname),
session=session, _context=self,
)
return self.sslobject_class(sslobj, session=session)

def set_npn_protocols(self, npn_protocols):
protos = bytearray()
Expand Down Expand Up @@ -616,12 +613,13 @@ class SSLObject:
* The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery.
"""

def __init__(self, sslobj, owner=None, session=None):
self._sslobj = sslobj
# Note: _sslobj takes a weak reference to owner
self._sslobj.owner = owner or self
if session is not None:
self._sslobj.session = session
def __init__(self, incoming, outgoing, server_side=False,
server_hostname=None, session=None, _context=None):
self._sslobj = _context._wrap_bio(
incoming, outgoing, server_side=server_side,
server_hostname=server_hostname,
owner=self, session=session
)

@property
def context(self):
Expand Down Expand Up @@ -684,7 +682,7 @@ def getpeercert(self, binary_form=False):
Return None if no certificate was provided, {} if a certificate was
provided, but not validated.
"""
return self._sslobj.peer_certificate(binary_form)
return self._sslobj.getpeercert(binary_form)

def selected_npn_protocol(self):
"""Return the currently selected NPN protocol as a string, or ``None``
Expand Down Expand Up @@ -732,13 +730,7 @@ def get_channel_binding(self, cb_type="tls-unique"):
"""Get channel binding data for current connection. Raise ValueError
if the requested `cb_type` is not supported. Return bytes of the data
or None if the data is not available (e.g. before the handshake)."""
if cb_type not in CHANNEL_BINDING_TYPES:
raise ValueError("Unsupported channel binding type")
if cb_type != "tls-unique":
raise NotImplementedError(
"{0} channel binding type not implemented"
.format(cb_type))
return self._sslobj.tls_unique_cb()
return self._sslobj.get_channel_binding(cb_type)

def version(self):
"""Return a string identifying the protocol version used by the
Expand Down Expand Up @@ -832,10 +824,10 @@ def __init__(self, sock=None, keyfile=None, certfile=None,
if connected:
# create the SSL object
try:
sslobj = self._context._wrap_socket(self, server_side,
self.server_hostname)
self._sslobj = SSLObject(sslobj, owner=self,
session=self._session)
self._sslobj = self._context._wrap_socket(
self, server_side, self.server_hostname,
owner=self, session=self._session,
)
if do_handshake_on_connect:
timeout = self.gettimeout()
if timeout == 0.0:
Expand Down Expand Up @@ -895,10 +887,13 @@ def read(self, len=1024, buffer=None):
Return zero-length string on EOF."""

self._checkClosed()
if not self._sslobj:
if self._sslobj is None:
raise ValueError("Read on closed or unwrapped SSL socket.")
try:
return self._sslobj.read(len, buffer)
if buffer is not None:
return self._sslobj.read(len, buffer)
else:
return self._sslobj.read(len)
except SSLError as x:
if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
if buffer is not None:
Expand All @@ -913,7 +908,7 @@ def write(self, data):
number of bytes of DATA actually transmitted."""

self._checkClosed()
if not self._sslobj:
if self._sslobj is None:
raise ValueError("Write on closed or unwrapped SSL socket.")
return self._sslobj.write(data)

Expand All @@ -929,41 +924,42 @@ def getpeercert(self, binary_form=False):

def selected_npn_protocol(self):
self._checkClosed()
if not self._sslobj or not _ssl.HAS_NPN:
if self._sslobj is None or not _ssl.HAS_NPN:
return None
else:
return self._sslobj.selected_npn_protocol()

def selected_alpn_protocol(self):
self._checkClosed()
if not self._sslobj or not _ssl.HAS_ALPN:
if self._sslobj is None or not _ssl.HAS_ALPN:
return None
else:
return self._sslobj.selected_alpn_protocol()

def cipher(self):
self._checkClosed()
if not self._sslobj:
if self._sslobj is None:
return None
else:
return self._sslobj.cipher()

def shared_ciphers(self):
self._checkClosed()
if not self._sslobj:
if self._sslobj is None:
return None
return self._sslobj.shared_ciphers()
else:
return self._sslobj.shared_ciphers()

def compression(self):
self._checkClosed()
if not self._sslobj:
if self._sslobj is None:
return None
else:
return self._sslobj.compression()

def send(self, data, flags=0):
self._checkClosed()
if self._sslobj:
if self._sslobj is not None:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to send() on %s" %
Expand All @@ -974,7 +970,7 @@ def send(self, data, flags=0):

def sendto(self, data, flags_or_addr, addr=None):
self._checkClosed()
if self._sslobj:
if self._sslobj is not None:
raise ValueError("sendto not allowed on instances of %s" %
self.__class__)
elif addr is None:
Expand All @@ -990,7 +986,7 @@ def sendmsg(self, *args, **kwargs):

def sendall(self, data, flags=0):
self._checkClosed()
if self._sslobj:
if self._sslobj is not None:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to sendall() on %s" %
Expand All @@ -1008,15 +1004,15 @@ def sendfile(self, file, offset=0, count=None):
"""Send a file, possibly by using os.sendfile() if this is a
clear-text socket. Return the total number of bytes sent.
"""
if self._sslobj is None:
if self._sslobj is not None:
return self._sendfile_use_send(file, offset, count)
else:
# os.sendfile() works with plain sockets only
return super().sendfile(file, offset, count)
else:
return self._sendfile_use_send(file, offset, count)

def recv(self, buflen=1024, flags=0):
self._checkClosed()
if self._sslobj:
if self._sslobj is not None:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to recv() on %s" %
Expand All @@ -1031,7 +1027,7 @@ def recv_into(self, buffer, nbytes=None, flags=0):
nbytes = len(buffer)
elif nbytes is None:
nbytes = 1024
if self._sslobj:
if self._sslobj is not None:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to recv_into() on %s" %
Expand All @@ -1042,15 +1038,15 @@ def recv_into(self, buffer, nbytes=None, flags=0):

def recvfrom(self, buflen=1024, flags=0):
self._checkClosed()
if self._sslobj:
if self._sslobj is not None:
raise ValueError("recvfrom not allowed on instances of %s" %
self.__class__)
else:
return super().recvfrom(buflen, flags)

def recvfrom_into(self, buffer, nbytes=None, flags=0):
self._checkClosed()
if self._sslobj:
if self._sslobj is not None:
raise ValueError("recvfrom_into not allowed on instances of %s" %
self.__class__)
else:
Expand All @@ -1066,7 +1062,7 @@ def recvmsg_into(self, *args, **kwargs):

def pending(self):
self._checkClosed()
if self._sslobj:
if self._sslobj is not None:
return self._sslobj.pending()
else:
return 0
Expand All @@ -1078,7 +1074,7 @@ def shutdown(self, how):

def unwrap(self):
if self._sslobj:
s = self._sslobj.unwrap()
s = self._sslobj.shutdown()
self._sslobj = None
return s
else:
Expand All @@ -1096,6 +1092,11 @@ def do_handshake(self, block=False):
if timeout == 0.0 and block:
self.settimeout(None)
self._sslobj.do_handshake()
if self.context.check_hostname:
if not self.server_hostname:
raise ValueError("check_hostname needs server_hostname "
"argument")
match_hostname(self.getpeercert(), self.server_hostname)
finally:
self.settimeout(timeout)

Expand All @@ -1104,11 +1105,12 @@ def _real_connect(self, addr, connect_ex):
raise ValueError("can't connect in server-side mode")
# Here we assume that the socket is client-side, and not
# connected at the time of the call. We connect it, then wrap it.
if self._connected:
if self._connected or self._sslobj is not None:
raise ValueError("attempt to connect already-connected SSLSocket!")
sslobj = self.context._wrap_socket(self, False, self.server_hostname)
self._sslobj = SSLObject(sslobj, owner=self,
session=self._session)
self._sslobj = self.context._wrap_socket(
self, False, self.server_hostname,
owner=self, session=self._session
)
try:
if connect_ex:
rc = super().connect_ex(addr)
Expand Down Expand Up @@ -1151,18 +1153,24 @@ def get_channel_binding(self, cb_type="tls-unique"):
if the requested `cb_type` is not supported. Return bytes of the data
or None if the data is not available (e.g. before the handshake).
"""
if self._sslobj is None:
if self._sslobj is not None:
return self._sslobj.get_channel_binding(cb_type)
else:
if cb_type not in CHANNEL_BINDING_TYPES:
raise ValueError(
"{0} channel binding type not implemented".format(cb_type)
)
return None
return self._sslobj.get_channel_binding(cb_type)

def version(self):
"""
Return a string identifying the protocol version used by the
current SSL channel, or None if there is no established channel.
"""
if self._sslobj is None:
if self._sslobj is not None:
return self._sslobj.version()
else:
return None
return self._sslobj.version()


# Python does not support forward declaration of types.
Expand Down
4 changes: 4 additions & 0 deletions Lib/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,8 @@ def test_wrapped_unconnected(self):
self.assertRaises(OSError, ss.recvfrom_into, bytearray(b'x'), 1)
self.assertRaises(OSError, ss.send, b'x')
self.assertRaises(OSError, ss.sendto, b'x', ('0.0.0.0', 0))
self.assertRaises(NotImplementedError, ss.sendmsg,
[b'x'], (), 0, ('0.0.0.0', 0))

def test_timeout(self):
# Issue #8524: when creating an SSL socket, the timeout of the
Expand Down Expand Up @@ -3381,11 +3383,13 @@ def test_version_basic(self):
chatty=False) as server:
with context.wrap_socket(socket.socket()) as s:
self.assertIs(s.version(), None)
self.assertIs(s._sslobj, None)
s.connect((HOST, server.port))
if ssl.OPENSSL_VERSION_INFO >= (1, 0, 2):
self.assertEqual(s.version(), 'TLSv1.2')
else: # 0.9.8 to 1.0.1
self.assertIn(s.version(), ('TLSv1', 'TLSv1.2'))
self.assertIs(s._sslobj, None)
self.assertIs(s.version(), None)

@unittest.skipUnless(ssl.HAS_TLSv1_3,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Internal implementation details of ssl module were cleaned up. The SSLSocket
has one less layer of indirection. Owner and session information are now
handled by the SSLSocket and SSLObject constructor. Channel binding
implementation has been simplified.
Loading