From ce197055061998b77ea75099e77a17fd5b32b259 Mon Sep 17 00:00:00 2001 From: Jian-Hong Pan Date: Wed, 29 Jan 2020 11:21:24 +0800 Subject: [PATCH] Add TLS feature for Modbus asynchronous (#470) * Add TLS feature for Modbus asynchronous client Since we have Modbus TLS client in synchronous mode, we can also implement Modbus TLS client in asynchronous mode with ASYNC_IO. * Add TLS feature for Modbus asynchronous server Since we have Modbus TLS server in synchronous mode, we can also implement Modbus TLS server in asynchronous mode with ASYNC_IO. --- examples/common/asyncio_server.py | 7 + .../asynchronous_asyncio_modbus_tls_client.py | 40 +++++ .../client/asynchronous/asyncio/__init__.py | 84 +++++++++- pymodbus/client/asynchronous/factory/tls.py | 60 +++++++ pymodbus/client/asynchronous/tls.py | 52 +++++++ pymodbus/server/asyncio.py | 147 +++++++++++++++++- test/test_client_async.py | 27 +++- test/test_server_asyncio.py | 51 +++++- 8 files changed, 464 insertions(+), 4 deletions(-) create mode 100755 examples/contrib/asynchronous_asyncio_modbus_tls_client.py create mode 100644 pymodbus/client/asynchronous/factory/tls.py create mode 100644 pymodbus/client/asynchronous/tls.py diff --git a/examples/common/asyncio_server.py b/examples/common/asyncio_server.py index be0189b8a..153b91c19 100755 --- a/examples/common/asyncio_server.py +++ b/examples/common/asyncio_server.py @@ -13,6 +13,7 @@ # --------------------------------------------------------------------------- # import asyncio from pymodbus.server.asyncio import StartTcpServer +from pymodbus.server.asyncio import StartTlsServer from pymodbus.server.asyncio import StartUdpServer from pymodbus.server.asyncio import StartSerialServer @@ -127,6 +128,12 @@ async def run_server(): # StartTcpServer(context, identity=identity, # framer=ModbusRtuFramer, address=("0.0.0.0", 5020)) + # Tls: + # await StartTlsServer(context, identity=identity, address=("localhost", 8020), + # certfile="server.crt", keyfile="server.key", + # allow_reuse_address=True, allow_reuse_port=True, + # defer_start=False) + # Udp: # server = await StartUdpServer(context, identity=identity, address=("0.0.0.0", 5020), # allow_reuse_address=True, defer_start=True) diff --git a/examples/contrib/asynchronous_asyncio_modbus_tls_client.py b/examples/contrib/asynchronous_asyncio_modbus_tls_client.py new file mode 100755 index 000000000..d5a973d84 --- /dev/null +++ b/examples/contrib/asynchronous_asyncio_modbus_tls_client.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python +""" +Simple Asynchronous Modbus TCP over TLS client +--------------------------------------------------------------------------- + +This is a simple example of writing a asynchronous modbus TCP over TLS client +that uses Python builtin module ssl - TLS/SSL wrapper for socket objects for +the TLS feature and asyncio. +""" +# -------------------------------------------------------------------------- # +# import neccessary libraries +# -------------------------------------------------------------------------- # +import ssl +from pymodbus.client.asynchronous.tls import AsyncModbusTLSClient +from pymodbus.client.asynchronous.schedulers import ASYNC_IO + +# -------------------------------------------------------------------------- # +# the TLS detail security can be set in SSLContext which is the context here +# -------------------------------------------------------------------------- # +context = ssl.create_default_context() +context.options |= ssl.OP_NO_SSLv2 +context.options |= ssl.OP_NO_SSLv3 +context.options |= ssl.OP_NO_TLSv1 +context.options |= ssl.OP_NO_TLSv1_1 + +async def start_async_test(client): + result = await client.read_coils(1, 8) + print(result.bits) + await client.write_coils(1, [False]*3) + result = await client.read_coils(1, 8) + print(result.bits) + +if __name__ == '__main__': +# -------------------------------------------------------------------------- # +# pass SSLContext which is the context here to ModbusTcpClient() +# -------------------------------------------------------------------------- # + loop, client = AsyncModbusTLSClient(ASYNC_IO, 'test.host.com', 8020, + sslctx=context) + loop.run_until_complete(start_async_test(client.protocol)) + loop.close() diff --git a/pymodbus/client/asynchronous/asyncio/__init__.py b/pymodbus/client/asynchronous/asyncio/__init__.py index d83f6eeee..836ef9671 100644 --- a/pymodbus/client/asynchronous/asyncio/__init__.py +++ b/pymodbus/client/asynchronous/asyncio/__init__.py @@ -4,16 +4,17 @@ import socket import asyncio import functools +import ssl from pymodbus.exceptions import ConnectionException from pymodbus.client.asynchronous.mixins import AsyncModbusClientMixin from pymodbus.compat import byte2int +from pymodbus.transaction import FifoTransactionManager import logging _logger = logging.getLogger(__name__) DGRAM_TYPE = socket.SocketKind.SOCK_DGRAM - class BaseModbusAsyncClientProtocol(AsyncModbusClientMixin): """ Asyncio specific implementation of asynchronous modbus client protocol. @@ -423,6 +424,66 @@ def protocol_lost_connection(self, protocol): ' callback called while not connected.') +class ReconnectingAsyncioModbusTlsClient(ReconnectingAsyncioModbusTcpClient): + """ + Client to connect to modbus device repeatedly over TLS." + """ + def __init__(self, protocol_class=None, loop=None, framer=None): + """ + Initialize ReconnectingAsyncioModbusTcpClient + :param protocol_class: Protocol used to talk to modbus device. + :param loop: Event loop to use + """ + self.framer = framer + ReconnectingAsyncioModbusTcpClient.__init__(self, protocol_class, loop) + + @asyncio.coroutine + def start(self, host, port=802, sslctx=None, server_hostname=None): + """ + Initiates connection to start client + :param host: + :param port: + :param sslctx: + :param server_hostname: + :return: + """ + self.sslctx = sslctx + if self.sslctx is None: + self.sslctx = ssl.create_default_context() + # According to MODBUS/TCP Security Protocol Specification, it is + # TLSv2 at least + self.sslctx.options |= ssl.OP_NO_TLSv1_1 + self.sslctx.options |= ssl.OP_NO_TLSv1 + self.sslctx.options |= ssl.OP_NO_SSLv3 + self.sslctx.options |= ssl.OP_NO_SSLv2 + self.server_hostname = server_hostname + yield from ReconnectingAsyncioModbusTcpClient.start(self, host, port) + + @asyncio.coroutine + def _connect(self): + _logger.debug('Connecting.') + try: + yield from self.loop.create_connection(self._create_protocol, + self.host, + self.port, + ssl=self.sslctx, + server_hostname=self.server_hostname) + except Exception as ex: + _logger.warning('Failed to connect: %s' % ex) + asyncio.ensure_future(self._reconnect(), loop=self.loop) + else: + _logger.info('Connected to %s:%s.' % (self.host, self.port)) + self.reset_delay() + + def _create_protocol(self): + """ + Factory function to create initialized protocol instance. + """ + protocol = self.protocol_class(framer=self.framer) + protocol.transaction = FifoTransactionManager(self) + protocol.factory = self + return protocol + class ReconnectingAsyncioModbusUdpClient(object): """ Client to connect to modbus device repeatedly over UDP. @@ -774,6 +835,27 @@ def init_tcp_client(proto_cls, loop, host, port, **kwargs): return client +@asyncio.coroutine +def init_tls_client(proto_cls, loop, host, port, sslctx=None, + server_hostname=None, framer=None, **kwargs): + """ + Helper function to initialize tcp client + :param proto_cls: + :param loop: + :param host: + :param port: + :param sslctx: + :param server_hostname: + :param framer: + :param kwargs: + :return: + """ + client = ReconnectingAsyncioModbusTlsClient(protocol_class=proto_cls, + loop=loop, framer=framer) + yield from client.start(host, port, sslctx, server_hostname) + return client + + @asyncio.coroutine def init_udp_client(proto_cls, loop, host, port, **kwargs): """ diff --git a/pymodbus/client/asynchronous/factory/tls.py b/pymodbus/client/asynchronous/factory/tls.py new file mode 100644 index 000000000..0dfa81c07 --- /dev/null +++ b/pymodbus/client/asynchronous/factory/tls.py @@ -0,0 +1,60 @@ +""" +Factory to create asynchronous tls clients based on asyncio +""" +from __future__ import unicode_literals +from __future__ import absolute_import + +import logging + +from pymodbus.client.asynchronous import schedulers +from pymodbus.client.asynchronous.thread import EventLoopThread +from pymodbus.constants import Defaults + +LOGGER = logging.getLogger(__name__) + +def async_io_factory(host="127.0.0.1", port=Defaults.TLSPort, sslctx=None, + server_hostname=None, framer=None, source_address=None, + timeout=None, **kwargs): + """ + Factory to create asyncio based asynchronous tls clients + :param host: Host IP address + :param port: Port + :param sslctx: The SSLContext to use for TLS (default None and auto create) + :param server_hostname: Target server's name matched for certificate + :param framer: Modbus Framer + :param source_address: Bind address + :param timeout: Timeout in seconds + :param kwargs: + :return: asyncio event loop and tcp client + """ + import asyncio + from pymodbus.client.asynchronous.asyncio import init_tls_client + loop = kwargs.get("loop") or asyncio.new_event_loop() + proto_cls = kwargs.get("proto_cls", None) + if not loop.is_running(): + asyncio.set_event_loop(loop) + cor = init_tls_client(proto_cls, loop, host, port, sslctx, server_hostname, + framer) + client = loop.run_until_complete(asyncio.gather(cor))[0] + else: + cor = init_tls_client(proto_cls, loop, host, port, sslctx, server_hostname, + framer) + future = asyncio.run_coroutine_threadsafe(cor, loop=loop) + client = future.result() + + return loop, client + + +def get_factory(scheduler): + """ + Gets protocol factory based on the backend scheduler being used + :param scheduler: ASYNC_IO + :return + """ + if scheduler == schedulers.ASYNC_IO: + return async_io_factory + else: + LOGGER.warning("Allowed Schedulers: {}".format( + schedulers.ASYNC_IO + )) + raise Exception("Invalid Scheduler '{}'".format(scheduler)) diff --git a/pymodbus/client/asynchronous/tls.py b/pymodbus/client/asynchronous/tls.py new file mode 100644 index 000000000..d2412ff54 --- /dev/null +++ b/pymodbus/client/asynchronous/tls.py @@ -0,0 +1,52 @@ +from __future__ import unicode_literals +from __future__ import absolute_import + +import logging +from pymodbus.client.asynchronous.factory.tls import get_factory +from pymodbus.constants import Defaults +from pymodbus.compat import IS_PYTHON3, PYTHON_VERSION +from pymodbus.client.asynchronous.schedulers import ASYNC_IO +from pymodbus.factory import ClientDecoder +from pymodbus.transaction import ModbusTlsFramer + +logger = logging.getLogger(__name__) + + +class AsyncModbusTLSClient(object): + """ + Actual Async TLS Client to be used. + + To use do:: + + from pymodbus.client.asynchronous.tls import AsyncModbusTLSClient + """ + def __new__(cls, scheduler, host="127.0.0.1", port=Defaults.TLSPort, + framer=None, sslctx=None, server_hostname=None, + source_address=None, timeout=None, **kwargs): + """ + Scheduler to use: + - async_io (asyncio) + :param scheduler: Backend to use + :param host: Host IP address + :param port: Port + :param framer: Modbus Framer to use + :param sslctx: The SSLContext to use for TLS (default None and auto create) + :param server_hostname: Target server's name matched for certificate + :param source_address: source address specific to underlying backend + :param timeout: Time out in seconds + :param kwargs: Other extra args specific to Backend being used + :return: + """ + if (not (IS_PYTHON3 and PYTHON_VERSION >= (3, 4)) + and scheduler == ASYNC_IO): + logger.critical("ASYNCIO is supported only on python3") + import sys + sys.exit(1) + framer = framer or ModbusTlsFramer(ClientDecoder()) + factory_class = get_factory(scheduler) + yieldable = factory_class(host=host, port=port, sslctx=sslctx, + server_hostname=server_hostname, + framer=framer, source_address=source_address, + timeout=timeout, **kwargs) + return yieldable + diff --git a/pymodbus/server/asyncio.py b/pymodbus/server/asyncio.py index 50ccf97d1..c8bc8d01f 100755 --- a/pymodbus/server/asyncio.py +++ b/pymodbus/server/asyncio.py @@ -5,6 +5,7 @@ """ from binascii import b2a_hex import socket +import ssl import traceback import asyncio @@ -427,6 +428,111 @@ def server_close(self): self.server.close() +class ModbusTlsServer(ModbusTcpServer): + """ + A modbus threaded tls socket server + + We inherit and overload the socket server so that we + can control the client threads as well as have a single + server context instance. + """ + + def __init__(self, + context, + framer=None, + identity=None, + address=None, + sslctx=None, + certfile=None, + keyfile=None, + handler=None, + allow_reuse_address=False, + allow_reuse_port=False, + defer_start=False, + backlog=20, + loop=None, + **kwargs): + """ Overloaded initializer for the socket server + + If the identify structure is not passed in, the ModbusControlBlock + uses its own empty structure. + + :param context: The ModbusServerContext datastore + :param framer: The framer strategy to use + :param identity: An optional identify structure + :param address: An optional (interface, port) to bind to. + :param sslctx: The SSLContext to use for TLS (default None and auto + create) + :param certfile: The cert file path for TLS (used if sslctx is None) + :param keyfile: The key file path for TLS (used if sslctx is None) + :param handler: A handler for each client session; default is + ModbusConnectedRequestHandler. The handler class + receives connection create/teardown events + :param allow_reuse_address: Whether the server will allow the + reuse of an address. + :param allow_reuse_port: Whether the server will allow the + reuse of a port. + :param backlog: is the maximum number of queued connections + passed to listen(). Defaults to 20, increase if many + connections are being made and broken to your Modbus slave + :param loop: optional asyncio event loop to run in. Will default to + asyncio.get_event_loop() supplied value if None. + :param ignore_missing_slaves: True to not send errors on a request + to a missing slave + :param broadcast_enable: True to treat unit_id 0 as broadcast address, + False to treat 0 as any other unit_id + """ + self.active_connections = {} + self.loop = loop or asyncio.get_event_loop() + self.allow_reuse_address = allow_reuse_address + self.decoder = ServerDecoder() + self.framer = framer or ModbusTlsFramer + self.context = context or ModbusServerContext() + self.control = ModbusControlBlock() + self.address = address or ("", Defaults.Port) + self.handler = handler or ModbusConnectedRequestHandler + self.handler.server = self + self.ignore_missing_slaves = kwargs.get('ignore_missing_slaves', + Defaults.IgnoreMissingSlaves) + self.broadcast_enable = kwargs.get('broadcast_enable', + Defaults.broadcast_enable) + + if isinstance(identity, ModbusDeviceIdentification): + self.control.Identity.update(identity) + + self.sslctx = sslctx + if self.sslctx is None: + self.sslctx = ssl.create_default_context() + self.sslctx.load_cert_chain(certfile=certfile, keyfile=keyfile) + # According to MODBUS/TCP Security Protocol Specification, it is + # TLSv2 at least + self.sslctx.options |= ssl.OP_NO_TLSv1_1 + self.sslctx.options |= ssl.OP_NO_TLSv1 + self.sslctx.options |= ssl.OP_NO_SSLv3 + self.sslctx.options |= ssl.OP_NO_SSLv2 + self.sslctx.verify_mode = ssl.CERT_OPTIONAL + self.sslctx.check_hostname = False + + self.serving = self.loop.create_future() # asyncio future that will be done once server has started + self.server = None # constructors cannot be declared async, so we have to defer the initialization of the server + if PYTHON_VERSION >= (3, 7): + # start_serving is new in version 3.7 + self.server_factory = self.loop.create_server(lambda : self.handler(self), + *self.address, + ssl=self.sslctx, + reuse_address=allow_reuse_address, + reuse_port=allow_reuse_port, + backlog=backlog, + start_serving=not defer_start) + else: + self.server_factory = self.loop.create_server(lambda : self.handler(self), + *self.address, + ssl=self.sslctx, + reuse_address=allow_reuse_address, + reuse_port=allow_reuse_port, + backlog=backlog) + + class ModbusUdpServer: """ A modbus threaded udp socket server @@ -568,6 +674,45 @@ async def StartTcpServer(context=None, identity=None, address=None, +async def StartTlsServer(context=None, identity=None, address=None, sslctx=None, + certfile=None, keyfile=None, allow_reuse_address=False, + allow_reuse_port=False, custom_functions=[], + defer_start=True, **kwargs): + """ A factory to start and run a tls modbus server + + :param context: The ModbusServerContext datastore + :param identity: An optional identify structure + :param address: An optional (interface, port) to bind to. + :param sslctx: The SSLContext to use for TLS (default None and auto create) + :param certfile: The cert file path for TLS (used if sslctx is None) + :param keyfile: The key file path for TLS (used if sslctx is None) + :param allow_reuse_address: Whether the server will allow the reuse of an + address. + :param allow_reuse_port: Whether the server will allow the reuse of a port. + :param custom_functions: An optional list of custom function classes + supported by server instance. + :param defer_start: if set, a coroutine which can be started and stopped + will be returned. Otherwise, the server will be immediately spun + up without the ability to shut it off from within the asyncio loop + :param ignore_missing_slaves: True to not send errors on a request to a + missing slave + :return: an initialized but inactive server object coroutine + """ + framer = kwargs.pop("framer", ModbusTlsFramer) + server = ModbusTlsServer(context, framer, identity, address, sslctx, + certfile, keyfile, + allow_reuse_address=allow_reuse_address, + allow_reuse_port=allow_reuse_port, **kwargs) + + for f in custom_functions: + server.decoder.register(f) # pragma: no cover + + if not defer_start: + await server.serve_forever() + + return server + + async def StartUdpServer(context=None, identity=None, address=None, custom_functions=[], defer_start=True, **kwargs): @@ -637,6 +782,6 @@ def StopServer(): __all__ = [ - "StartTcpServer", "StartUdpServer", "StartSerialServer" + "StartTcpServer", "StartTlsServer", "StartUdpServer", "StartSerialServer" ] diff --git a/test/test_client_async.py b/test/test_client_async.py index 337b42593..07e32ca58 100644 --- a/test/test_client_async.py +++ b/test/test_client_async.py @@ -5,6 +5,7 @@ if IS_PYTHON3 and PYTHON_VERSION >= (3, 4): from unittest.mock import patch, Mock, MagicMock import asyncio + from pymodbus.client.asynchronous.asyncio import ReconnectingAsyncioModbusTlsClient from pymodbus.client.asynchronous.asyncio import AsyncioModbusSerialClient from serial_asyncio import SerialTransport else: @@ -14,6 +15,7 @@ from pymodbus.client.asynchronous.serial import AsyncModbusSerialClient from pymodbus.client.asynchronous.tcp import AsyncModbusTCPClient +from pymodbus.client.asynchronous.tls import AsyncModbusTLSClient from pymodbus.client.asynchronous.udp import AsyncModbusUDPClient from pymodbus.client.asynchronous.tornado import AsyncModbusSerialClient as AsyncTornadoModbusSerialClient @@ -22,9 +24,11 @@ from pymodbus.client.asynchronous import schedulers from pymodbus.factory import ClientDecoder from pymodbus.exceptions import ConnectionException -from pymodbus.transaction import ModbusSocketFramer, ModbusRtuFramer, ModbusAsciiFramer, ModbusBinaryFramer +from pymodbus.transaction import ModbusSocketFramer, ModbusTlsFramer, ModbusRtuFramer, ModbusAsciiFramer, ModbusBinaryFramer from pymodbus.client.asynchronous.twisted import ModbusSerClientProtocol +import ssl + IS_DARWIN = platform.system().lower() == "darwin" OSX_SIERRA = LooseVersion("10.12") if IS_DARWIN: @@ -104,6 +108,27 @@ def testTcpAsyncioClient(self, mock_gather, mock_loop): """ pytest.skip("TBD") + # -----------------------------------------------------------------------# + # Test TLS Client client + # -----------------------------------------------------------------------# + @pytest.mark.skipif(not IS_PYTHON3 or PYTHON_VERSION < (3, 4), + reason="requires python3.4 or above") + def testTlsAsyncioClient(self): + """ + Test the TLS AsyncIO client + """ + loop, client = AsyncModbusTLSClient(schedulers.ASYNC_IO) + assert(isinstance(client, ReconnectingAsyncioModbusTlsClient)) + assert(isinstance(client.framer, ModbusTlsFramer)) + assert(isinstance(client.sslctx, ssl.SSLContext)) + assert(client.port == 802) + + def handle_failure(failure): + assert(isinstance(failure.exception(), ConnectionException)) + + client.stop() + assert(client.host is None) + # -----------------------------------------------------------------------# # Test UDP client # -----------------------------------------------------------------------# diff --git a/test/test_server_asyncio.py b/test/test_server_asyncio.py index 372c96479..4bcbe5ca7 100755 --- a/test/test_server_asyncio.py +++ b/test/test_server_asyncio.py @@ -11,7 +11,7 @@ from pymodbus.device import ModbusDeviceIdentification from pymodbus.factory import ServerDecoder from pymodbus.server.asynchronous import ModbusTcpProtocol, ModbusUdpProtocol -from pymodbus.server.asyncio import StartTcpServer, StartUdpServer, StartSerialServer, StopServer, ModbusServerFactory +from pymodbus.server.asyncio import StartTcpServer, StartTlsServer, StartUdpServer, StartSerialServer, StopServer, ModbusServerFactory from pymodbus.server.asyncio import ModbusConnectedRequestHandler, ModbusBaseRequestHandler from pymodbus.datastore import ModbusSequentialDataBlock from pymodbus.datastore import ModbusSlaveContext, ModbusServerContext @@ -20,6 +20,7 @@ from pymodbus.exceptions import NoSuchSlaveException, ModbusIOException import sys +import ssl #---------------------------------------------------------------------------# # Fixture #---------------------------------------------------------------------------# @@ -399,6 +400,54 @@ def eof_received(self): server.server_close() + #-----------------------------------------------------------------------# + # Test ModbusTlsProtocol + #-----------------------------------------------------------------------# + @asyncio.coroutine + def testStartTlsServer(self): + ''' Test that the modbus tls asyncio server starts correctly ''' + with patch.object(ssl.SSLContext, 'load_cert_chain') as mock_method: + identity = ModbusDeviceIdentification(info={0x00: 'VendorName'}) + self.loop = asynctest.Mock(self.loop) + server = yield from StartTlsServer(context=self.context,loop=self.loop,identity=identity) + self.assertEqual(server.control.Identity.VendorName, 'VendorName') + self.assertIsNotNone(server.sslctx) + if PYTHON_VERSION >= (3, 6): + self.loop.create_server.assert_called_once() + + @pytest.mark.skipif(PYTHON_VERSION < (3, 7), reason="requires python3.7 or above") + @asyncio.coroutine + def testTlsServerServeNoDefer(self): + ''' Test StartTcpServer without deferred start (immediate execution of server) ''' + with patch('asyncio.base_events.Server.serve_forever', new_callable=asynctest.CoroutineMock) as serve: + with patch.object(ssl.SSLContext, 'load_cert_chain') as mock_method: + server = yield from StartTlsServer(context=self.context,address=("127.0.0.1", 0), loop=self.loop, defer_start=False) + serve.assert_awaited() + + @pytest.mark.skipif(PYTHON_VERSION < (3, 7), reason="requires python3.7 or above") + @asyncio.coroutine + def testTlsServerServeForever(self): + ''' Test StartTcpServer serve_forever() method ''' + with patch('asyncio.base_events.Server.serve_forever', new_callable=asynctest.CoroutineMock) as serve: + with patch.object(ssl.SSLContext, 'load_cert_chain') as mock_method: + server = yield from StartTlsServer(context=self.context,address=("127.0.0.1", 0), loop=self.loop) + yield from server.serve_forever() + serve.assert_awaited() + + @asyncio.coroutine + def testTlsServerServeForeverTwice(self): + ''' Call on serve_forever() twice should result in a runtime error ''' + with patch.object(ssl.SSLContext, 'load_cert_chain') as mock_method: + server = yield from StartTlsServer(context=self.context,address=("127.0.0.1", 0), loop=self.loop) + if PYTHON_VERSION >= (3, 7): + server_task = asyncio.create_task(server.serve_forever()) + else: + server_task = asyncio.ensure_future(server.serve_forever()) + yield from server.serving + with self.assertRaises(RuntimeError): + yield from server.serve_forever() + server.server_close() + #-----------------------------------------------------------------------# # Test ModbusUdpProtocol