From efd483351e7eb86f94f84478d3eddcf666bd6fc1 Mon Sep 17 00:00:00 2001 From: Jian-Hong Pan Date: Sun, 7 Mar 2021 22:23:28 +0800 Subject: [PATCH] Add server requiring client's cert in TLS handshake feature This patch adds server requiring client's certificate feature which is mentioned in the 6th step CertificateRequest to 9th step VerifyClientCertSig in Table 5 TLS Full Handshake Protocol of MODBUS/TCP Security Protocol Specification [1]. This patch implements the feature within both sync and async_io version. * Server side: Add an optional argument "reqclicert" of StartTlsServer(). So, users can force server require client's certificate for TLS full handshake, or according to the SSL Context's original behavior [2]. * Client side: Add optional arguments "certfile" and "keyfile" for replying, if the server requires client's certificate for TLS full handshake. Besides, also add an optional argument "password" on both server and client side for decrypting the private keyfile. This fixes part of https://github.com/riptideio/pymodbus/issues/606 [1]: http://modbus.org/docs/MB-TCP-Security-v21_2018-07-24.pdf [2]: https://docs.python.org/3/library/ssl.html#ssl.SSLContext.verify_mode --- examples/common/asyncio_server.py | 8 +++- examples/common/synchronous_server.py | 10 ++++- .../asynchronous_asyncio_modbus_tls_client.py | 14 ++++--- examples/contrib/modbus_tls_client.py | 14 ++++--- .../client/asynchronous/async_io/__init__.py | 37 ++++++++--------- pymodbus/client/asynchronous/factory/tls.py | 17 ++++---- pymodbus/client/asynchronous/tls.py | 15 ++++--- pymodbus/client/sync.py | 17 ++++---- pymodbus/client/tls_helper.py | 31 ++++++++++++++ pymodbus/server/async_io.py | 27 ++++++------ pymodbus/server/sync.py | 41 +++++++++---------- pymodbus/server/tls_helper.py | 32 +++++++++++++++ test/test_client_sync.py | 29 ++++++++++--- test/test_server_sync.py | 22 ++++++---- 14 files changed, 207 insertions(+), 107 deletions(-) create mode 100644 pymodbus/client/tls_helper.py create mode 100644 pymodbus/server/tls_helper.py diff --git a/examples/common/asyncio_server.py b/examples/common/asyncio_server.py index ad2c7c5cf..c9ee30419 100755 --- a/examples/common/asyncio_server.py +++ b/examples/common/asyncio_server.py @@ -131,7 +131,13 @@ async def run_server(): # Tls: # await StartTlsServer(context, identity=identity, address=("localhost", 8020), - # certfile="server.crt", keyfile="server.key", + # certfile="server.crt", keyfile="server.key", password="pwd", + # allow_reuse_address=True, allow_reuse_port=True, + # defer_start=False) + + # Tls and force require client's certificate for TLS full handshake: + # await StartTlsServer(context, identity=identity, address=("localhost", 8020), + # certfile="server.crt", keyfile="server.key", password="pwd", reqclicert=True, # allow_reuse_address=True, allow_reuse_port=True, # defer_start=False) diff --git a/examples/common/synchronous_server.py b/examples/common/synchronous_server.py index 4266fac23..a9a1f3a1a 100755 --- a/examples/common/synchronous_server.py +++ b/examples/common/synchronous_server.py @@ -120,8 +120,14 @@ def run_server(): # framer=ModbusRtuFramer, address=("0.0.0.0", 5020)) # TLS - # StartTlsServer(context, identity=identity, certfile="server.crt", - # keyfile="server.key", address=("0.0.0.0", 8020)) + # StartTlsServer(context, identity=identity, + # certfile="server.crt", keyfile="server.key", password="pwd", + # address=("0.0.0.0", 8020)) + + # Tls and force require client's certificate for TLS full handshake: + # StartTlsServer(context, identity=identity, + # certfile="server.crt", keyfile="server.key", password="pwd", reqclicert=True, + # address=("0.0.0.0", 8020)) # Udp: # StartUdpServer(context, identity=identity, address=("0.0.0.0", 5020)) diff --git a/examples/contrib/asynchronous_asyncio_modbus_tls_client.py b/examples/contrib/asynchronous_asyncio_modbus_tls_client.py index d5a973d84..9f474ac72 100755 --- a/examples/contrib/asynchronous_asyncio_modbus_tls_client.py +++ b/examples/contrib/asynchronous_asyncio_modbus_tls_client.py @@ -17,11 +17,13 @@ # -------------------------------------------------------------------------- # # the TLS detail security can be set in SSLContext which is the context here # -------------------------------------------------------------------------- # -context = ssl.create_default_context() -context.options |= ssl.OP_NO_SSLv2 -context.options |= ssl.OP_NO_SSLv3 -context.options |= ssl.OP_NO_TLSv1 -context.options |= ssl.OP_NO_TLSv1_1 +sslctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) +sslctx.verify_mode = ssl.CERT_REQUIRED +sslctx.check_hostname = True + +# Prepare client's certificate which the server requires for TLS full handshake +# sslctx.load_cert_chain(certfile="client.crt", keyfile="client.key", +# password="pwd") async def start_async_test(client): result = await client.read_coils(1, 8) @@ -35,6 +37,6 @@ async def start_async_test(client): # pass SSLContext which is the context here to ModbusTcpClient() # -------------------------------------------------------------------------- # loop, client = AsyncModbusTLSClient(ASYNC_IO, 'test.host.com', 8020, - sslctx=context) + sslctx=sslctx) loop.run_until_complete(start_async_test(client.protocol)) loop.close() diff --git a/examples/contrib/modbus_tls_client.py b/examples/contrib/modbus_tls_client.py index 98ad02a12..d49299932 100755 --- a/examples/contrib/modbus_tls_client.py +++ b/examples/contrib/modbus_tls_client.py @@ -16,16 +16,18 @@ # -------------------------------------------------------------------------- # # the TLS detail security can be set in SSLContext which is the context here # -------------------------------------------------------------------------- # -context = ssl.create_default_context() -context.options |= ssl.OP_NO_SSLv2 -context.options |= ssl.OP_NO_SSLv3 -context.options |= ssl.OP_NO_TLSv1 -context.options |= ssl.OP_NO_TLSv1_1 +sslctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) +sslctx.verify_mode = ssl.CERT_REQUIRED +sslctx.check_hostname = True + +# Prepare client's certificate which the server requires for TLS full handshake +# sslctx.load_cert_chain(certfile="client.crt", keyfile="client.key", +# password="pwd") # -------------------------------------------------------------------------- # # pass SSLContext which is the context here to ModbusTcpClient() # -------------------------------------------------------------------------- # -client = ModbusTlsClient('test.host.com', 8020, sslctx=context) +client = ModbusTlsClient('test.host.com', 8020, sslctx=sslctx) client.connect() result = client.read_coils(1, 8) diff --git a/pymodbus/client/asynchronous/async_io/__init__.py b/pymodbus/client/asynchronous/async_io/__init__.py index 552202a3f..decd86c4d 100644 --- a/pymodbus/client/asynchronous/async_io/__init__.py +++ b/pymodbus/client/asynchronous/async_io/__init__.py @@ -7,6 +7,7 @@ import ssl from pymodbus.exceptions import ConnectionException from pymodbus.client.asynchronous.mixins import AsyncModbusClientMixin +from pymodbus.client.tls_helper import sslctx_provider from pymodbus.utilities import hexlify_packets from pymodbus.transaction import FifoTransactionManager import logging @@ -449,25 +450,18 @@ def __init__(self, protocol_class=None, loop=None, framer=None): ReconnectingAsyncioModbusTcpClient.__init__(self, protocol_class, loop) @asyncio.coroutine - def start(self, host, port=802, sslctx=None, server_hostname=None): + def start(self, host='localhost', port=802, sslctx=None, + certfile=None, keyfile=None, password=None): """ Initiates connection to start client - :param host: - :param port: - :param sslctx: - :param server_hostname: - :return: + :param host: The host to connect to (default localhost) + :param port: Port to connect + :param sslctx:The SSLContext to use for TLS (default None and auto create) + :param certfile: The optional client's cert file path for TLS server request + :param keyfile: The optional client's key file path for TLS server request + :param password: The password for for decrypting client's private key file """ - self.sslctx = sslctx - if self.sslctx is None: - self.sslctx = ssl.create_default_context() - # According to MODBUS/TCP Security Protocol Specification, it is - # TLSv2 at least - self.sslctx.options |= ssl.OP_NO_TLSv1_1 - self.sslctx.options |= ssl.OP_NO_TLSv1 - self.sslctx.options |= ssl.OP_NO_SSLv3 - self.sslctx.options |= ssl.OP_NO_SSLv2 - self.server_hostname = server_hostname + self.sslctx = sslctx_provider(sslctx, certfile, keyfile, password) yield from ReconnectingAsyncioModbusTcpClient.start(self, host, port) @asyncio.coroutine @@ -478,7 +472,7 @@ def _connect(self): self.host, self.port, ssl=self.sslctx, - server_hostname=self.server_hostname) + server_hostname=self.host) except Exception as ex: _logger.warning('Failed to connect: %s' % ex) asyncio.ensure_future(self._reconnect(), loop=self.loop) @@ -849,7 +843,8 @@ def init_tcp_client(proto_cls, loop, host, port, **kwargs): @asyncio.coroutine def init_tls_client(proto_cls, loop, host, port, sslctx=None, - server_hostname=None, framer=None, **kwargs): + certfile=None, keyfile=None, password=None, + framer=None,**kwargs): """ Helper function to initialize tcp client :param proto_cls: @@ -857,14 +852,16 @@ def init_tls_client(proto_cls, loop, host, port, sslctx=None, :param host: :param port: :param sslctx: - :param server_hostname: + :param certfile: The optional client's cert file path for TLS server request + :param keyfile: The optional client's key file path for TLS server request + :param password: The password for for decrypting client's private key file :param framer: :param kwargs: :return: """ client = ReconnectingAsyncioModbusTlsClient(protocol_class=proto_cls, loop=loop, framer=framer) - yield from client.start(host, port, sslctx, server_hostname) + yield from client.start(host, port, sslctx, certfile, keyfile, password) return client diff --git a/pymodbus/client/asynchronous/factory/tls.py b/pymodbus/client/asynchronous/factory/tls.py index 3b11ebf5a..86356268b 100644 --- a/pymodbus/client/asynchronous/factory/tls.py +++ b/pymodbus/client/asynchronous/factory/tls.py @@ -13,14 +13,17 @@ LOGGER = logging.getLogger(__name__) def async_io_factory(host="127.0.0.1", port=Defaults.TLSPort, sslctx=None, - server_hostname=None, framer=None, source_address=None, + certfile=None, keyfile=None, password=None, + framer=None, source_address=None, timeout=None, **kwargs): """ Factory to create asyncio based asynchronous tls clients - :param host: Host IP address + :param host: Target server's name, also matched for certificate :param port: Port :param sslctx: The SSLContext to use for TLS (default None and auto create) - :param server_hostname: Target server's name matched for certificate + :param certfile: The optional client's cert file path for TLS server request + :param keyfile: The optional client's key file path for TLS server request + :param password: The password for for decrypting client's private key file :param framer: Modbus Framer :param source_address: Bind address :param timeout: Timeout in seconds @@ -33,12 +36,12 @@ def async_io_factory(host="127.0.0.1", port=Defaults.TLSPort, sslctx=None, proto_cls = kwargs.get("proto_cls", None) if not loop.is_running(): asyncio.set_event_loop(loop) - cor = init_tls_client(proto_cls, loop, host, port, sslctx, server_hostname, - framer) + cor = init_tls_client(proto_cls, loop, host, port, + sslctx, certfile, keyfile, password, framer) client = loop.run_until_complete(asyncio.gather(cor))[0] else: - cor = init_tls_client(proto_cls, loop, host, port, sslctx, server_hostname, - framer) + cor = init_tls_client(proto_cls, loop, host, port, + sslctx, certfile, keyfile, password, framer) future = asyncio.run_coroutine_threadsafe(cor, loop=loop) client = future.result() diff --git a/pymodbus/client/asynchronous/tls.py b/pymodbus/client/asynchronous/tls.py index d2412ff54..83eab9f79 100644 --- a/pymodbus/client/asynchronous/tls.py +++ b/pymodbus/client/asynchronous/tls.py @@ -21,17 +21,19 @@ class AsyncModbusTLSClient(object): from pymodbus.client.asynchronous.tls import AsyncModbusTLSClient """ def __new__(cls, scheduler, host="127.0.0.1", port=Defaults.TLSPort, - framer=None, sslctx=None, server_hostname=None, - source_address=None, timeout=None, **kwargs): + framer=None, sslctx=None, certfile=None, keyfile=None, + password=None, source_address=None, timeout=None, **kwargs): """ Scheduler to use: - async_io (asyncio) :param scheduler: Backend to use - :param host: Host IP address + :param host: Target server's name, also matched for certificate :param port: Port :param framer: Modbus Framer to use :param sslctx: The SSLContext to use for TLS (default None and auto create) - :param server_hostname: Target server's name matched for certificate + :param certfile: The optional client's cert file path for TLS server request + :param keyfile: The optional client's key file path for TLS server request + :param password: The password for for decrypting client's private key file :param source_address: source address specific to underlying backend :param timeout: Time out in seconds :param kwargs: Other extra args specific to Backend being used @@ -45,8 +47,9 @@ def __new__(cls, scheduler, host="127.0.0.1", port=Defaults.TLSPort, framer = framer or ModbusTlsFramer(ClientDecoder()) factory_class = get_factory(scheduler) yieldable = factory_class(host=host, port=port, sslctx=sslctx, - server_hostname=server_hostname, - framer=framer, source_address=source_address, + certfile=certfile, keyfile=keyfile, + password=password, framer=framer, + source_address=source_address, timeout=timeout, **kwargs) return yieldable diff --git a/pymodbus/client/sync.py b/pymodbus/client/sync.py index 3d1fb6d6f..c6a1491b9 100644 --- a/pymodbus/client/sync.py +++ b/pymodbus/client/sync.py @@ -16,6 +16,7 @@ from pymodbus.transaction import ModbusAsciiFramer, ModbusRtuFramer from pymodbus.transaction import ModbusTlsFramer from pymodbus.client.common import ModbusClientMixin +from pymodbus.client.tls_helper import sslctx_provider # --------------------------------------------------------------------------- # # Logging @@ -368,27 +369,23 @@ class ModbusTlsClient(ModbusTcpClient): """ def __init__(self, host='localhost', port=Defaults.TLSPort, sslctx=None, - framer=ModbusTlsFramer, **kwargs): + certfile=None, keyfile=None, password=None, framer=ModbusTlsFramer, + **kwargs): """ Initialize a client instance :param host: The host to connect to (default localhost) :param port: The modbus port to connect to (default 802) :param sslctx: The SSLContext to use for TLS (default None and auto create) + :param certfile: The optional client's cert file path for TLS server request + :param keyfile: The optional client's key file path for TLS server request + :param password: The password for for decrypting client's private key file :param source_address: The source address tuple to bind to (default ('', 0)) :param timeout: The timeout to use for this socket (default Defaults.Timeout) :param framer: The modbus framer to use (default ModbusSocketFramer) .. note:: The host argument will accept ipv4 and ipv6 hosts """ - self.sslctx = sslctx - if self.sslctx is None: - self.sslctx = ssl.create_default_context() - # According to MODBUS/TCP Security Protocol Specification, it is - # TLSv2 at least - self.sslctx.options |= ssl.OP_NO_TLSv1_1 - self.sslctx.options |= ssl.OP_NO_TLSv1 - self.sslctx.options |= ssl.OP_NO_SSLv3 - self.sslctx.options |= ssl.OP_NO_SSLv2 + self.sslctx = sslctx_provider(sslctx, certfile, keyfile, password) ModbusTcpClient.__init__(self, host, port, framer, **kwargs) def connect(self): diff --git a/pymodbus/client/tls_helper.py b/pymodbus/client/tls_helper.py new file mode 100644 index 000000000..71c06e884 --- /dev/null +++ b/pymodbus/client/tls_helper.py @@ -0,0 +1,31 @@ +""" +TLS helper for Modbus TLS Client +------------------------------------------ + +""" +import ssl + +def sslctx_provider(sslctx=None, certfile=None, keyfile=None, password=None): + """ Provide the SSLContext for ModbusTlsClient + + If the user defined SSLContext is not passed in, sslctx_provider will + produce a default one. + + :param sslctx: The user defined SSLContext to use for TLS (default None and + auto create) + :param certfile: The optional client's cert file path for TLS server request + :param keyfile: The optional client's key file path for TLS server request + :param password: The password for for decrypting client's private key file + """ + if sslctx is None: + # According to MODBUS/TCP Security Protocol Specification, it is + # TLSv2 at least + sslctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) + sslctx.verify_mode = ssl.CERT_REQUIRED + sslctx.check_hostname = True + + if certfile and keyfile: + sslctx.load_cert_chain(certfile=certfile, keyfile=keyfile, + password=password) + + return sslctx diff --git a/pymodbus/server/async_io.py b/pymodbus/server/async_io.py index a85b37a78..68b9dca15 100755 --- a/pymodbus/server/async_io.py +++ b/pymodbus/server/async_io.py @@ -21,6 +21,7 @@ from pymodbus.exceptions import NotImplementedException, NoSuchSlaveException from pymodbus.pdu import ModbusExceptions as merror from pymodbus.compat import socketserver, byte2int +from pymodbus.server.tls_helper import sslctx_provider # --------------------------------------------------------------------------- # # Logging @@ -524,6 +525,8 @@ def __init__(self, sslctx=None, certfile=None, keyfile=None, + password=None, + reqclicert=False, handler=None, allow_reuse_address=False, allow_reuse_port=False, @@ -544,6 +547,8 @@ def __init__(self, create) :param certfile: The cert file path for TLS (used if sslctx is None) :param keyfile: The key file path for TLS (used if sslctx is None) + :param password: The password for for decrypting the private key file + :param reqclicert: Force the sever request client's certificate :param handler: A handler for each client session; default is ModbusConnectedRequestHandler. The handler class receives connection create/teardown events @@ -582,18 +587,9 @@ def __init__(self, if isinstance(identity, ModbusDeviceIdentification): self.control.Identity.update(identity) - self.sslctx = sslctx - if self.sslctx is None: - self.sslctx = ssl.create_default_context() - self.sslctx.load_cert_chain(certfile=certfile, keyfile=keyfile) - # According to MODBUS/TCP Security Protocol Specification, it is - # TLSv2 at least - self.sslctx.options |= ssl.OP_NO_TLSv1_1 - self.sslctx.options |= ssl.OP_NO_TLSv1 - self.sslctx.options |= ssl.OP_NO_SSLv3 - self.sslctx.options |= ssl.OP_NO_SSLv2 - self.sslctx.verify_mode = ssl.CERT_OPTIONAL - self.sslctx.check_hostname = False + self.sslctx = sslctx_provider(sslctx, certfile, keyfile, password, + reqclicert) + # asyncio future that will be done once server has started self.serving = self.loop.create_future() # constructors cannot be declared async, so we have to @@ -845,7 +841,8 @@ async def StartTcpServer(context=None, identity=None, address=None, async def StartTlsServer(context=None, identity=None, address=None, sslctx=None, - certfile=None, keyfile=None, + certfile=None, keyfile=None, password=None, + reqclicert=False, allow_reuse_address=False, allow_reuse_port=False, custom_functions=[], @@ -858,6 +855,8 @@ async def StartTlsServer(context=None, identity=None, address=None, :param sslctx: The SSLContext to use for TLS (default None and auto create) :param certfile: The cert file path for TLS (used if sslctx is None) :param keyfile: The key file path for TLS (used if sslctx is None) + :param password: The password for for decrypting the private key file + :param reqclicert: Force the sever request client's certificate :param allow_reuse_address: Whether the server will allow the reuse of an address. :param allow_reuse_port: Whether the server will allow the reuse of a port. @@ -872,7 +871,7 @@ async def StartTlsServer(context=None, identity=None, address=None, """ framer = kwargs.pop("framer", ModbusTlsFramer) server = ModbusTlsServer(context, framer, identity, address, sslctx, - certfile, keyfile, + certfile, keyfile, password, reqclicert, allow_reuse_address=allow_reuse_address, allow_reuse_port=allow_reuse_port, **kwargs) diff --git a/pymodbus/server/sync.py b/pymodbus/server/sync.py index 041069aab..5945a5185 100644 --- a/pymodbus/server/sync.py +++ b/pymodbus/server/sync.py @@ -19,6 +19,7 @@ from pymodbus.exceptions import NotImplementedException, NoSuchSlaveException from pymodbus.pdu import ModbusExceptions as merror from pymodbus.compat import socketserver, byte2int +from pymodbus.server.tls_helper import sslctx_provider # --------------------------------------------------------------------------- # # Logging @@ -375,9 +376,10 @@ class ModbusTlsServer(ModbusTcpServer): server context instance. """ - def __init__(self, context, framer=None, identity=None, - address=None, handler=None, allow_reuse_address=False, - sslctx=None, certfile=None, keyfile=None, **kwargs): + def __init__(self, context, framer=None, identity=None, address=None, + sslctx=None, certfile=None, keyfile=None, password=None, + reqclicert=False, handler=None, allow_reuse_address=False, + **kwargs): """ Overloaded initializer for the ModbusTcpServer If the identify structure is not passed in, the ModbusControlBlock @@ -387,31 +389,23 @@ def __init__(self, context, framer=None, identity=None, :param framer: The framer strategy to use :param identity: An optional identify structure :param address: An optional (interface, port) to bind to. - :param handler: A handler for each client session; default is - ModbusConnectedRequestHandler - :param allow_reuse_address: Whether the server will allow the - reuse of an address. :param sslctx: The SSLContext to use for TLS (default None and auto create) :param certfile: The cert file path for TLS (used if sslctx is None) :param keyfile: The key file path for TLS (used if sslctx is None) + :param password: The password for for decrypting the private key file + :param reqclicert: Force the sever request client's certificate + :param handler: A handler for each client session; default is + ModbusConnectedRequestHandler + :param allow_reuse_address: Whether the server will allow the + reuse of an address. :param ignore_missing_slaves: True to not send errors on a request to a missing slave :param broadcast_enable: True to treat unit_id 0 as broadcast address, False to treat 0 as any other unit_id """ - self.sslctx = sslctx - if self.sslctx is None: - self.sslctx = ssl.create_default_context() - self.sslctx.load_cert_chain(certfile=certfile, keyfile=keyfile) - # According to MODBUS/TCP Security Protocol Specification, it is - # TLSv2 at least - self.sslctx.options |= ssl.OP_NO_TLSv1_1 - self.sslctx.options |= ssl.OP_NO_TLSv1 - self.sslctx.options |= ssl.OP_NO_SSLv3 - self.sslctx.options |= ssl.OP_NO_SSLv2 - self.sslctx.verify_mode = ssl.CERT_OPTIONAL - self.sslctx.check_hostname = False + self.sslctx = sslctx_provider(sslctx, certfile, keyfile, password, + reqclicert) ModbusTcpServer.__init__(self, context, framer, identity, address, handler, allow_reuse_address, **kwargs) @@ -625,7 +619,8 @@ def StartTcpServer(context=None, identity=None, address=None, def StartTlsServer(context=None, identity=None, address=None, sslctx=None, - certfile=None, keyfile=None, custom_functions=[], **kwargs): + certfile=None, keyfile=None, password=None, reqclicert=False, + custom_functions=[], **kwargs): """ A factory to start and run a tls modbus server :param context: The ModbusServerContext datastore @@ -634,14 +629,16 @@ def StartTlsServer(context=None, identity=None, address=None, sslctx=None, :param sslctx: The SSLContext to use for TLS (default None and auto create) :param certfile: The cert file path for TLS (used if sslctx is None) :param keyfile: The key file path for TLS (used if sslctx is None) + :param password: The password for for decrypting the private key file + :param reqclicert: Force the sever request client's certificate :param custom_functions: An optional list of custom function classes supported by server instance. :param ignore_missing_slaves: True to not send errors on a request to a missing slave """ framer = kwargs.pop("framer", ModbusTlsFramer) - server = ModbusTlsServer(context, framer, identity, address, sslctx=sslctx, - certfile=certfile, keyfile=keyfile, **kwargs) + server = ModbusTlsServer(context, framer, identity, address, sslctx, + certfile, keyfile, password, reqclicert, **kwargs) for f in custom_functions: server.decoder.register(f) diff --git a/pymodbus/server/tls_helper.py b/pymodbus/server/tls_helper.py new file mode 100644 index 000000000..a806fbbd5 --- /dev/null +++ b/pymodbus/server/tls_helper.py @@ -0,0 +1,32 @@ +""" +TLS helper for Modbus TLS Server +------------------------------------------ + +""" +import ssl + +def sslctx_provider(sslctx=None, certfile=None, keyfile=None, password=None, + reqclicert=False): + """ Provide the SSLContext for ModbusTlsServer + + If the user defined SSLContext is not passed in, sslctx_provider will + produce a default one. + + :param sslctx: The user defined SSLContext to use for TLS (default None and + auto create) + :param certfile: The cert file path for TLS (used if sslctx is None) + :param keyfile: The key file path for TLS (used if sslctx is None) + :param password: The password for for decrypting the private key file + :param reqclicert: Force the sever request client's certificate + """ + if sslctx is None: + # According to MODBUS/TCP Security Protocol Specification, it is + # TLSv2 at least + sslctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) + sslctx.load_cert_chain(certfile=certfile, keyfile=keyfile, + password=password) + + if reqclicert: + sslctx.verify_mode = ssl.CERT_REQUIRED + + return sslctx diff --git a/test/test_client_sync.py b/test/test_client_sync.py index 30456a58d..f9cc4426d 100755 --- a/test/test_client_sync.py +++ b/test/test_client_sync.py @@ -18,6 +18,7 @@ from pymodbus.client.sync import ModbusTcpClient, ModbusUdpClient from pymodbus.client.sync import ModbusSerialClient, BaseModbusClient from pymodbus.client.sync import ModbusTlsClient +from pymodbus.client.tls_helper import sslctx_provider from pymodbus.exceptions import ConnectionException, NotImplementedException from pymodbus.exceptions import ParameterException from pymodbus.transaction import ModbusAsciiFramer, ModbusRtuFramer @@ -311,18 +312,34 @@ class CustomeRequest: # Test TLS Client # -----------------------------------------------------------------------# + def testTlsSSLCTX_Provider(self): + ''' test that sslctx_provider() produce SSLContext correctly ''' + with patch.object(ssl.SSLContext, 'load_cert_chain') as mock_method: + sslctx1 = sslctx_provider(certfile="cert.pem") + self.assertIsNotNone(sslctx1) + self.assertEqual(type(sslctx1), ssl.SSLContext) + self.assertEqual(mock_method.called, False) + + sslctx2 = sslctx_provider(keyfile="key.pem") + self.assertIsNotNone(sslctx2) + self.assertEqual(type(sslctx2), ssl.SSLContext) + self.assertEqual(mock_method.called, False) + + sslctx3 = sslctx_provider(certfile="cert.pem", keyfile="key.pem") + self.assertIsNotNone(sslctx3) + self.assertEqual(type(sslctx3), ssl.SSLContext) + self.assertEqual(mock_method.called, True) + + sslctx_old = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) + sslctx_new = sslctx_provider(sslctx=sslctx_old) + self.assertEqual(sslctx_new, sslctx_old) + def testSyncTlsClientInstantiation(self): # default SSLContext client = ModbusTlsClient() self.assertNotEqual(client, None) self.assertTrue(client.sslctx) - # user defined SSLContext - context = ssl.create_default_context() - client = ModbusTlsClient(sslctx=context) - self.assertNotEqual(client, None) - self.assertEqual(client.sslctx, context) - def testBasicSyncTlsClient(self): ''' Test the basic methods for the tls sync client''' diff --git a/test/test_server_sync.py b/test/test_server_sync.py index 8b37b9ac2..a0dc97636 100644 --- a/test/test_server_sync.py +++ b/test/test_server_sync.py @@ -16,6 +16,7 @@ from pymodbus.server.sync import ModbusDisconnectedRequestHandler from pymodbus.server.sync import ModbusTcpServer, ModbusTlsServer, ModbusUdpServer, ModbusSerialServer from pymodbus.server.sync import StartTcpServer, StartTlsServer, StartUdpServer, StartSerialServer +from pymodbus.server.tls_helper import sslctx_provider from pymodbus.exceptions import NotImplementedException from pymodbus.bit_read_message import ReadCoilsRequest, ReadCoilsResponse from pymodbus.datastore import ModbusServerContext @@ -277,22 +278,29 @@ def testTcpServerProcess(self): #-----------------------------------------------------------------------# # Test TLS Server #-----------------------------------------------------------------------# + def testTlsSSLCTX_Provider(self): + ''' test that sslctx_provider() produce SSLContext correctly ''' + with patch.object(ssl.SSLContext, 'load_cert_chain'): + sslctx = sslctx_provider(reqclicert=True) + self.assertIsNotNone(sslctx) + self.assertEqual(type(sslctx), ssl.SSLContext) + self.assertEqual(sslctx.verify_mode, ssl.CERT_REQUIRED) + + sslctx_old = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) + sslctx_new = sslctx_provider(sslctx=sslctx_old) + self.assertEqual(sslctx_new, sslctx_old) + def testTlsServerInit(self): ''' test that the synchronous TLS server intial correctly ''' with patch.object(socketserver.TCPServer, 'server_activate'): with patch.object(ssl.SSLContext, 'load_cert_chain') as mock_method: identity = ModbusDeviceIdentification(info={0x00: 'VendorName'}) server = ModbusTlsServer(context=None, identity=identity, + reqclicert=True, bind_and_activate=False) server.server_activate() self.assertIsNotNone(server.sslctx) - self.assertEqual(type(server.socket), ssl.SSLSocket) - server.server_close() - sslctx = ssl.create_default_context() - server = ModbusTlsServer(context=None, identity=identity, - sslctx=sslctx, bind_and_activate=False) - server.server_activate() - self.assertEqual(server.sslctx, sslctx) + self.assertEqual(server.sslctx.verify_mode, ssl.CERT_REQUIRED) self.assertEqual(type(server.socket), ssl.SSLSocket) server.server_close()