Skip to content

Commit

Permalink
Address PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
WillChilds-Klein committed Jun 28, 2024
1 parent 57b4d6d commit 082e7d6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
8 changes: 4 additions & 4 deletions Lib/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4364,14 +4364,14 @@ def test_session_handling(self):
def test_psk(self):
psk = bytes.fromhex('deadbeef')

client_context, server_context, _ = testing_context()

client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
client_context.check_hostname = False
client_context.verify_mode = ssl.CERT_NONE
client_context.maximum_version = ssl.TLSVersion.TLSv1_2
client_context.set_ciphers('PSK')
client_context.set_psk_client_callback(lambda hint: (None, psk))

server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
server_context.maximum_version = ssl.TLSVersion.TLSv1_2
server_context.set_ciphers('PSK')
server_context.set_psk_server_callback(lambda identity: psk)
Expand Down Expand Up @@ -4443,14 +4443,14 @@ def server_callback(identity):
self.assertEqual(identity, client_identity)
return psk

client_context, server_context, _ = testing_context()

client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
client_context.check_hostname = False
client_context.verify_mode = ssl.CERT_NONE
client_context.minimum_version = ssl.TLSVersion.TLSv1_3
client_context.set_ciphers('PSK')
client_context.set_psk_client_callback(client_callback)

server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
server_context.minimum_version = ssl.TLSVersion.TLSv1_3
server_context.set_ciphers('PSK')
server_context.set_psk_server_callback(server_callback, identity_hint)
Expand Down
18 changes: 9 additions & 9 deletions Modules/_ssl.c
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ extern const SSL_METHOD *TLSv1_2_method(void);
#endif


#if !defined(SSL_VERIFY_POST_HANDSHAKE) || !defined(TLS1_3_VERSION) || defined(OPENSSL_NO_TLS1_3)
#define PY_SSL_NO_POST_HS_AUTH
#if defined(SSL_VERIFY_POST_HANDSHAKE) && defined(TLS1_3_VERSION) && !defined(OPENSSL_NO_TLS1_3)
#define PySSL_HAVE_POST_HS_AUTH
#endif


Expand Down Expand Up @@ -298,7 +298,7 @@ typedef struct {
*/
unsigned int hostflags;
int protocol;
#if !defined(PY_SSL_NO_POST_HS_AUTH)
#if defined(PySSL_HAVE_POST_HS_AUTH)
int post_handshake_auth;
#endif
PyObject *msg_cb;
Expand Down Expand Up @@ -878,7 +878,7 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock,
SSL_set_mode(self->ssl,
SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_AUTO_RETRY);

#if !defined(PY_SSL_NO_POST_HS_AUTH)
#if defined(PySSL_HAVE_POST_HS_AUTH)
if (sslctx->post_handshake_auth == 1) {
if (socket_type == PY_SSL_SERVER) {
/* bpo-37428: OpenSSL does not ignore SSL_VERIFY_POST_HANDSHAKE.
Expand Down Expand Up @@ -2781,7 +2781,7 @@ static PyObject *
_ssl__SSLSocket_verify_client_post_handshake_impl(PySSLSocket *self)
/*[clinic end generated code: output=532147f3b1341425 input=6bfa874810a3d889]*/
{
#if !defined(PY_SSL_NO_POST_HS_AUTH)
#if defined(PySSL_HAVE_POST_HS_AUTH)
int err = SSL_verify_client_post_handshake(self->ssl);
if (err == 0)
return _setSSLError(get_state_sock(self), NULL, 0, __FILE__, __LINE__);
Expand Down Expand Up @@ -3204,7 +3204,7 @@ _ssl__SSLContext_impl(PyTypeObject *type, int proto_version)
X509_VERIFY_PARAM_set_flags(params, X509_V_FLAG_TRUSTED_FIRST);
X509_VERIFY_PARAM_set_hostflags(params, self->hostflags);

#if !defined(PY_SSL_NO_POST_HS_AUTH)
#if defined(PySSL_HAVE_POST_HS_AUTH)
self->post_handshake_auth = 0;
SSL_CTX_set_post_handshake_auth(self->ctx, self->post_handshake_auth);
#endif
Expand Down Expand Up @@ -3716,14 +3716,14 @@ set_check_hostname(PySSLContext *self, PyObject *arg, void *c)

static PyObject *
get_post_handshake_auth(PySSLContext *self, void *c) {
#if !defined(PY_SSL_NO_POST_HS_AUTH)
#if defined(PySSL_HAVE_POST_HS_AUTH)
return PyBool_FromLong(self->post_handshake_auth);
#else
Py_RETURN_NONE;
#endif
}

#if !defined(PY_SSL_NO_POST_HS_AUTH)
#if defined(PySSL_HAVE_POST_HS_AUTH)
static int
set_post_handshake_auth(PySSLContext *self, PyObject *arg, void *c) {
if (arg == NULL) {
Expand Down Expand Up @@ -4972,7 +4972,7 @@ static PyGetSetDef context_getsetlist[] = {
{"options", (getter) get_options,
(setter) set_options, NULL},
{"post_handshake_auth", (getter) get_post_handshake_auth,
#if !defined(PY_SSL_NO_POST_HS_AUTH)
#if defined(PySSL_HAVE_POST_HS_AUTH)
(setter) set_post_handshake_auth,
#else
NULL,
Expand Down

0 comments on commit 082e7d6

Please sign in to comment.