diff --git a/src/engine_posix.cpp b/src/engine_posix.cpp index b6b5ffa..184ef21 100644 --- a/src/engine_posix.cpp +++ b/src/engine_posix.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include "ujrpc/ujrpc.h" @@ -43,9 +44,7 @@ struct ujrpc_ssl_context_t { mbedtls_pk_free(&pkey); mbedtls_ssl_free(&ssl); mbedtls_ssl_config_free(&conf); - // #if defined(MBEDTLS_SSL_CACHE_C) - // mbedtls_ssl_cache_free(&cache); - // #endif + mbedtls_ssl_cache_free(&cache); mbedtls_ctr_drbg_free(&ctr_drbg); mbedtls_entropy_free(&entropy); } @@ -53,9 +52,7 @@ struct ujrpc_ssl_context_t { int init(const char* pk_path, const char** crts_path, size_t crts_cnt) { mbedtls_ssl_init(&ssl); mbedtls_ssl_config_init(&conf); - // #if defined(MBEDTLS_SSL_CACHE_C) - // mbedtls_ssl_cache_init(&cache); - // #endif + mbedtls_ssl_cache_init(&cache); mbedtls_x509_crt_init(&srvcert); mbedtls_pk_init(&pkey); mbedtls_entropy_init(&entropy); @@ -84,9 +81,7 @@ struct ujrpc_ssl_context_t { mbedtls_ssl_conf_rng(&conf, mbedtls_ctr_drbg_random, &ctr_drbg); - // #if defined(MBEDTLS_SSL_CACHE_C) - // mbedtls_ssl_conf_session_cache(&conf, &cache, mbedtls_ssl_cache_get, mbedtls_ssl_cache_set); - // #endif + mbedtls_ssl_conf_session_cache(&conf, &cache, mbedtls_ssl_cache_get, mbedtls_ssl_cache_set); mbedtls_ssl_conf_renegotiation(&conf, MBEDTLS_SSL_RENEGOTIATION_DISABLED); mbedtls_ssl_conf_ca_chain(&conf, srvcert.next, NULL); @@ -104,6 +99,7 @@ struct ujrpc_ssl_context_t { mbedtls_pk_context pkey{}; mbedtls_x509_crt srvcert{}; mbedtls_entropy_context entropy{}; + mbedtls_ssl_cache_context cache{}; mbedtls_ctr_drbg_context ctr_drbg{}; }; diff --git a/src/ujrpc/client.py b/src/ujrpc/client.py index 121c1c7..f6604ee 100644 --- a/src/ujrpc/client.py +++ b/src/ujrpc/client.py @@ -185,7 +185,7 @@ def __call__(self, jsonrpc: object) -> Response: class ClientTLS(Client): def __init__(self, uri: str = '127.0.0.1', port: int = 8545, - ssl_context: ssl.SSLContext = None, allow_self_signed=False) -> None: + ssl_context: ssl.SSLContext = None, allow_self_signed=False, enable_session_resumption=True) -> None: super().__init__(uri, port, use_http=True) if ssl_context is None: @@ -195,14 +195,17 @@ def __init__(self, uri: str = '127.0.0.1', port: int = 8545, ssl_context.verify_mode = ssl.CERT_NONE self.ssl_context = ssl_context + self.session = None + self.session_resumption = enable_session_resumption def _make_socket(self): if not self._socket_is_closed(): return self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock = self.ssl_context.wrap_socket(self.sock, server_hostname=self.uri, session=self.session) self.sock.connect((self.uri, self.port)) - self.sock = self.ssl_context.wrap_socket( - self.sock, server_hostname=self.uri) + if self.session_resumption: + self.session = self.sock.session def _socket_is_closed(self) -> bool: if self.sock is None: