diff --git a/examples/common/asyncio_server.py b/examples/common/asyncio_server.py index ad2c7c5cf1..c9ee304194 100755 --- a/examples/common/asyncio_server.py +++ b/examples/common/asyncio_server.py @@ -131,7 +131,13 @@ async def run_server(): # Tls: # await StartTlsServer(context, identity=identity, address=("localhost", 8020), - # certfile="server.crt", keyfile="server.key", + # certfile="server.crt", keyfile="server.key", password="pwd", + # allow_reuse_address=True, allow_reuse_port=True, + # defer_start=False) + + # Tls and force require client's certificate for TLS full handshake: + # await StartTlsServer(context, identity=identity, address=("localhost", 8020), + # certfile="server.crt", keyfile="server.key", password="pwd", reqclicert=True, # allow_reuse_address=True, allow_reuse_port=True, # defer_start=False) diff --git a/examples/common/synchronous_server.py b/examples/common/synchronous_server.py index 4266fac233..a9a1f3a1ac 100755 --- a/examples/common/synchronous_server.py +++ b/examples/common/synchronous_server.py @@ -120,8 +120,14 @@ def run_server(): # framer=ModbusRtuFramer, address=("0.0.0.0", 5020)) # TLS - # StartTlsServer(context, identity=identity, certfile="server.crt", - # keyfile="server.key", address=("0.0.0.0", 8020)) + # StartTlsServer(context, identity=identity, + # certfile="server.crt", keyfile="server.key", password="pwd", + # address=("0.0.0.0", 8020)) + + # Tls and force require client's certificate for TLS full handshake: + # StartTlsServer(context, identity=identity, + # certfile="server.crt", keyfile="server.key", password="pwd", reqclicert=True, + # address=("0.0.0.0", 8020)) # Udp: # StartUdpServer(context, identity=identity, address=("0.0.0.0", 5020)) diff --git a/examples/contrib/asynchronous_asyncio_modbus_tls_client.py b/examples/contrib/asynchronous_asyncio_modbus_tls_client.py index d5a973d84b..ec70620dbc 100755 --- a/examples/contrib/asynchronous_asyncio_modbus_tls_client.py +++ b/examples/contrib/asynchronous_asyncio_modbus_tls_client.py @@ -17,11 +17,17 @@ # -------------------------------------------------------------------------- # # the TLS detail security can be set in SSLContext which is the context here # -------------------------------------------------------------------------- # -context = ssl.create_default_context() -context.options |= ssl.OP_NO_SSLv2 -context.options |= ssl.OP_NO_SSLv3 -context.options |= ssl.OP_NO_TLSv1 -context.options |= ssl.OP_NO_TLSv1_1 +sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS) +sslctx.options |= ssl.OP_NO_SSLv2 +sslctx.options |= ssl.OP_NO_SSLv3 +sslctx.options |= ssl.OP_NO_TLSv1 +sslctx.options |= ssl.OP_NO_TLSv1_1 +sslctx.verify_mode = ssl.CERT_REQUIRED +sslctx.check_hostname = True + +# Prepare client's certificate which the server requires for TLS full handshake +# sslctx.load_cert_chain(certfile="client.crt", keyfile="client.key", +# password="pwd") async def start_async_test(client): result = await client.read_coils(1, 8) @@ -35,6 +41,6 @@ async def start_async_test(client): # pass SSLContext which is the context here to ModbusTcpClient() # -------------------------------------------------------------------------- # loop, client = AsyncModbusTLSClient(ASYNC_IO, 'test.host.com', 8020, - sslctx=context) + sslctx=sslctx) loop.run_until_complete(start_async_test(client.protocol)) loop.close() diff --git a/examples/contrib/modbus_tls_client.py b/examples/contrib/modbus_tls_client.py index 98ad02a12a..5fa9ac8cee 100755 --- a/examples/contrib/modbus_tls_client.py +++ b/examples/contrib/modbus_tls_client.py @@ -16,16 +16,22 @@ # -------------------------------------------------------------------------- # # the TLS detail security can be set in SSLContext which is the context here # -------------------------------------------------------------------------- # -context = ssl.create_default_context() -context.options |= ssl.OP_NO_SSLv2 -context.options |= ssl.OP_NO_SSLv3 -context.options |= ssl.OP_NO_TLSv1 -context.options |= ssl.OP_NO_TLSv1_1 +sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS) +sslctx.options |= ssl.OP_NO_SSLv2 +sslctx.options |= ssl.OP_NO_SSLv3 +sslctx.options |= ssl.OP_NO_TLSv1 +sslctx.options |= ssl.OP_NO_TLSv1_1 +sslctx.verify_mode = ssl.CERT_REQUIRED +sslctx.check_hostname = True + +# Prepare client's certificate which the server requires for TLS full handshake +# sslctx.load_cert_chain(certfile="client.crt", keyfile="client.key", +# password="pwd") # -------------------------------------------------------------------------- # # pass SSLContext which is the context here to ModbusTcpClient() # -------------------------------------------------------------------------- # -client = ModbusTlsClient('test.host.com', 8020, sslctx=context) +client = ModbusTlsClient('test.host.com', 8020, sslctx=sslctx) client.connect() result = client.read_coils(1, 8) diff --git a/pymodbus/client/asynchronous/async_io/__init__.py b/pymodbus/client/asynchronous/async_io/__init__.py index 552202a3fa..decd86c4d0 100644 --- a/pymodbus/client/asynchronous/async_io/__init__.py +++ b/pymodbus/client/asynchronous/async_io/__init__.py @@ -7,6 +7,7 @@ import ssl from pymodbus.exceptions import ConnectionException from pymodbus.client.asynchronous.mixins import AsyncModbusClientMixin +from pymodbus.client.tls_helper import sslctx_provider from pymodbus.utilities import hexlify_packets from pymodbus.transaction import FifoTransactionManager import logging @@ -449,25 +450,18 @@ def __init__(self, protocol_class=None, loop=None, framer=None): ReconnectingAsyncioModbusTcpClient.__init__(self, protocol_class, loop) @asyncio.coroutine - def start(self, host, port=802, sslctx=None, server_hostname=None): + def start(self, host='localhost', port=802, sslctx=None, + certfile=None, keyfile=None, password=None): """ Initiates connection to start client - :param host: - :param port: - :param sslctx: - :param server_hostname: - :return: + :param host: The host to connect to (default localhost) + :param port: Port to connect + :param sslctx:The SSLContext to use for TLS (default None and auto create) + :param certfile: The optional client's cert file path for TLS server request + :param keyfile: The optional client's key file path for TLS server request + :param password: The password for for decrypting client's private key file """ - self.sslctx = sslctx - if self.sslctx is None: - self.sslctx = ssl.create_default_context() - # According to MODBUS/TCP Security Protocol Specification, it is - # TLSv2 at least - self.sslctx.options |= ssl.OP_NO_TLSv1_1 - self.sslctx.options |= ssl.OP_NO_TLSv1 - self.sslctx.options |= ssl.OP_NO_SSLv3 - self.sslctx.options |= ssl.OP_NO_SSLv2 - self.server_hostname = server_hostname + self.sslctx = sslctx_provider(sslctx, certfile, keyfile, password) yield from ReconnectingAsyncioModbusTcpClient.start(self, host, port) @asyncio.coroutine @@ -478,7 +472,7 @@ def _connect(self): self.host, self.port, ssl=self.sslctx, - server_hostname=self.server_hostname) + server_hostname=self.host) except Exception as ex: _logger.warning('Failed to connect: %s' % ex) asyncio.ensure_future(self._reconnect(), loop=self.loop) @@ -849,7 +843,8 @@ def init_tcp_client(proto_cls, loop, host, port, **kwargs): @asyncio.coroutine def init_tls_client(proto_cls, loop, host, port, sslctx=None, - server_hostname=None, framer=None, **kwargs): + certfile=None, keyfile=None, password=None, + framer=None,**kwargs): """ Helper function to initialize tcp client :param proto_cls: @@ -857,14 +852,16 @@ def init_tls_client(proto_cls, loop, host, port, sslctx=None, :param host: :param port: :param sslctx: - :param server_hostname: + :param certfile: The optional client's cert file path for TLS server request + :param keyfile: The optional client's key file path for TLS server request + :param password: The password for for decrypting client's private key file :param framer: :param kwargs: :return: """ client = ReconnectingAsyncioModbusTlsClient(protocol_class=proto_cls, loop=loop, framer=framer) - yield from client.start(host, port, sslctx, server_hostname) + yield from client.start(host, port, sslctx, certfile, keyfile, password) return client diff --git a/pymodbus/client/asynchronous/factory/tls.py b/pymodbus/client/asynchronous/factory/tls.py index 3b11ebf5aa..86356268b5 100644 --- a/pymodbus/client/asynchronous/factory/tls.py +++ b/pymodbus/client/asynchronous/factory/tls.py @@ -13,14 +13,17 @@ LOGGER = logging.getLogger(__name__) def async_io_factory(host="127.0.0.1", port=Defaults.TLSPort, sslctx=None, - server_hostname=None, framer=None, source_address=None, + certfile=None, keyfile=None, password=None, + framer=None, source_address=None, timeout=None, **kwargs): """ Factory to create asyncio based asynchronous tls clients - :param host: Host IP address + :param host: Target server's name, also matched for certificate :param port: Port :param sslctx: The SSLContext to use for TLS (default None and auto create) - :param server_hostname: Target server's name matched for certificate + :param certfile: The optional client's cert file path for TLS server request + :param keyfile: The optional client's key file path for TLS server request + :param password: The password for for decrypting client's private key file :param framer: Modbus Framer :param source_address: Bind address :param timeout: Timeout in seconds @@ -33,12 +36,12 @@ def async_io_factory(host="127.0.0.1", port=Defaults.TLSPort, sslctx=None, proto_cls = kwargs.get("proto_cls", None) if not loop.is_running(): asyncio.set_event_loop(loop) - cor = init_tls_client(proto_cls, loop, host, port, sslctx, server_hostname, - framer) + cor = init_tls_client(proto_cls, loop, host, port, + sslctx, certfile, keyfile, password, framer) client = loop.run_until_complete(asyncio.gather(cor))[0] else: - cor = init_tls_client(proto_cls, loop, host, port, sslctx, server_hostname, - framer) + cor = init_tls_client(proto_cls, loop, host, port, + sslctx, certfile, keyfile, password, framer) future = asyncio.run_coroutine_threadsafe(cor, loop=loop) client = future.result() diff --git a/pymodbus/client/asynchronous/tls.py b/pymodbus/client/asynchronous/tls.py index d2412ff546..83eab9f79c 100644 --- a/pymodbus/client/asynchronous/tls.py +++ b/pymodbus/client/asynchronous/tls.py @@ -21,17 +21,19 @@ class AsyncModbusTLSClient(object): from pymodbus.client.asynchronous.tls import AsyncModbusTLSClient """ def __new__(cls, scheduler, host="127.0.0.1", port=Defaults.TLSPort, - framer=None, sslctx=None, server_hostname=None, - source_address=None, timeout=None, **kwargs): + framer=None, sslctx=None, certfile=None, keyfile=None, + password=None, source_address=None, timeout=None, **kwargs): """ Scheduler to use: - async_io (asyncio) :param scheduler: Backend to use - :param host: Host IP address + :param host: Target server's name, also matched for certificate :param port: Port :param framer: Modbus Framer to use :param sslctx: The SSLContext to use for TLS (default None and auto create) - :param server_hostname: Target server's name matched for certificate + :param certfile: The optional client's cert file path for TLS server request + :param keyfile: The optional client's key file path for TLS server request + :param password: The password for for decrypting client's private key file :param source_address: source address specific to underlying backend :param timeout: Time out in seconds :param kwargs: Other extra args specific to Backend being used @@ -45,8 +47,9 @@ def __new__(cls, scheduler, host="127.0.0.1", port=Defaults.TLSPort, framer = framer or ModbusTlsFramer(ClientDecoder()) factory_class = get_factory(scheduler) yieldable = factory_class(host=host, port=port, sslctx=sslctx, - server_hostname=server_hostname, - framer=framer, source_address=source_address, + certfile=certfile, keyfile=keyfile, + password=password, framer=framer, + source_address=source_address, timeout=timeout, **kwargs) return yieldable diff --git a/pymodbus/client/sync.py b/pymodbus/client/sync.py index 3d1fb6d6f3..c6a1491b9b 100644 --- a/pymodbus/client/sync.py +++ b/pymodbus/client/sync.py @@ -16,6 +16,7 @@ from pymodbus.transaction import ModbusAsciiFramer, ModbusRtuFramer from pymodbus.transaction import ModbusTlsFramer from pymodbus.client.common import ModbusClientMixin +from pymodbus.client.tls_helper import sslctx_provider # --------------------------------------------------------------------------- # # Logging @@ -368,27 +369,23 @@ class ModbusTlsClient(ModbusTcpClient): """ def __init__(self, host='localhost', port=Defaults.TLSPort, sslctx=None, - framer=ModbusTlsFramer, **kwargs): + certfile=None, keyfile=None, password=None, framer=ModbusTlsFramer, + **kwargs): """ Initialize a client instance :param host: The host to connect to (default localhost) :param port: The modbus port to connect to (default 802) :param sslctx: The SSLContext to use for TLS (default None and auto create) + :param certfile: The optional client's cert file path for TLS server request + :param keyfile: The optional client's key file path for TLS server request + :param password: The password for for decrypting client's private key file :param source_address: The source address tuple to bind to (default ('', 0)) :param timeout: The timeout to use for this socket (default Defaults.Timeout) :param framer: The modbus framer to use (default ModbusSocketFramer) .. note:: The host argument will accept ipv4 and ipv6 hosts """ - self.sslctx = sslctx - if self.sslctx is None: - self.sslctx = ssl.create_default_context() - # According to MODBUS/TCP Security Protocol Specification, it is - # TLSv2 at least - self.sslctx.options |= ssl.OP_NO_TLSv1_1 - self.sslctx.options |= ssl.OP_NO_TLSv1 - self.sslctx.options |= ssl.OP_NO_SSLv3 - self.sslctx.options |= ssl.OP_NO_SSLv2 + self.sslctx = sslctx_provider(sslctx, certfile, keyfile, password) ModbusTcpClient.__init__(self, host, port, framer, **kwargs) def connect(self): diff --git a/pymodbus/client/tls_helper.py b/pymodbus/client/tls_helper.py new file mode 100644 index 0000000000..ce64613509 --- /dev/null +++ b/pymodbus/client/tls_helper.py @@ -0,0 +1,35 @@ +""" +TLS helper for Modbus TLS Client +------------------------------------------ + +""" +import ssl + +def sslctx_provider(sslctx=None, certfile=None, keyfile=None, password=None): + """ Provide the SSLContext for ModbusTlsClient + + If the user defined SSLContext is not passed in, sslctx_provider will + produce a default one. + + :param sslctx: The user defined SSLContext to use for TLS (default None and + auto create) + :param certfile: The optional client's cert file path for TLS server request + :param keyfile: The optional client's key file path for TLS server request + :param password: The password for for decrypting client's private key file + """ + if sslctx is None: + sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS) + # According to MODBUS/TCP Security Protocol Specification, it is + # TLSv2 at least + sslctx.options |= ssl.OP_NO_TLSv1_1 + sslctx.options |= ssl.OP_NO_TLSv1 + sslctx.options |= ssl.OP_NO_SSLv3 + sslctx.options |= ssl.OP_NO_SSLv2 + sslctx.verify_mode = ssl.CERT_REQUIRED + sslctx.check_hostname = True + + if certfile and keyfile: + sslctx.load_cert_chain(certfile=certfile, keyfile=keyfile, + password=password) + + return sslctx diff --git a/pymodbus/server/async_io.py b/pymodbus/server/async_io.py index a85b37a785..68b9dca158 100755 --- a/pymodbus/server/async_io.py +++ b/pymodbus/server/async_io.py @@ -21,6 +21,7 @@ from pymodbus.exceptions import NotImplementedException, NoSuchSlaveException from pymodbus.pdu import ModbusExceptions as merror from pymodbus.compat import socketserver, byte2int +from pymodbus.server.tls_helper import sslctx_provider # --------------------------------------------------------------------------- # # Logging @@ -524,6 +525,8 @@ def __init__(self, sslctx=None, certfile=None, keyfile=None, + password=None, + reqclicert=False, handler=None, allow_reuse_address=False, allow_reuse_port=False, @@ -544,6 +547,8 @@ def __init__(self, create) :param certfile: The cert file path for TLS (used if sslctx is None) :param keyfile: The key file path for TLS (used if sslctx is None) + :param password: The password for for decrypting the private key file + :param reqclicert: Force the sever request client's certificate :param handler: A handler for each client session; default is ModbusConnectedRequestHandler. The handler class receives connection create/teardown events @@ -582,18 +587,9 @@ def __init__(self, if isinstance(identity, ModbusDeviceIdentification): self.control.Identity.update(identity) - self.sslctx = sslctx - if self.sslctx is None: - self.sslctx = ssl.create_default_context() - self.sslctx.load_cert_chain(certfile=certfile, keyfile=keyfile) - # According to MODBUS/TCP Security Protocol Specification, it is - # TLSv2 at least - self.sslctx.options |= ssl.OP_NO_TLSv1_1 - self.sslctx.options |= ssl.OP_NO_TLSv1 - self.sslctx.options |= ssl.OP_NO_SSLv3 - self.sslctx.options |= ssl.OP_NO_SSLv2 - self.sslctx.verify_mode = ssl.CERT_OPTIONAL - self.sslctx.check_hostname = False + self.sslctx = sslctx_provider(sslctx, certfile, keyfile, password, + reqclicert) + # asyncio future that will be done once server has started self.serving = self.loop.create_future() # constructors cannot be declared async, so we have to @@ -845,7 +841,8 @@ async def StartTcpServer(context=None, identity=None, address=None, async def StartTlsServer(context=None, identity=None, address=None, sslctx=None, - certfile=None, keyfile=None, + certfile=None, keyfile=None, password=None, + reqclicert=False, allow_reuse_address=False, allow_reuse_port=False, custom_functions=[], @@ -858,6 +855,8 @@ async def StartTlsServer(context=None, identity=None, address=None, :param sslctx: The SSLContext to use for TLS (default None and auto create) :param certfile: The cert file path for TLS (used if sslctx is None) :param keyfile: The key file path for TLS (used if sslctx is None) + :param password: The password for for decrypting the private key file + :param reqclicert: Force the sever request client's certificate :param allow_reuse_address: Whether the server will allow the reuse of an address. :param allow_reuse_port: Whether the server will allow the reuse of a port. @@ -872,7 +871,7 @@ async def StartTlsServer(context=None, identity=None, address=None, """ framer = kwargs.pop("framer", ModbusTlsFramer) server = ModbusTlsServer(context, framer, identity, address, sslctx, - certfile, keyfile, + certfile, keyfile, password, reqclicert, allow_reuse_address=allow_reuse_address, allow_reuse_port=allow_reuse_port, **kwargs) diff --git a/pymodbus/server/sync.py b/pymodbus/server/sync.py index 041069aabd..5945a51852 100644 --- a/pymodbus/server/sync.py +++ b/pymodbus/server/sync.py @@ -19,6 +19,7 @@ from pymodbus.exceptions import NotImplementedException, NoSuchSlaveException from pymodbus.pdu import ModbusExceptions as merror from pymodbus.compat import socketserver, byte2int +from pymodbus.server.tls_helper import sslctx_provider # --------------------------------------------------------------------------- # # Logging @@ -375,9 +376,10 @@ class ModbusTlsServer(ModbusTcpServer): server context instance. """ - def __init__(self, context, framer=None, identity=None, - address=None, handler=None, allow_reuse_address=False, - sslctx=None, certfile=None, keyfile=None, **kwargs): + def __init__(self, context, framer=None, identity=None, address=None, + sslctx=None, certfile=None, keyfile=None, password=None, + reqclicert=False, handler=None, allow_reuse_address=False, + **kwargs): """ Overloaded initializer for the ModbusTcpServer If the identify structure is not passed in, the ModbusControlBlock @@ -387,31 +389,23 @@ def __init__(self, context, framer=None, identity=None, :param framer: The framer strategy to use :param identity: An optional identify structure :param address: An optional (interface, port) to bind to. - :param handler: A handler for each client session; default is - ModbusConnectedRequestHandler - :param allow_reuse_address: Whether the server will allow the - reuse of an address. :param sslctx: The SSLContext to use for TLS (default None and auto create) :param certfile: The cert file path for TLS (used if sslctx is None) :param keyfile: The key file path for TLS (used if sslctx is None) + :param password: The password for for decrypting the private key file + :param reqclicert: Force the sever request client's certificate + :param handler: A handler for each client session; default is + ModbusConnectedRequestHandler + :param allow_reuse_address: Whether the server will allow the + reuse of an address. :param ignore_missing_slaves: True to not send errors on a request to a missing slave :param broadcast_enable: True to treat unit_id 0 as broadcast address, False to treat 0 as any other unit_id """ - self.sslctx = sslctx - if self.sslctx is None: - self.sslctx = ssl.create_default_context() - self.sslctx.load_cert_chain(certfile=certfile, keyfile=keyfile) - # According to MODBUS/TCP Security Protocol Specification, it is - # TLSv2 at least - self.sslctx.options |= ssl.OP_NO_TLSv1_1 - self.sslctx.options |= ssl.OP_NO_TLSv1 - self.sslctx.options |= ssl.OP_NO_SSLv3 - self.sslctx.options |= ssl.OP_NO_SSLv2 - self.sslctx.verify_mode = ssl.CERT_OPTIONAL - self.sslctx.check_hostname = False + self.sslctx = sslctx_provider(sslctx, certfile, keyfile, password, + reqclicert) ModbusTcpServer.__init__(self, context, framer, identity, address, handler, allow_reuse_address, **kwargs) @@ -625,7 +619,8 @@ def StartTcpServer(context=None, identity=None, address=None, def StartTlsServer(context=None, identity=None, address=None, sslctx=None, - certfile=None, keyfile=None, custom_functions=[], **kwargs): + certfile=None, keyfile=None, password=None, reqclicert=False, + custom_functions=[], **kwargs): """ A factory to start and run a tls modbus server :param context: The ModbusServerContext datastore @@ -634,14 +629,16 @@ def StartTlsServer(context=None, identity=None, address=None, sslctx=None, :param sslctx: The SSLContext to use for TLS (default None and auto create) :param certfile: The cert file path for TLS (used if sslctx is None) :param keyfile: The key file path for TLS (used if sslctx is None) + :param password: The password for for decrypting the private key file + :param reqclicert: Force the sever request client's certificate :param custom_functions: An optional list of custom function classes supported by server instance. :param ignore_missing_slaves: True to not send errors on a request to a missing slave """ framer = kwargs.pop("framer", ModbusTlsFramer) - server = ModbusTlsServer(context, framer, identity, address, sslctx=sslctx, - certfile=certfile, keyfile=keyfile, **kwargs) + server = ModbusTlsServer(context, framer, identity, address, sslctx, + certfile, keyfile, password, reqclicert, **kwargs) for f in custom_functions: server.decoder.register(f) diff --git a/pymodbus/server/tls_helper.py b/pymodbus/server/tls_helper.py new file mode 100644 index 0000000000..2d7da71149 --- /dev/null +++ b/pymodbus/server/tls_helper.py @@ -0,0 +1,35 @@ +""" +TLS helper for Modbus TLS Server +------------------------------------------ + +""" +import ssl + +def sslctx_provider(sslctx=None, certfile=None, keyfile=None, password=None, + reqclicert=False): + """ Provide the SSLContext for ModbusTlsServer + + If the user defined SSLContext is not passed in, sslctx_provider will + produce a default one. + + :param sslctx: The user defined SSLContext to use for TLS (default None and + auto create) + :param certfile: The cert file path for TLS (used if sslctx is None) + :param keyfile: The key file path for TLS (used if sslctx is None) + :param password: The password for for decrypting the private key file + :param reqclicert: Force the sever request client's certificate + """ + if sslctx is None: + sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS) + sslctx.load_cert_chain(certfile=certfile, keyfile=keyfile, + password=password) + # According to MODBUS/TCP Security Protocol Specification, it is + # TLSv2 at least + sslctx.options |= ssl.OP_NO_TLSv1_1 + sslctx.options |= ssl.OP_NO_TLSv1 + sslctx.options |= ssl.OP_NO_SSLv3 + sslctx.options |= ssl.OP_NO_SSLv2 + if reqclicert: + sslctx.verify_mode = ssl.CERT_REQUIRED + + return sslctx diff --git a/test/test_client_sync.py b/test/test_client_sync.py index 30456a58de..dcb344ff36 100755 --- a/test/test_client_sync.py +++ b/test/test_client_sync.py @@ -18,6 +18,7 @@ from pymodbus.client.sync import ModbusTcpClient, ModbusUdpClient from pymodbus.client.sync import ModbusSerialClient, BaseModbusClient from pymodbus.client.sync import ModbusTlsClient +from pymodbus.client.tls_helper import sslctx_provider from pymodbus.exceptions import ConnectionException, NotImplementedException from pymodbus.exceptions import ParameterException from pymodbus.transaction import ModbusAsciiFramer, ModbusRtuFramer @@ -311,18 +312,34 @@ class CustomeRequest: # Test TLS Client # -----------------------------------------------------------------------# + def testTlsSSLCTX_Provider(self): + ''' test that sslctx_provider() produce SSLContext correctly ''' + with patch.object(ssl.SSLContext, 'load_cert_chain') as mock_method: + sslctx1 = sslctx_provider(certfile="cert.pem") + self.assertIsNotNone(sslctx1) + self.assertEqual(type(sslctx1), ssl.SSLContext) + assert not mock_method.called + + sslctx2 = sslctx_provider(keyfile="key.pem") + self.assertIsNotNone(sslctx2) + self.assertEqual(type(sslctx2), ssl.SSLContext) + assert not mock_method.called + + sslctx3 = sslctx_provider(certfile="cert.pem", keyfile="key.pem") + self.assertIsNotNone(sslctx3) + self.assertEqual(type(sslctx3), ssl.SSLContext) + assert mock_method.called + + sslctx_old = ssl.SSLContext(ssl.PROTOCOL_TLS) + sslctx_new = sslctx_provider(sslctx=sslctx_old) + self.assertEqual(sslctx_new, sslctx_old) + def testSyncTlsClientInstantiation(self): # default SSLContext client = ModbusTlsClient() self.assertNotEqual(client, None) self.assertTrue(client.sslctx) - # user defined SSLContext - context = ssl.create_default_context() - client = ModbusTlsClient(sslctx=context) - self.assertNotEqual(client, None) - self.assertEqual(client.sslctx, context) - def testBasicSyncTlsClient(self): ''' Test the basic methods for the tls sync client''' diff --git a/test/test_server_sync.py b/test/test_server_sync.py index 8b37b9ac22..34a104cf3d 100644 --- a/test/test_server_sync.py +++ b/test/test_server_sync.py @@ -16,6 +16,7 @@ from pymodbus.server.sync import ModbusDisconnectedRequestHandler from pymodbus.server.sync import ModbusTcpServer, ModbusTlsServer, ModbusUdpServer, ModbusSerialServer from pymodbus.server.sync import StartTcpServer, StartTlsServer, StartUdpServer, StartSerialServer +from pymodbus.server.tls_helper import sslctx_provider from pymodbus.exceptions import NotImplementedException from pymodbus.bit_read_message import ReadCoilsRequest, ReadCoilsResponse from pymodbus.datastore import ModbusServerContext @@ -277,22 +278,29 @@ def testTcpServerProcess(self): #-----------------------------------------------------------------------# # Test TLS Server #-----------------------------------------------------------------------# + def testTlsSSLCTX_Provider(self): + ''' test that sslctx_provider() produce SSLContext correctly ''' + with patch.object(ssl.SSLContext, 'load_cert_chain') as mock_method: + sslctx = sslctx_provider(reqclicert=True) + self.assertIsNotNone(sslctx) + self.assertEqual(type(sslctx), ssl.SSLContext) + self.assertEqual(sslctx.verify_mode, ssl.CERT_REQUIRED) + + sslctx_old = ssl.SSLContext(ssl.PROTOCOL_TLS) + sslctx_new = sslctx_provider(sslctx=sslctx_old) + self.assertEqual(sslctx_new, sslctx_old) + def testTlsServerInit(self): ''' test that the synchronous TLS server intial correctly ''' with patch.object(socketserver.TCPServer, 'server_activate'): with patch.object(ssl.SSLContext, 'load_cert_chain') as mock_method: identity = ModbusDeviceIdentification(info={0x00: 'VendorName'}) server = ModbusTlsServer(context=None, identity=identity, + reqclicert=True, bind_and_activate=False) server.server_activate() self.assertIsNotNone(server.sslctx) - self.assertEqual(type(server.socket), ssl.SSLSocket) - server.server_close() - sslctx = ssl.create_default_context() - server = ModbusTlsServer(context=None, identity=identity, - sslctx=sslctx, bind_and_activate=False) - server.server_activate() - self.assertEqual(server.sslctx, sslctx) + self.assertEqual(server.sslctx.verify_mode, ssl.CERT_REQUIRED) self.assertEqual(type(server.socket), ssl.SSLSocket) server.server_close()