diff --git a/pymodbus/server/async_io.py b/pymodbus/server/async_io.py index c4a7e28369..04b4f5afa0 100755 --- a/pymodbus/server/async_io.py +++ b/pymodbus/server/async_io.py @@ -517,6 +517,7 @@ def __init__(self, sslctx=None, certfile=None, keyfile=None, + reqclicert=False, handler=None, allow_reuse_address=False, allow_reuse_port=False, @@ -537,6 +538,7 @@ def __init__(self, create) :param certfile: The cert file path for TLS (used if sslctx is None) :param keyfile: The key file path for TLS (used if sslctx is None) + :param reqclicert: Force the sever request client's certificate :param handler: A handler for each client session; default is ModbusConnectedRequestHandler. The handler class receives connection create/teardown events @@ -585,8 +587,8 @@ def __init__(self, self.sslctx.options |= ssl.OP_NO_TLSv1 self.sslctx.options |= ssl.OP_NO_SSLv3 self.sslctx.options |= ssl.OP_NO_SSLv2 - self.sslctx.verify_mode = ssl.CERT_OPTIONAL - self.sslctx.check_hostname = False + if reqclicert: + self.sslctx.verify_mode = ssl.CERT_REQUIRED # asyncio future that will be done once server has started self.serving = self.loop.create_future() # constructors cannot be declared async, so we have to @@ -838,7 +840,7 @@ async def StartTcpServer(context=None, identity=None, address=None, async def StartTlsServer(context=None, identity=None, address=None, sslctx=None, - certfile=None, keyfile=None, + certfile=None, keyfile=None, reqclicert=False, allow_reuse_address=False, allow_reuse_port=False, custom_functions=[], @@ -851,6 +853,7 @@ async def StartTlsServer(context=None, identity=None, address=None, :param sslctx: The SSLContext to use for TLS (default None and auto create) :param certfile: The cert file path for TLS (used if sslctx is None) :param keyfile: The key file path for TLS (used if sslctx is None) + :param reqclicert: Force the sever request client's certificate :param allow_reuse_address: Whether the server will allow the reuse of an address. :param allow_reuse_port: Whether the server will allow the reuse of a port. @@ -865,7 +868,7 @@ async def StartTlsServer(context=None, identity=None, address=None, """ framer = kwargs.pop("framer", ModbusTlsFramer) server = ModbusTlsServer(context, framer, identity, address, sslctx, - certfile, keyfile, + certfile, keyfile, reqclicert=reqclicert, allow_reuse_address=allow_reuse_address, allow_reuse_port=allow_reuse_port, **kwargs) diff --git a/pymodbus/server/sync.py b/pymodbus/server/sync.py index c504ff0135..1ec4a5b881 100644 --- a/pymodbus/server/sync.py +++ b/pymodbus/server/sync.py @@ -376,7 +376,8 @@ class ModbusTlsServer(ModbusTcpServer): def __init__(self, context, framer=None, identity=None, address=None, handler=None, allow_reuse_address=False, - sslctx=None, certfile=None, keyfile=None, **kwargs): + sslctx=None, certfile=None, keyfile=None, reqclicert=False, + **kwargs): """ Overloaded initializer for the ModbusTcpServer If the identify structure is not passed in, the ModbusControlBlock @@ -394,6 +395,7 @@ def __init__(self, context, framer=None, identity=None, create) :param certfile: The cert file path for TLS (used if sslctx is None) :param keyfile: The key file path for TLS (used if sslctx is None) + :param reqclicert: Force the sever request client's certificate :param ignore_missing_slaves: True to not send errors on a request to a missing slave :param broadcast_enable: True to treat unit_id 0 as broadcast address, @@ -409,8 +411,8 @@ def __init__(self, context, framer=None, identity=None, self.sslctx.options |= ssl.OP_NO_TLSv1 self.sslctx.options |= ssl.OP_NO_SSLv3 self.sslctx.options |= ssl.OP_NO_SSLv2 - self.sslctx.verify_mode = ssl.CERT_OPTIONAL - self.sslctx.check_hostname = False + if reqclicert: + self.sslctx.verify_mode = ssl.CERT_REQUIRED ModbusTcpServer.__init__(self, context, framer, identity, address, handler, allow_reuse_address, **kwargs) @@ -621,7 +623,8 @@ def StartTcpServer(context=None, identity=None, address=None, def StartTlsServer(context=None, identity=None, address=None, sslctx=None, - certfile=None, keyfile=None, custom_functions=[], **kwargs): + certfile=None, keyfile=None, reqclicert=False, + custom_functions=[], **kwargs): """ A factory to start and run a tls modbus server :param context: The ModbusServerContext datastore @@ -630,6 +633,7 @@ def StartTlsServer(context=None, identity=None, address=None, sslctx=None, :param sslctx: The SSLContext to use for TLS (default None and auto create) :param certfile: The cert file path for TLS (used if sslctx is None) :param keyfile: The key file path for TLS (used if sslctx is None) + :param reqclicert: Force the sever request client's certificate :param custom_functions: An optional list of custom function classes supported by server instance. :param ignore_missing_slaves: True to not send errors on a request to a @@ -637,7 +641,8 @@ def StartTlsServer(context=None, identity=None, address=None, sslctx=None, """ framer = kwargs.pop("framer", ModbusTlsFramer) server = ModbusTlsServer(context, framer, identity, address, sslctx=sslctx, - certfile=certfile, keyfile=keyfile, **kwargs) + certfile=certfile, keyfile=keyfile, + reqclicert=reqclicert, **kwargs) for f in custom_functions: server.decoder.register(f) diff --git a/test/test_server_sync.py b/test/test_server_sync.py index 526da1a5f4..9a0701894f 100644 --- a/test/test_server_sync.py +++ b/test/test_server_sync.py @@ -283,8 +283,10 @@ def testTlsServerInit(self): with patch.object(socket.socket, 'bind') as mock_socket: with patch.object(ssl.SSLContext, 'load_cert_chain') as mock_method: identity = ModbusDeviceIdentification(info={0x00: 'VendorName'}) - server = ModbusTlsServer(context=None, identity=identity) + server = ModbusTlsServer(context=None, identity=identity, + reqclicert=True) self.assertIsNotNone(server.sslctx) + self.assertEqual(server.sslctx.verify_mode, ssl.CERT_REQUIRED) self.assertEqual(type(server.socket), ssl.SSLSocket) server.server_close() sslctx = ssl.create_default_context() @@ -386,6 +388,7 @@ def testStartTcpServer(self): with patch.object(socketserver.TCPServer, 'server_bind') as mock_binder: StartTcpServer() + @patch.dict(ssl.SSLContext, {'server_hostname': 'test.com'}) def testStartTlsServer(self): ''' Test the tls server starting factory ''' with patch.object(ModbusTlsServer, 'serve_forever') as mock_server: