From 96458c825009d927a1a555998e9165374817edf6 Mon Sep 17 00:00:00 2001 From: Jian-Hong Pan Date: Sun, 7 Mar 2021 22:23:28 +0800 Subject: [PATCH] Add server requiring client's cert in TLS handshake feature This patch adds server requiring client's certificate feature which is mentioned in the 6th step CertificateRequest to 9th step VerifyClientCertSig in Table 5 TLS Full Handshake Protocol of MODBUS/TCP Security Protocol Specification [1], This feature is implemented with an optional argument "reqclicert" of StartTlsServer() in both sync and async_io. So, users can force server require client's certificate, or according to the SSL Context's original behavior [2]. This fixes part of https://github.com/riptideio/pymodbus/issues/606 [1]: http://modbus.org/docs/MB-TCP-Security-v21_2018-07-24.pdf [2]: https://docs.python.org/3/library/ssl.html#ssl.SSLContext.verify_mode --- pymodbus/server/async_io.py | 11 +++++++---- pymodbus/server/sync.py | 15 ++++++++++----- test/test_server_sync.py | 5 ++++- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/pymodbus/server/async_io.py b/pymodbus/server/async_io.py index c4a7e28369..04b4f5afa0 100755 --- a/pymodbus/server/async_io.py +++ b/pymodbus/server/async_io.py @@ -517,6 +517,7 @@ def __init__(self, sslctx=None, certfile=None, keyfile=None, + reqclicert=False, handler=None, allow_reuse_address=False, allow_reuse_port=False, @@ -537,6 +538,7 @@ def __init__(self, create) :param certfile: The cert file path for TLS (used if sslctx is None) :param keyfile: The key file path for TLS (used if sslctx is None) + :param reqclicert: Force the sever request client's certificate :param handler: A handler for each client session; default is ModbusConnectedRequestHandler. The handler class receives connection create/teardown events @@ -585,8 +587,8 @@ def __init__(self, self.sslctx.options |= ssl.OP_NO_TLSv1 self.sslctx.options |= ssl.OP_NO_SSLv3 self.sslctx.options |= ssl.OP_NO_SSLv2 - self.sslctx.verify_mode = ssl.CERT_OPTIONAL - self.sslctx.check_hostname = False + if reqclicert: + self.sslctx.verify_mode = ssl.CERT_REQUIRED # asyncio future that will be done once server has started self.serving = self.loop.create_future() # constructors cannot be declared async, so we have to @@ -838,7 +840,7 @@ async def StartTcpServer(context=None, identity=None, address=None, async def StartTlsServer(context=None, identity=None, address=None, sslctx=None, - certfile=None, keyfile=None, + certfile=None, keyfile=None, reqclicert=False, allow_reuse_address=False, allow_reuse_port=False, custom_functions=[], @@ -851,6 +853,7 @@ async def StartTlsServer(context=None, identity=None, address=None, :param sslctx: The SSLContext to use for TLS (default None and auto create) :param certfile: The cert file path for TLS (used if sslctx is None) :param keyfile: The key file path for TLS (used if sslctx is None) + :param reqclicert: Force the sever request client's certificate :param allow_reuse_address: Whether the server will allow the reuse of an address. :param allow_reuse_port: Whether the server will allow the reuse of a port. @@ -865,7 +868,7 @@ async def StartTlsServer(context=None, identity=None, address=None, """ framer = kwargs.pop("framer", ModbusTlsFramer) server = ModbusTlsServer(context, framer, identity, address, sslctx, - certfile, keyfile, + certfile, keyfile, reqclicert=reqclicert, allow_reuse_address=allow_reuse_address, allow_reuse_port=allow_reuse_port, **kwargs) diff --git a/pymodbus/server/sync.py b/pymodbus/server/sync.py index c504ff0135..1ec4a5b881 100644 --- a/pymodbus/server/sync.py +++ b/pymodbus/server/sync.py @@ -376,7 +376,8 @@ class ModbusTlsServer(ModbusTcpServer): def __init__(self, context, framer=None, identity=None, address=None, handler=None, allow_reuse_address=False, - sslctx=None, certfile=None, keyfile=None, **kwargs): + sslctx=None, certfile=None, keyfile=None, reqclicert=False, + **kwargs): """ Overloaded initializer for the ModbusTcpServer If the identify structure is not passed in, the ModbusControlBlock @@ -394,6 +395,7 @@ def __init__(self, context, framer=None, identity=None, create) :param certfile: The cert file path for TLS (used if sslctx is None) :param keyfile: The key file path for TLS (used if sslctx is None) + :param reqclicert: Force the sever request client's certificate :param ignore_missing_slaves: True to not send errors on a request to a missing slave :param broadcast_enable: True to treat unit_id 0 as broadcast address, @@ -409,8 +411,8 @@ def __init__(self, context, framer=None, identity=None, self.sslctx.options |= ssl.OP_NO_TLSv1 self.sslctx.options |= ssl.OP_NO_SSLv3 self.sslctx.options |= ssl.OP_NO_SSLv2 - self.sslctx.verify_mode = ssl.CERT_OPTIONAL - self.sslctx.check_hostname = False + if reqclicert: + self.sslctx.verify_mode = ssl.CERT_REQUIRED ModbusTcpServer.__init__(self, context, framer, identity, address, handler, allow_reuse_address, **kwargs) @@ -621,7 +623,8 @@ def StartTcpServer(context=None, identity=None, address=None, def StartTlsServer(context=None, identity=None, address=None, sslctx=None, - certfile=None, keyfile=None, custom_functions=[], **kwargs): + certfile=None, keyfile=None, reqclicert=False, + custom_functions=[], **kwargs): """ A factory to start and run a tls modbus server :param context: The ModbusServerContext datastore @@ -630,6 +633,7 @@ def StartTlsServer(context=None, identity=None, address=None, sslctx=None, :param sslctx: The SSLContext to use for TLS (default None and auto create) :param certfile: The cert file path for TLS (used if sslctx is None) :param keyfile: The key file path for TLS (used if sslctx is None) + :param reqclicert: Force the sever request client's certificate :param custom_functions: An optional list of custom function classes supported by server instance. :param ignore_missing_slaves: True to not send errors on a request to a @@ -637,7 +641,8 @@ def StartTlsServer(context=None, identity=None, address=None, sslctx=None, """ framer = kwargs.pop("framer", ModbusTlsFramer) server = ModbusTlsServer(context, framer, identity, address, sslctx=sslctx, - certfile=certfile, keyfile=keyfile, **kwargs) + certfile=certfile, keyfile=keyfile, + reqclicert=reqclicert, **kwargs) for f in custom_functions: server.decoder.register(f) diff --git a/test/test_server_sync.py b/test/test_server_sync.py index 526da1a5f4..9a0701894f 100644 --- a/test/test_server_sync.py +++ b/test/test_server_sync.py @@ -283,8 +283,10 @@ def testTlsServerInit(self): with patch.object(socket.socket, 'bind') as mock_socket: with patch.object(ssl.SSLContext, 'load_cert_chain') as mock_method: identity = ModbusDeviceIdentification(info={0x00: 'VendorName'}) - server = ModbusTlsServer(context=None, identity=identity) + server = ModbusTlsServer(context=None, identity=identity, + reqclicert=True) self.assertIsNotNone(server.sslctx) + self.assertEqual(server.sslctx.verify_mode, ssl.CERT_REQUIRED) self.assertEqual(type(server.socket), ssl.SSLSocket) server.server_close() sslctx = ssl.create_default_context() @@ -386,6 +388,7 @@ def testStartTcpServer(self): with patch.object(socketserver.TCPServer, 'server_bind') as mock_binder: StartTcpServer() + @patch.dict(ssl.SSLContext, {'server_hostname': 'test.com'}) def testStartTlsServer(self): ''' Test the tls server starting factory ''' with patch.object(ModbusTlsServer, 'serve_forever') as mock_server: