diff --git a/.coveragerc b/.coveragerc index dbdb75230..472afaae6 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,4 +1,5 @@ [run] omit = pymodbus/repl/* - pymodbus/internal/* \ No newline at end of file + pymodbus/internal/* + pymodbus/server/asyncio.py \ No newline at end of file diff --git a/.travis.yml b/.travis.yml index 06a815834..b00c82637 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,14 +4,12 @@ matrix: include: - os: linux python: "2.7" - - os: linux - python: "3.4" - os: linux python: "3.5" - os: linux python: "3.6" -# - os: linux -# python: "3.7" + - os: linux + python: "3.7" - os: osx osx_image: xcode8.3 language: generic diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0e9a7ddf8..ac308fcae 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,20 @@ +Version 2.3.0 +----------------------------------------------------------- +* Support Modbus TLS (client / server) +* Distribute license with source +* BinaryPayloadDecoder/Encoder now supports float16 on python3.6 and above +* Fix asyncio UDP client/server +* Minor cosmetic updates + +Version 2.3.0rc1 +----------------------------------------------------------- +* Asyncio Server implementation (Python 3.7 and above only) +* Bug fix for DiagnosticStatusResponse when odd sized response is received +* Remove Pycrypto from dependencies and include cryptodome instead +* Remove `SIX` requirement pinned to exact version. +* Minor bug-fixes in documentations. + + Version 2.2.0 ----------------------------------------------------------- **NOTE: Supports python 3.7, async client is now moved to pymodbus/client/asychronous** diff --git a/doc/LICENSE b/LICENSE similarity index 100% rename from doc/LICENSE rename to LICENSE diff --git a/MANIFEST.in b/MANIFEST.in index 7b0fe70cc..9e8f9ae0b 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,4 @@ include requirements.txt include README.rst -include CHANGELOG.rst \ No newline at end of file +include CHANGELOG.rst +include LICENSE \ No newline at end of file diff --git a/README.rst b/README.rst index 2641414ef..95303f82b 100644 --- a/README.rst +++ b/README.rst @@ -99,7 +99,7 @@ get lost in the noise: http://groups.google.com/group/pymodbus or at gitter: https://gitter.im/pymodbus_dev/Lobby ------------------------------------------------------------ -Pymodbus REPL (Read Evaluate Procee Loop) +Pymodbus REPL (Read Evaluate Print Loop) ------------------------------------------------------------ Starting with Pymodbus 2.x, pymodbus library comes with handy Pymodbus REPL to quickly run the modbus clients in tcp/rtu modes. @@ -205,4 +205,4 @@ Pymodbus is built on top of code developed from/by: * Hynek Petrak, https://github.com/HynekPetrak * Twisted Matrix -Released under the BSD License +Released under the `BSD License `_ diff --git a/doc/source/library/REPL.md b/doc/source/library/REPL.md index 1e7969af6..48a426993 100644 --- a/doc/source/library/REPL.md +++ b/doc/source/library/REPL.md @@ -6,7 +6,7 @@ Depends on [prompt_toolkit](https://python-prompt-toolkit.readthedocs.io/en/stab Install dependencies ``` -$ pip install click prompt_toolkit --upgarde +$ pip install click prompt_toolkit --upgrade ``` Or diff --git a/examples/common/async_asyncio_client.py b/examples/common/async_asyncio_client.py index 5aefd6bd0..ab0844505 100644 --- a/examples/common/async_asyncio_client.py +++ b/examples/common/async_asyncio_client.py @@ -16,8 +16,8 @@ # Import the required asynchronous client # ----------------------------------------------------------------------- # from pymodbus.client.asynchronous.tcp import AsyncModbusTCPClient as ModbusClient - # from pymodbus.client.asynchronous.udp import ( - # AsyncModbusUDPClient as ModbusClient) + from pymodbus.client.asynchronous.udp import ( + AsyncModbusUDPClient as ModbusClient) from pymodbus.client.asynchronous import schedulers else: @@ -141,6 +141,7 @@ def run_with_not_running_loop(): log.debug("------------------------------------------------------") loop = asyncio.new_event_loop() assert not loop.is_running() + asyncio.set_event_loop(loop) new_loop, client = ModbusClient(schedulers.ASYNC_IO, port=5020, loop=loop) loop.run_until_complete(start_async_test(client.protocol)) loop.close() @@ -191,9 +192,12 @@ def run_with_no_loop(): ModbusClient Factory creates a loop. :return: """ + log.debug("---------------------RUN_WITH_NO_LOOP-----------------") loop, client = ModbusClient(schedulers.ASYNC_IO, port=5020) loop.run_until_complete(start_async_test(client.protocol)) loop.close() + log.debug("--------DONE RUN_WITH_NO_LOOP-------------") + log.debug("") if __name__ == '__main__': @@ -207,5 +211,5 @@ def run_with_no_loop(): # Run with already running loop run_with_already_running_loop() - log.debug("---------------------RUN_WITH_NO_LOOP-----------------") + log.debug("") diff --git a/examples/common/asynchronous_server.py b/examples/common/asynchronous_server.py index be42f3996..15e9b70c2 100755 --- a/examples/common/asynchronous_server.py +++ b/examples/common/asynchronous_server.py @@ -108,7 +108,7 @@ def run_async_server(): identity.VendorUrl = 'http://github.com/bashwork/pymodbus/' identity.ProductName = 'Pymodbus Server' identity.ModelName = 'Pymodbus Server' - identity.MajorMinorRevision = '2.2.0' + identity.MajorMinorRevision = '2.3.0' # ----------------------------------------------------------------------- # # run the server you want diff --git a/examples/common/asyncio_server.py b/examples/common/asyncio_server.py new file mode 100755 index 000000000..be0189b8a --- /dev/null +++ b/examples/common/asyncio_server.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python +""" +Pymodbus Asyncio Server Example +-------------------------------------------------------------------------- + +The asyncio server is implemented in pure python without any third +party libraries (unless you need to use the serial protocols which require +asyncio-pyserial). This is helpful in constrained or old environments where using +twisted is just not feasible. What follows is an example of its use: +""" +# --------------------------------------------------------------------------- # +# import the various server implementations +# --------------------------------------------------------------------------- # +import asyncio +from pymodbus.server.asyncio import StartTcpServer +from pymodbus.server.asyncio import StartUdpServer +from pymodbus.server.asyncio import StartSerialServer + +from pymodbus.device import ModbusDeviceIdentification +from pymodbus.datastore import ModbusSequentialDataBlock, ModbusSparseDataBlock +from pymodbus.datastore import ModbusSlaveContext, ModbusServerContext + +from pymodbus.transaction import ModbusRtuFramer, ModbusBinaryFramer +# --------------------------------------------------------------------------- # +# configure the service logging +# --------------------------------------------------------------------------- # +import logging +FORMAT = ('%(asctime)-15s %(threadName)-15s' + ' %(levelname)-8s %(module)-15s:%(lineno)-8s %(message)s') +logging.basicConfig(format=FORMAT) +log = logging.getLogger() +log.setLevel(logging.DEBUG) + + +async def run_server(): + # ----------------------------------------------------------------------- # + # initialize your data store + # ----------------------------------------------------------------------- # + # The datastores only respond to the addresses that they are initialized to + # Therefore, if you initialize a DataBlock to addresses of 0x00 to 0xFF, a + # request to 0x100 will respond with an invalid address exception. This is + # because many devices exhibit this kind of behavior (but not all):: + # + # block = ModbusSequentialDataBlock(0x00, [0]*0xff) + # + # Continuing, you can choose to use a sequential or a sparse DataBlock in + # your data context. The difference is that the sequential has no gaps in + # the data while the sparse can. Once again, there are devices that exhibit + # both forms of behavior:: + # + # block = ModbusSparseDataBlock({0x00: 0, 0x05: 1}) + # block = ModbusSequentialDataBlock(0x00, [0]*5) + # + # Alternately, you can use the factory methods to initialize the DataBlocks + # or simply do not pass them to have them initialized to 0x00 on the full + # address range:: + # + # store = ModbusSlaveContext(di = ModbusSequentialDataBlock.create()) + # store = ModbusSlaveContext() + # + # Finally, you are allowed to use the same DataBlock reference for every + # table or you may use a separate DataBlock for each table. + # This depends if you would like functions to be able to access and modify + # the same data or not:: + # + # block = ModbusSequentialDataBlock(0x00, [0]*0xff) + # store = ModbusSlaveContext(di=block, co=block, hr=block, ir=block) + # + # The server then makes use of a server context that allows the server to + # respond with different slave contexts for different unit ids. By default + # it will return the same context for every unit id supplied (broadcast + # mode). + # However, this can be overloaded by setting the single flag to False and + # then supplying a dictionary of unit id to context mapping:: + # + # slaves = { + # 0x01: ModbusSlaveContext(...), + # 0x02: ModbusSlaveContext(...), + # 0x03: ModbusSlaveContext(...), + # } + # context = ModbusServerContext(slaves=slaves, single=False) + # + # The slave context can also be initialized in zero_mode which means that a + # request to address(0-7) will map to the address (0-7). The default is + # False which is based on section 4.4 of the specification, so address(0-7) + # will map to (1-8):: + # + # store = ModbusSlaveContext(..., zero_mode=True) + # ----------------------------------------------------------------------- # + store = ModbusSlaveContext( + di=ModbusSequentialDataBlock(0, [17]*100), + co=ModbusSequentialDataBlock(0, [17]*100), + hr=ModbusSequentialDataBlock(0, [17]*100), + ir=ModbusSequentialDataBlock(0, [17]*100)) + + context = ModbusServerContext(slaves=store, single=True) + + # ----------------------------------------------------------------------- # + # initialize the server information + # ----------------------------------------------------------------------- # + # If you don't set this or any fields, they are defaulted to empty strings. + # ----------------------------------------------------------------------- # + identity = ModbusDeviceIdentification() + identity.VendorName = 'Pymodbus' + identity.ProductCode = 'PM' + identity.VendorUrl = 'http://github.com/riptideio/pymodbus/' + identity.ProductName = 'Pymodbus Server' + identity.ModelName = 'Pymodbus Server' + identity.MajorMinorRevision = '2.3.0' + + # ----------------------------------------------------------------------- # + # run the server you want + # ----------------------------------------------------------------------- # + # Tcp: + # immediately start serving: + await StartTcpServer(context, identity=identity, address=("0.0.0.0", 5020), allow_reuse_address=True, + defer_start=False) + + # deferred start: + # server = await StartTcpServer(context, identity=identity, address=("0.0.0.0", 5020), + # allow_reuse_address=True, defer_start=True) + # + # asyncio.get_event_loop().call_later(20, lambda : server.serve_forever) + # await server.serve_forever() + + # TCP with different framer + # StartTcpServer(context, identity=identity, + # framer=ModbusRtuFramer, address=("0.0.0.0", 5020)) + + # Udp: + # server = await StartUdpServer(context, identity=identity, address=("0.0.0.0", 5020), + # allow_reuse_address=True, defer_start=True) + # # + # await server.serve_forever() + + # !!! SERIAL SERVER NOT IMPLEMENTED !!! + # Ascii: + # StartSerialServer(context, identity=identity, + # port='/dev/ttyp0', timeout=1) + + # RTU: + # StartSerialServer(context, framer=ModbusRtuFramer, identity=identity, + # port='/dev/ttyp0', timeout=.005, baudrate=9600) + + # Binary + # StartSerialServer(context, + # identity=identity, + # framer=ModbusBinaryFramer, + # port='/dev/ttyp0', + # timeout=1) + + +if __name__ == "__main__": + asyncio.run(run_server()) + diff --git a/examples/common/callback_server.py b/examples/common/callback_server.py index d7f3a7bc4..325fbca56 100755 --- a/examples/common/callback_server.py +++ b/examples/common/callback_server.py @@ -132,7 +132,7 @@ def run_callback_server(): identity.VendorUrl = 'http://github.com/bashwork/pymodbus/' identity.ProductName = 'pymodbus Server' identity.ModelName = 'pymodbus Server' - identity.MajorMinorRevision = '2.2.0' + identity.MajorMinorRevision = '2.3.0' # ----------------------------------------------------------------------- # # run the server you want diff --git a/examples/common/custom_datablock.py b/examples/common/custom_datablock.py index 00b55fad0..089a27445 100755 --- a/examples/common/custom_datablock.py +++ b/examples/common/custom_datablock.py @@ -68,7 +68,7 @@ def run_custom_db_server(): identity.VendorUrl = 'http://github.com/bashwork/pymodbus/' identity.ProductName = 'pymodbus Server' identity.ModelName = 'pymodbus Server' - identity.MajorMinorRevision = '2.2.0' + identity.MajorMinorRevision = '2.3.0' # ----------------------------------------------------------------------- # # run the server you want diff --git a/examples/common/custom_synchronous_server.py b/examples/common/custom_synchronous_server.py index 191b6fac0..66f6f1b3c 100755 --- a/examples/common/custom_synchronous_server.py +++ b/examples/common/custom_synchronous_server.py @@ -101,7 +101,7 @@ def run_server(): identity.VendorUrl = 'http://github.com/riptideio/pymodbus/' identity.ProductName = 'Pymodbus Server' identity.ModelName = 'Pymodbus Server' - identity.MajorMinorRevision = '2.1.0' + identity.MajorMinorRevision = '2.3.0' # ----------------------------------------------------------------------- # # run the server you want diff --git a/examples/common/dbstore_update_server.py b/examples/common/dbstore_update_server.py index ba37520ac..ef467de0a 100644 --- a/examples/common/dbstore_update_server.py +++ b/examples/common/dbstore_update_server.py @@ -86,7 +86,7 @@ def run_dbstore_update_server(): identity.VendorUrl = 'http://github.com/bashwork/pymodbus/' identity.ProductName = 'pymodbus Server' identity.ModelName = 'pymodbus Server' - identity.MajorMinorRevision = '2.2.0' + identity.MajorMinorRevision = '2.3.0' # ----------------------------------------------------------------------- # # run the server you want diff --git a/examples/common/modbus_payload.py b/examples/common/modbus_payload.py index fad1f891f..aac1eee14 100755 --- a/examples/common/modbus_payload.py +++ b/examples/common/modbus_payload.py @@ -81,6 +81,8 @@ def run_binary_payload_ex(): builder.add_16bit_uint(0x1234) builder.add_32bit_int(-0x1234) builder.add_32bit_uint(0x12345678) + builder.add_16bit_float(12.34) + builder.add_16bit_float(-12.34) builder.add_32bit_float(22.34) builder.add_32bit_float(-22.34) builder.add_64bit_int(-0xDEADBEEF) @@ -144,6 +146,8 @@ def run_binary_payload_ex(): ('16uint', decoder.decode_16bit_uint()), ('32int', decoder.decode_32bit_int()), ('32uint', decoder.decode_32bit_uint()), + ('16float', decoder.decode_16bit_float()), + ('16float2', decoder.decode_16bit_float()), ('32float', decoder.decode_32bit_float()), ('32float2', decoder.decode_32bit_float()), ('64int', decoder.decode_64bit_int()), diff --git a/examples/common/modbus_payload_server.py b/examples/common/modbus_payload_server.py index d9d48d241..9f1cce5dc 100755 --- a/examples/common/modbus_payload_server.py +++ b/examples/common/modbus_payload_server.py @@ -78,7 +78,7 @@ def run_payload_server(): identity.VendorUrl = 'http://github.com/bashwork/pymodbus/' identity.ProductName = 'Pymodbus Server' identity.ModelName = 'Pymodbus Server' - identity.MajorMinorRevision = '2.2.0' + identity.MajorMinorRevision = '2.3.0' # ----------------------------------------------------------------------- # # run the server you want # ----------------------------------------------------------------------- # diff --git a/examples/common/synchronous_client.py b/examples/common/synchronous_client.py index 89f3f509d..e9fbfad5f 100755 --- a/examples/common/synchronous_client.py +++ b/examples/common/synchronous_client.py @@ -16,9 +16,9 @@ # --------------------------------------------------------------------------- # # import the various server implementations # --------------------------------------------------------------------------- # -# from pymodbus.client.sync import ModbusTcpClient as ModbusClient +from pymodbus.client.sync import ModbusTcpClient as ModbusClient # from pymodbus.client.sync import ModbusUdpClient as ModbusClient -from pymodbus.client.sync import ModbusSerialClient as ModbusClient +# from pymodbus.client.sync import ModbusSerialClient as ModbusClient # --------------------------------------------------------------------------- # # configure the client logging diff --git a/examples/common/synchronous_server.py b/examples/common/synchronous_server.py index d3e53b23a..e93d33a5f 100755 --- a/examples/common/synchronous_server.py +++ b/examples/common/synchronous_server.py @@ -12,6 +12,7 @@ # import the various server implementations # --------------------------------------------------------------------------- # from pymodbus.server.sync import StartTcpServer +from pymodbus.server.sync import StartTlsServer from pymodbus.server.sync import StartUdpServer from pymodbus.server.sync import StartSerialServer @@ -105,7 +106,7 @@ def run_server(): identity.VendorUrl = 'http://github.com/riptideio/pymodbus/' identity.ProductName = 'Pymodbus Server' identity.ModelName = 'Pymodbus Server' - identity.MajorMinorRevision = '2.2.0' + identity.MajorMinorRevision = '2.3.0' # ----------------------------------------------------------------------- # # run the server you want @@ -117,6 +118,10 @@ def run_server(): # StartTcpServer(context, identity=identity, # framer=ModbusRtuFramer, address=("0.0.0.0", 5020)) + # TLS + # StartTlsServer(context, identity=identity, certfile="server.crt", + # keyfile="server.key", address=("0.0.0.0", 8020)) + # Udp: # StartUdpServer(context, identity=identity, address=("0.0.0.0", 5020)) diff --git a/examples/common/updating_server.py b/examples/common/updating_server.py index 1393712a9..b5b04faa3 100755 --- a/examples/common/updating_server.py +++ b/examples/common/updating_server.py @@ -78,7 +78,7 @@ def run_updating_server(): identity.VendorUrl = 'http://github.com/bashwork/pymodbus/' identity.ProductName = 'pymodbus Server' identity.ModelName = 'pymodbus Server' - identity.MajorMinorRevision = '2.2.0' + identity.MajorMinorRevision = '2.3.0' # ----------------------------------------------------------------------- # # run the server you want diff --git a/examples/contrib/deviceinfo_showcase_server.py b/examples/contrib/deviceinfo_showcase_server.py index 53bc753a0..983bb7111 100755 --- a/examples/contrib/deviceinfo_showcase_server.py +++ b/examples/contrib/deviceinfo_showcase_server.py @@ -55,7 +55,7 @@ def run_server(): identity.VendorUrl = 'http://github.com/riptideio/pymodbus/' identity.ProductName = 'Pymodbus Server' identity.ModelName = 'Pymodbus Server' - identity.MajorMinorRevision = '2.2.0' + identity.MajorMinorRevision = '2.3.0' # ----------------------------------------------------------------------- # # Add an example which is long enough to force the ReadDeviceInformation diff --git a/examples/contrib/modbus_tls_client.py b/examples/contrib/modbus_tls_client.py new file mode 100755 index 000000000..98ad02a12 --- /dev/null +++ b/examples/contrib/modbus_tls_client.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python +""" +Simple Modbus TCP over TLS client +--------------------------------------------------------------------------- + +This is a simple example of writing a modbus TCP over TLS client that uses +Python builtin module ssl - TLS/SSL wrapper for socket objects for the TLS +feature. +""" +# -------------------------------------------------------------------------- # +# import neccessary libraries +# -------------------------------------------------------------------------- # +import ssl +from pymodbus.client.sync import ModbusTlsClient + +# -------------------------------------------------------------------------- # +# the TLS detail security can be set in SSLContext which is the context here +# -------------------------------------------------------------------------- # +context = ssl.create_default_context() +context.options |= ssl.OP_NO_SSLv2 +context.options |= ssl.OP_NO_SSLv3 +context.options |= ssl.OP_NO_TLSv1 +context.options |= ssl.OP_NO_TLSv1_1 + +# -------------------------------------------------------------------------- # +# pass SSLContext which is the context here to ModbusTcpClient() +# -------------------------------------------------------------------------- # +client = ModbusTlsClient('test.host.com', 8020, sslctx=context) +client.connect() + +result = client.read_coils(1, 8) +print(result.bits) + +client.close() diff --git a/examples/gui/bottle/frontend.py b/examples/gui/bottle/frontend.py index c3e7c10c0..3e79e0b46 100644 --- a/examples/gui/bottle/frontend.py +++ b/examples/gui/bottle/frontend.py @@ -277,7 +277,7 @@ def RunDebugModbusFrontend(server, port=8080): identity.VendorUrl = 'http://github.com/bashwork/pymodbus/' identity.ProductName = 'Pymodbus Server' identity.ModelName = 'Pymodbus Server' - identity.MajorMinorRevision = '2.2.0' + identity.MajorMinorRevision = '2.3.0' # ------------------------------------------------------------ # initialize the datastore diff --git a/pymodbus/client/asynchronous/asyncio/__init__.py b/pymodbus/client/asynchronous/asyncio/__init__.py index baa9f71ec..d83f6eeee 100644 --- a/pymodbus/client/asynchronous/asyncio/__init__.py +++ b/pymodbus/client/asynchronous/asyncio/__init__.py @@ -115,6 +115,9 @@ def connected(self): """ return self._connected + def write_transport(self, packet): + return self.transport.write(packet) + def execute(self, request, **kwargs): """ Starts the producer to send the next request to @@ -123,7 +126,7 @@ def execute(self, request, **kwargs): request.transaction_id = self.transaction.getNextTID() packet = self.framer.buildPacket(request) _logger.debug("send: " + " ".join([hex(byte2int(x)) for x in packet])) - self.transport.write(packet) + self.write_transport(packet) return self._buildResponse(request.transaction_id) def _dataReceived(self, data): @@ -206,6 +209,9 @@ def __init__(self, host=None, port=0, **kwargs): def datagram_received(self, data, addr): self._dataReceived(data) + def write_transport(self, packet): + return self.transport.sendto(packet) + class ReconnectingAsyncioModbusTcpClient(object): """ @@ -713,7 +719,7 @@ def connect(self): yield from create_serial_connection( self.loop, self._create_protocol, self.port, baudrate=self.baudrate, - bytesize=self.bytesize, stopbits=self.stopbits + bytesize=self.bytesize, stopbits=self.stopbits, parity=self.parity ) yield from self._connected_event.wait() _logger.info('Connected to %s', self.port) diff --git a/pymodbus/client/asynchronous/factory/udp.py b/pymodbus/client/asynchronous/factory/udp.py index d6dc75ed5..6578732e3 100644 --- a/pymodbus/client/asynchronous/factory/udp.py +++ b/pymodbus/client/asynchronous/factory/udp.py @@ -69,7 +69,11 @@ def async_io_factory(host="127.0.0.1", port=Defaults.Port, framer=None, loop = kwargs.get("loop") or asyncio.get_event_loop() proto_cls = kwargs.get("proto_cls", None) cor = init_udp_client(proto_cls, loop, host, port) - client = loop.run_until_complete(asyncio.gather(cor))[0] + if not loop.is_running(): + client = loop.run_until_complete(asyncio.gather(cor))[0] + else: + client = asyncio.run_coroutine_threadsafe(cor, loop=loop) + client = client.result() return loop, client diff --git a/pymodbus/client/sync.py b/pymodbus/client/sync.py index 04d7778e3..b3b3e197f 100644 --- a/pymodbus/client/sync.py +++ b/pymodbus/client/sync.py @@ -2,6 +2,7 @@ import select import serial import time +import ssl import sys from functools import partial from pymodbus.constants import Defaults @@ -13,6 +14,7 @@ from pymodbus.transaction import DictTransactionManager from pymodbus.transaction import ModbusSocketFramer, ModbusBinaryFramer from pymodbus.transaction import ModbusAsciiFramer, ModbusRtuFramer +from pymodbus.transaction import ModbusTlsFramer from pymodbus.client.common import ModbusClientMixin # --------------------------------------------------------------------------- # @@ -260,18 +262,21 @@ def _recv(self, size): else: recv_size = size - data = b'' + data = [] + data_length = 0 time_ = time.time() end = time_ + timeout while recv_size > 0: ready = select.select([self.socket], [], [], end - time_) if ready[0]: - data += self.socket.recv(recv_size) + recv_data = self.socket.recv(recv_size) + data.append(recv_data) + data_length += len(recv_data) time_ = time.time() # If size isn't specified continue to read until timeout expires. if size: - recv_size = size - len(data) + recv_size = size - data_length # Timeout is reduced also if some data has been received in order # to avoid infinite loops when there isn't an expected response @@ -279,7 +284,7 @@ def _recv(self, size): if time_ > end: break - return data + return b"".join(data) def is_socket_open(self): return True if self.socket is not None else False @@ -297,6 +302,116 @@ def __repr__(self): "port={self.port}, timeout={self.timeout}>" ).format(self.__class__.__name__, hex(id(self)), self=self) +# --------------------------------------------------------------------------- # +# Modbus TLS Client Transport Implementation +# --------------------------------------------------------------------------- # + + +class ModbusTlsClient(ModbusTcpClient): + """ Implementation of a modbus tls client + """ + + def __init__(self, host='localhost', port=Defaults.TLSPort, sslctx=None, + framer=ModbusTlsFramer, **kwargs): + """ Initialize a client instance + + :param host: The host to connect to (default localhost) + :param port: The modbus port to connect to (default 802) + :param sslctx: The SSLContext to use for TLS (default None and auto create) + :param source_address: The source address tuple to bind to (default ('', 0)) + :param timeout: The timeout to use for this socket (default Defaults.Timeout) + :param framer: The modbus framer to use (default ModbusSocketFramer) + + .. note:: The host argument will accept ipv4 and ipv6 hosts + """ + self.sslctx = sslctx + if self.sslctx is None: + self.sslctx = ssl.create_default_context() + # According to MODBUS/TCP Security Protocol Specification, it is + # TLSv2 at least + self.sslctx.options |= ssl.OP_NO_TLSv1_1 + self.sslctx.options |= ssl.OP_NO_TLSv1 + self.sslctx.options |= ssl.OP_NO_SSLv3 + self.sslctx.options |= ssl.OP_NO_SSLv2 + ModbusTcpClient.__init__(self, host, port, framer, **kwargs) + + def connect(self): + """ Connect to the modbus tls server + + :returns: True if connection succeeded, False otherwise + """ + if self.socket: return True + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(self.source_address) + self.socket = self.sslctx.wrap_socket(sock, server_side=False, + server_hostname=self.host) + self.socket.settimeout(self.timeout) + self.socket.connect((self.host, self.port)) + except socket.error as msg: + _logger.error('Connection to (%s, %s) ' + 'failed: %s' % (self.host, self.port, msg)) + self.close() + return self.socket is not None + + def _recv(self, size): + """ Reads data from the underlying descriptor + + :param size: The number of bytes to read + :return: The bytes read + """ + if not self.socket: + raise ConnectionException(self.__str__()) + + # socket.recv(size) waits until it gets some data from the host but + # not necessarily the entire response that can be fragmented in + # many packets. + # To avoid the splitted responses to be recognized as invalid + # messages and to be discarded, loops socket.recv until full data + # is received or timeout is expired. + # If timeout expires returns the read data, also if its length is + # less than the expected size. + timeout = self.timeout + + # If size isn't specified read 1 byte at a time. + if size is None: + recv_size = 1 + else: + recv_size = size + + data = b'' + time_ = time.time() + end = time_ + timeout + while recv_size > 0: + data += self.socket.recv(recv_size) + time_ = time.time() + + # If size isn't specified continue to read until timeout expires. + if size: + recv_size = size - len(data) + + # Timeout is reduced also if some data has been received in order + # to avoid infinite loops when there isn't an expected response + # size and the slave sends noisy data continuosly. + if time_ > end: + break + + return data + + def __str__(self): + """ Builds a string representation of the connection + + :returns: The string representation + """ + return "ModbusTlsClient(%s:%s)" % (self.host, self.port) + + def __repr__(self): + return ( + "<{} at {} socket={self.socket}, ipaddr={self.host}, " + "port={self.port}, sslctx={self.sslctx}, timeout={self.timeout}>" + ).format(self.__class__.__name__, hex(id(self)), self=self) + + # --------------------------------------------------------------------------- # # Modbus UDP Client Transport Implementation # --------------------------------------------------------------------------- # @@ -470,6 +585,7 @@ def connect(self): :returns: True if connection succeeded, False otherwise """ + import serial if self.socket: return True try: @@ -593,5 +709,5 @@ def __repr__(self): __all__ = [ - "ModbusTcpClient", "ModbusUdpClient", "ModbusSerialClient" + "ModbusTcpClient", "ModbusTlsClient", "ModbusUdpClient", "ModbusSerialClient" ] diff --git a/pymodbus/constants.py b/pymodbus/constants.py index c05f0b555..fc26f2e07 100644 --- a/pymodbus/constants.py +++ b/pymodbus/constants.py @@ -15,6 +15,10 @@ class Defaults(Singleton): The default modbus tcp server port (502) + .. attribute:: TLSPort + + The default modbus tcp over tls server port (802) + .. attribute:: Retries The default number of times a client should retry the given @@ -99,6 +103,7 @@ class Defaults(Singleton): ''' Port = 502 + TLSPort = 802 Retries = 3 RetryOnEmpty = False Timeout = 3 diff --git a/pymodbus/diag_message.py b/pymodbus/diag_message.py index f6c02cb1c..b76782c38 100644 --- a/pymodbus/diag_message.py +++ b/pymodbus/diag_message.py @@ -120,6 +120,7 @@ def decode(self, data): word_len = len(data)//2 if len(data) % 2: word_len += 1 + data = data + b'0' data = struct.unpack('>' + 'H'*word_len, data) self.sub_function_code, self.message = data[0], data[1:] diff --git a/pymodbus/framer/__init__.py b/pymodbus/framer/__init__.py index 5d84cc412..4859d0104 100644 --- a/pymodbus/framer/__init__.py +++ b/pymodbus/framer/__init__.py @@ -8,6 +8,8 @@ # Transaction Id, Protocol ID, Length, Unit ID, Function Code SOCKET_FRAME_HEADER = BYTE_ORDER + 'HHH' + FRAME_HEADER +# Function Code +TLS_FRAME_HEADER = BYTE_ORDER + 'B' class ModbusFramer(IModbusFramer): """ diff --git a/pymodbus/framer/tls_framer.py b/pymodbus/framer/tls_framer.py new file mode 100644 index 000000000..33bd48d89 --- /dev/null +++ b/pymodbus/framer/tls_framer.py @@ -0,0 +1,185 @@ +import struct +from pymodbus.exceptions import ModbusIOException +from pymodbus.exceptions import InvalidMessageReceivedException +from pymodbus.utilities import hexlify_packets +from pymodbus.framer import ModbusFramer, TLS_FRAME_HEADER + +# --------------------------------------------------------------------------- # +# Logging +# --------------------------------------------------------------------------- # +import logging +_logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- # +# Modbus TLS Message +# --------------------------------------------------------------------------- # + + +class ModbusTlsFramer(ModbusFramer): + """ Modbus TLS Frame controller + + No prefix MBAP header before decrypted PDU is used as a message frame for + Modbus Security Application Protocol. It allows us to easily separate + decrypted messages which is PDU as follows: + + [ Function Code] [ Data ] + 1b Nb + """ + + def __init__(self, decoder, client=None): + """ Initializes a new instance of the framer + + :param decoder: The decoder factory implementation to use + """ + self._buffer = b'' + self._header = {} + self._hsize = 0x0 + self.decoder = decoder + self.client = client + + # ----------------------------------------------------------------------- # + # Private Helper Functions + # ----------------------------------------------------------------------- # + def checkFrame(self): + """ + Check and decode the next frame Return true if we were successful + """ + if self.isFrameReady(): + # we have at least a complete message, continue + if len(self._buffer) - self._hsize >= 1: + return True + # we don't have enough of a message yet, wait + return False + + def advanceFrame(self): + """ Skip over the current framed message + This allows us to skip over the current message after we have processed + it or determined that it contains an error. It also has to reset the + current frame header handle + """ + self._buffer = b'' + self._header = {} + + def isFrameReady(self): + """ Check if we should continue decode logic + This is meant to be used in a while loop in the decoding phase to let + the decoder factory know that there is still data in the buffer. + + :returns: True if ready, False otherwise + """ + return len(self._buffer) > self._hsize + + def addToFrame(self, message): + """ Adds new packet data to the current frame buffer + + :param message: The most recent packet + """ + self._buffer += message + + def getFrame(self): + """ Return the next frame from the buffered data + + :returns: The next full frame buffer + """ + return self._buffer[self._hsize:] + + def populateResult(self, result): + """ + Populates the modbus result with the transport specific header + information (no header before PDU in decrypted message) + + :param result: The response packet + """ + return + + # ----------------------------------------------------------------------- # + # Public Member Functions + # ----------------------------------------------------------------------- # + def decode_data(self, data): + if len(data) > self._hsize: + (fcode,) = struct.unpack(TLS_FRAME_HEADER, data[0:self._hsize+1]) + return dict(fcode=fcode) + return dict() + + def processIncomingPacket(self, data, callback, unit, **kwargs): + """ + The new packet processing pattern + + This takes in a new request packet, adds it to the current + packet stream, and performs framing on it. That is, checks + for complete messages, and once found, will process all that + exist. This handles the case when we read N + 1 or 1 // N + messages at a time instead of 1. + + The processed and decoded messages are pushed to the callback + function to process and send. + + :param data: The new packet data + :param callback: The function to send results to + :param unit: Process if unit id matches, ignore otherwise (could be a + list of unit ids (server) or single unit id(client/server) + :param single: True or False (If True, ignore unit address validation) + :return: + """ + if not isinstance(unit, (list, tuple)): + unit = [unit] + # no unit id for Modbus Security Application Protocol + single = kwargs.get("single", True) + _logger.debug("Processing: " + hexlify_packets(data)) + self.addToFrame(data) + + if self.isFrameReady(): + if self.checkFrame(): + if self._validate_unit_id(unit, single): + self._process(callback) + else: + _logger.debug("Not in valid unit id - {}, " + "ignoring!!".format(unit)) + self.resetFrame() + else: + _logger.debug("Frame check failed, ignoring!!") + self.resetFrame() + + def _process(self, callback, error=False): + """ + Process incoming packets irrespective error condition + """ + data = self.getRawFrame() if error else self.getFrame() + result = self.decoder.decode(data) + if result is None: + raise ModbusIOException("Unable to decode request") + elif error and result.function_code < 0x80: + raise InvalidMessageReceivedException(result) + else: + self.populateResult(result) + self.advanceFrame() + callback(result) # defer or push to a thread? + + def resetFrame(self): + """ + Reset the entire message frame. + This allows us to skip ovver errors that may be in the stream. + It is hard to know if we are simply out of sync or if there is + an error in the stream as we have no way to check the start or + end of the message (python just doesn't have the resolution to + check for millisecond delays). + """ + self._buffer = b'' + + def getRawFrame(self): + """ + Returns the complete buffer + """ + return self._buffer + + def buildPacket(self, message): + """ Creates a ready to send modbus packet + + :param message: The populated request/response to send + """ + data = message.encode() + packet = struct.pack(TLS_FRAME_HEADER, message.function_code) + packet += data + return packet + +# __END__ diff --git a/pymodbus/payload.py b/pymodbus/payload.py index f97d434a4..15aa66dc1 100644 --- a/pymodbus/payload.py +++ b/pymodbus/payload.py @@ -14,7 +14,7 @@ from pymodbus.utilities import unpack_bitstring from pymodbus.utilities import make_byte_string from pymodbus.exceptions import ParameterException -from pymodbus.compat import unicode_string +from pymodbus.compat import unicode_string, IS_PYTHON3, PYTHON_VERSION # --------------------------------------------------------------------------- # # Logging # --------------------------------------------------------------------------- # @@ -25,6 +25,7 @@ WC = { "b": 1, "h": 2, + "e": 2, "i": 4, "l": 4, "q": 8, @@ -229,6 +230,18 @@ def add_64bit_int(self, value): p_string = self._pack_words(fstring, value) self._payload.append(p_string) + def add_16bit_float(self, value): + """ Adds a 16 bit float to the buffer + + :param value: The value to add to the buffer + """ + if IS_PYTHON3 and PYTHON_VERSION.minor >= 6: + fstring = 'e' + p_string = self._pack_words(fstring, value) + self._payload.append(p_string) + else: + _logger.warning("float16 only supported on python3.6 and above!!!") + def add_32bit_float(self, value): """ Adds a 32 bit float to the buffer @@ -443,6 +456,18 @@ def decode_64bit_int(self): handle = self._unpack_words(fstring, handle) return unpack("!"+fstring, handle)[0] + def decode_16bit_float(self): + """ Decodes a 16 bit float from the buffer + """ + if IS_PYTHON3 and PYTHON_VERSION.minor >= 6: + self._pointer += 2 + fstring = 'e' + handle = self._payload[self._pointer - 2:self._pointer] + handle = self._unpack_words(fstring, handle) + return unpack("!"+fstring, handle)[0] + else: + _logger.warning("float16 only supported on python3.6 and above!!!") + def decode_32bit_float(self): """ Decodes a 32 bit float from the buffer """ diff --git a/pymodbus/pdu.py b/pymodbus/pdu.py index 8f5e3cea9..2c7c55d84 100644 --- a/pymodbus/pdu.py +++ b/pymodbus/pdu.py @@ -19,7 +19,7 @@ # --------------------------------------------------------------------------- # class ModbusPDU(object): """ - Base class for all Modbus mesages + Base class for all Modbus messages .. attribute:: transaction_id diff --git a/pymodbus/repl/README.md b/pymodbus/repl/README.md index d5e970861..bea1d25bc 100644 --- a/pymodbus/repl/README.md +++ b/pymodbus/repl/README.md @@ -41,6 +41,7 @@ Usage: pymodbus.console tcp [OPTIONS] Options: --host TEXT Modbus TCP IP --port INTEGER Modbus TCP port + --framer TEXT Override the default packet framer tcp|rtu --help Show this message and exit. diff --git a/pymodbus/repl/helper.py b/pymodbus/repl/helper.py index 7f255a7e6..38a29e9df 100644 --- a/pymodbus/repl/helper.py +++ b/pymodbus/repl/helper.py @@ -33,6 +33,7 @@ 'uint16': 'decode_16bit_uint', 'uint32': 'decode_32bit_uint', 'uint64': 'decode_64bit_int', + 'float16': 'decode_16bit_float', 'float32': 'decode_32bit_float', 'float64': 'decode_64bit_float', } diff --git a/pymodbus/repl/main.py b/pymodbus/repl/main.py index cd13f29d1..d8149368a 100644 --- a/pymodbus/repl/main.py +++ b/pymodbus/repl/main.py @@ -250,9 +250,19 @@ def main(ctx, verbose): type=int, help="Modbus TCP port", ) -def tcp(ctx, host, port): +@click.option( + "--framer", + default='tcp', + type=str, + help="Override the default packet framer tcp|rtu", +) +def tcp(ctx, host, port, framer): from pymodbus.repl.client import ModbusTcpClient - client = ModbusTcpClient(host=host, port=port) + kwargs = dict(host=host, port=port) + if framer == 'rtu': + from pymodbus.framer.rtu_framer import ModbusRtuFramer + kwargs['framer'] = ModbusRtuFramer + client = ModbusTcpClient(**kwargs) cli(client) diff --git a/pymodbus/server/asyncio.py b/pymodbus/server/asyncio.py new file mode 100755 index 000000000..50ccf97d1 --- /dev/null +++ b/pymodbus/server/asyncio.py @@ -0,0 +1,642 @@ +""" +Implementation of a Threaded Modbus Server +------------------------------------------ + +""" +from binascii import b2a_hex +import socket +import traceback + +import asyncio +from pymodbus.compat import PYTHON_VERSION +from pymodbus.constants import Defaults +from pymodbus.utilities import hexlify_packets +from pymodbus.factory import ServerDecoder +from pymodbus.datastore import ModbusServerContext +from pymodbus.device import ModbusControlBlock +from pymodbus.device import ModbusDeviceIdentification +from pymodbus.transaction import * +from pymodbus.exceptions import NotImplementedException, NoSuchSlaveException +from pymodbus.pdu import ModbusExceptions as merror +from pymodbus.compat import socketserver, byte2int + +# --------------------------------------------------------------------------- # +# Logging +# --------------------------------------------------------------------------- # +import logging +_logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- # +# Protocol Handlers +# --------------------------------------------------------------------------- # + +class ModbusBaseRequestHandler(asyncio.BaseProtocol): + """ Implements modbus slave wire protocol + This uses the asyncio.Protocol to implement the client handler. + + When a connection is established, the asyncio.Protocol.connection_made + callback is called. This callback will setup the connection and + create and schedule an asyncio.Task and assign it to running_task. + + running_task will be canceled upon connection_lost event. + """ + def __init__(self, owner): + self.server = owner + self.running = False + self.receive_queue = asyncio.Queue() + self.handler_task = None # coroutine to be run on asyncio loop + + def connection_made(self, transport): + """ + asyncio.BaseProtocol callback for socket establish + + For streamed protocols (TCP) this will also correspond to an + entire conversation; however for datagram protocols (UDP) this + corresponds to the socket being opened + """ + try: + _logger.debug("Socket [%s:%s] opened" % transport.get_extra_info('sockname')) + self.transport = transport + self.running = True + self.framer = self.server.framer(self.server.decoder, client=None) + + # schedule the connection handler on the event loop + if PYTHON_VERSION >= (3, 7): + self.handler_task = asyncio.create_task(self.handle()) + else: + self.handler_task = asyncio.ensure_future(self.handle()) + except Exception as ex: # pragma: no cover + _logger.debug("Datastore unable to fulfill request: " + "%s; %s", ex, traceback.format_exc()) + + def connection_lost(self, exc): + """ + asyncio.BaseProtocol callback for socket tear down + + For streamed protocols any break in the network connection will + be reported here; for datagram protocols, only a teardown of the + socket itself will result in this call. + """ + try: + self.handler_task.cancel() + + if exc is None: + if hasattr(self, "client_address"): # TCP connection + _logger.debug("Disconnected from client [%s:%s]" % self.client_address) + else: + _logger.debug("Disconnected from client [%s]" % self.transport.get_extra_info("peername")) + else: # pragma: no cover + __logger.debug("Client Disconnection [%s:%s] due to %s" % (*self.client_address, exc)) + + self.running = False + + except Exception as ex: # pragma: no cover + _logger.debug("Datastore unable to fulfill request: " + "%s; %s", ex, traceback.format_exc()) + + async def handle(self): + """Asyncio coroutine which represents a single conversation between + the modbus slave and master + + Once the client connection is established, the data chunks will be + fed to this coroutine via the asyncio.Queue object which is fed by + the ModbusBaseRequestHandler class's callback Future. + + This callback future gets data from either asyncio.DatagramProtocol.datagram_received + or from asyncio.BaseProtocol.data_received. + + This function will execute without blocking in the while-loop and + yield to the asyncio event loop when the frame is exhausted. + As a result, multiple clients can be interleaved without any + interference between them. + + For ModbusConnectedRequestHandler, each connection will be given an + instance of the handle() coroutine and this instance will be put in the + active_connections dict. Calling server_close will individually cancel + each running handle() task. + + For ModbusDisconnectedRequestHandler, a single handle() coroutine will + be started and maintained. Calling server_close will cancel that task. + + """ + reset_frame = False + while self.running: + try: + units = self.server.context.slaves() + data = await self._recv_() # this is an asyncio.Queue await, it will never fail + if isinstance(data, tuple): + data, *addr = data # addr is populated when talking over UDP + else: + addr = (None,) # empty tuple + + if not isinstance(units, (list, tuple)): + units = [units] + # if broadcast is enabled make sure to + # process requests to address 0 + if self.server.broadcast_enable: # pragma: no cover + if 0 not in units: + units.append(0) + + if _logger.isEnabledFor(logging.DEBUG): + _logger.debug('Handling data: ' + hexlify_packets(data)) + + single = self.server.context.single + self.framer.processIncomingPacket(data=data, + callback=lambda x: self.execute(x, *addr), + unit=units, + single=single) + + except asyncio.CancelledError: + # catch and ignore cancelation errors + if isinstance(self, ModbusConnectedRequestHandler): + _logger.debug("Handler for stream [%s:%s] has been canceled" % self.client_address) + else: + _logger.debug("Handler for UDP socket [%s] has been canceled" % self.protocol._sock.getsockname()[1]) + + except Exception as e: + # force TCP socket termination as processIncomingPacket should handle applicaiton layer errors + # for UDP sockets, simply reset the frame + if isinstance(self, ModbusConnectedRequestHandler): + _logger.info("Unknown exception '%s' on stream [%s:%s] forcing disconnect" % (e, *self.client_address)) + self.transport.close() + else: + _logger.error("Unknown error occurred %s" % e) + reset_frame = True # graceful recovery + finally: + if reset_frame: + self.framer.resetFrame() + reset_frame = False + + def execute(self, request, *addr): + """ The callback to call with the resulting message + + :param request: The decoded request message + """ + broadcast = False + try: + if self.server.broadcast_enable and request.unit_id == 0: + broadcast = True + # if broadcasting then execute on all slave contexts, note response will be ignored + for unit_id in self.server.context.slaves(): + response = request.execute(self.server.context[unit_id]) + else: + context = self.server.context[request.unit_id] + response = request.execute(context) + except NoSuchSlaveException as ex: + _logger.debug("requested slave does " + "not exist: %s" % request.unit_id ) + if self.server.ignore_missing_slaves: + return # the client will simply timeout waiting for a response + response = request.doException(merror.GatewayNoResponse) + except Exception as ex: + _logger.debug("Datastore unable to fulfill request: " + "%s; %s", ex, traceback.format_exc()) + response = request.doException(merror.SlaveFailure) + # no response when broadcasting + if not broadcast: + response.transaction_id = request.transaction_id + response.unit_id = request.unit_id + self.send(response, *addr) + + + def send(self, message, *addr): + if message.should_respond: + # self.server.control.Counter.BusMessage += 1 + pdu = self.framer.buildPacket(message) + if _logger.isEnabledFor(logging.DEBUG): + _logger.debug('send: [%s]- %s' % (message, b2a_hex(pdu))) + if addr == (None,): + self._send_(pdu) + else: + self._send_(pdu, *addr) + + # ----------------------------------------------------------------------- # + # Derived class implementations + # ----------------------------------------------------------------------- # + + def _send_(self, data): # pragma: no cover + """ Send a request (string) to the network + + :param message: The unencoded modbus response + """ + raise NotImplementedException("Method not implemented " + "by derived class") + async def _recv_(self): # pragma: no cover + """ Receive data from the network + + :return: + """ + raise NotImplementedException("Method not implemented " + "by derived class") + + +class ModbusConnectedRequestHandler(ModbusBaseRequestHandler,asyncio.Protocol): + """ Implements the modbus server protocol + + This uses asyncio.Protocol to implement + the client handler for a connected protocol (TCP). + """ + + def connection_made(self, transport): + """ asyncio.BaseProtocol: Called when a connection is made. """ + super().connection_made(transport) + + self.client_address = transport.get_extra_info('peername') + self.server.active_connections[self.client_address] = self + _logger.debug("TCP client connection established [%s:%s]" % self.client_address) + + def connection_lost(self, exc): + """ asyncio.BaseProtocol: Called when the connection is lost or closed.""" + super().connection_lost(exc) + _logger.debug("TCP client disconnected [%s:%s]" % self.client_address) + if self.client_address in self.server.active_connections: + self.server.active_connections.pop(self.client_address) + + + def data_received(self,data): + """ + asyncio.Protocol: (TCP) Called when some data is received. + data is a non-empty bytes object containing the incoming data. + """ + self.receive_queue.put_nowait(data) + + async def _recv_(self): + return await self.receive_queue.get() + + def _send_(self, data): + """ tcp send """ + self.transport.write(data) + + +class ModbusDisconnectedRequestHandler(ModbusBaseRequestHandler, asyncio.DatagramProtocol): + """ Implements the modbus server protocol + + This uses the socketserver.BaseRequestHandler to implement + the client handler for a disconnected protocol (UDP). The + only difference is that we have to specify who to send the + resulting packet data to. + """ + def __init__(self,owner): + super().__init__(owner) + self.server.on_connection_terminated = asyncio.get_event_loop().create_future() + + def connection_lost(self,exc): + super().connection_lost(exc) + self.server.on_connection_terminated.set_result(True) + + def datagram_received(self,data, addr): + """ + asyncio.DatagramProtocol: Called when a datagram is received. + data is a bytes object containing the incoming data. addr + is the address of the peer sending the data; the exact + format depends on the transport. + """ + self.receive_queue.put_nowait((data, addr)) + + def error_received(self,exc): # pragma: no cover + """ + asyncio.DatagramProtocol: Called when a previous send + or receive operation raises an OSError. exc is the + OSError instance. + + This method is called in rare conditions, + when the transport (e.g. UDP) detects that a datagram could + not be delivered to its recipient. In many conditions + though, undeliverable datagrams will be silently dropped. + """ + _logger.error("datagram connection error [%s]" % exc) + + async def _recv_(self): + return await self.receive_queue.get() + + def _send_(self, data, addr): + self.transport.sendto(data, addr=addr) + +class ModbusServerFactory: + """ + Builder class for a modbus server + + This also holds the server datastore so that it is persisted between connections + """ + + def __init__(self, store, framer=None, identity=None, **kwargs): + import warnings + warnings.warn("deprecated API for asyncio. ServerFactory's are a twisted construct and don't have an equivalent in asyncio", + DeprecationWarning) + + +# --------------------------------------------------------------------------- # +# Server Implementations +# --------------------------------------------------------------------------- # +class ModbusTcpServer: + """ + A modbus threaded tcp socket server + + We inherit and overload the socket server so that we + can control the client threads as well as have a single + server context instance. + """ + + def __init__(self, + context, + framer=None, + identity=None, + address=None, + handler=None, + allow_reuse_address=False, + allow_reuse_port=False, + defer_start=False, + backlog=20, + loop=None, + **kwargs): + """ Overloaded initializer for the socket server + + If the identify structure is not passed in, the ModbusControlBlock + uses its own empty structure. + + :param context: The ModbusServerContext datastore + :param framer: The framer strategy to use + :param identity: An optional identify structure + :param address: An optional (interface, port) to bind to. + :param handler: A handler for each client session; default is + ModbusConnectedRequestHandler. The handler class + receives connection create/teardown events + :param allow_reuse_address: Whether the server will allow the + reuse of an address. + :param allow_reuse_port: Whether the server will allow the + reuse of a port. + :param backlog: is the maximum number of queued connections + passed to listen(). Defaults to 20, increase if many + connections are being made and broken to your Modbus slave + :param loop: optional asyncio event loop to run in. Will default to + asyncio.get_event_loop() supplied value if None. + :param ignore_missing_slaves: True to not send errors on a request + to a missing slave + :param broadcast_enable: True to treat unit_id 0 as broadcast address, + False to treat 0 as any other unit_id + """ + self.active_connections = {} + self.loop = loop or asyncio.get_event_loop() + self.allow_reuse_address = allow_reuse_address + self.decoder = ServerDecoder() + self.framer = framer or ModbusSocketFramer + self.context = context or ModbusServerContext() + self.control = ModbusControlBlock() + self.address = address or ("", Defaults.Port) + self.handler = handler or ModbusConnectedRequestHandler + self.handler.server = self + self.ignore_missing_slaves = kwargs.get('ignore_missing_slaves', + Defaults.IgnoreMissingSlaves) + self.broadcast_enable = kwargs.get('broadcast_enable', + Defaults.broadcast_enable) + + if isinstance(identity, ModbusDeviceIdentification): + self.control.Identity.update(identity) + + self.serving = self.loop.create_future() # asyncio future that will be done once server has started + self.server = None # constructors cannot be declared async, so we have to defer the initialization of the server + if PYTHON_VERSION >= (3, 7): + # start_serving is new in version 3.7 + self.server_factory = self.loop.create_server(lambda : self.handler(self), + *self.address, + reuse_address=allow_reuse_address, + reuse_port=allow_reuse_port, + backlog=backlog, + start_serving=not defer_start) + else: + self.server_factory = self.loop.create_server(lambda : self.handler(self), + *self.address, + reuse_address=allow_reuse_address, + reuse_port=allow_reuse_port, + backlog=backlog) + + async def serve_forever(self): + if self.server is None: + self.server = await self.server_factory + self.serving.set_result(True) + await self.server.serve_forever() + else: + raise RuntimeError("Can't call serve_forever on an already running server object") + + def server_close(self): + for k,v in self.active_connections.items(): + _logger.warning("aborting active session {}".format(k)) + v.handler_task.cancel() + self.active_connections = {} + self.server.close() + + +class ModbusUdpServer: + """ + A modbus threaded udp socket server + + We inherit and overload the socket server so that we + can control the client threads as well as have a single + server context instance. + """ + + def __init__(self, context, framer=None, identity=None, address=None, + handler=None, allow_reuse_address=False, + allow_reuse_port=False, + defer_start=False, + backlog=20, + loop=None, + **kwargs): + """ Overloaded initializer for the socket server + + If the identify structure is not passed in, the ModbusControlBlock + uses its own empty structure. + + :param context: The ModbusServerContext datastore + :param framer: The framer strategy to use + :param identity: An optional identify structure + :param address: An optional (interface, port) to bind to. + :param handler: A handler for each client session; default is + ModbusDisonnectedRequestHandler + :param ignore_missing_slaves: True to not send errors on a request + to a missing slave + :param broadcast_enable: True to treat unit_id 0 as broadcast address, + False to treat 0 as any other unit_id + """ + self.loop = loop or asyncio.get_event_loop() + self.decoder = ServerDecoder() + self.framer = framer or ModbusSocketFramer + self.context = context or ModbusServerContext() + self.control = ModbusControlBlock() + self.address = address or ("", Defaults.Port) + self.handler = handler or ModbusDisconnectedRequestHandler + self.ignore_missing_slaves = kwargs.get('ignore_missing_slaves', + Defaults.IgnoreMissingSlaves) + self.broadcast_enable = kwargs.get('broadcast_enable', + Defaults.broadcast_enable) + + if isinstance(identity, ModbusDeviceIdentification): + self.control.Identity.update(identity) + + self.protocol = None + self.endpoint = None + self.on_connection_terminated = None + self.stop_serving = self.loop.create_future() + self.serving = self.loop.create_future() # asyncio future that will be done once server has started + self.server_factory = self.loop.create_datagram_endpoint(lambda: self.handler(self), + local_addr=self.address, + reuse_address=allow_reuse_address, + reuse_port=allow_reuse_port, + allow_broadcast=True) + + async def serve_forever(self): + if self.protocol is None: + self.protocol, self.endpoint = await self.server_factory + self.serving.set_result(True) + await self.stop_serving + else: + raise RuntimeError("Can't call serve_forever on an already running server object") + + def server_close(self): + self.stop_serving.set_result(True) + if self.endpoint.handler_task is not None: + self.endpoint.handler_task.cancel() + + self.protocol.close() + + + +class ModbusSerialServer(object): + """ + A modbus threaded serial socket server + + We inherit and overload the socket server so that we + can control the client threads as well as have a single + server context instance. + """ + + handler = None + + def __init__(self, context, framer=None, identity=None, **kwargs): # pragma: no cover + """ Overloaded initializer for the socket server + + If the identify structure is not passed in, the ModbusControlBlock + uses its own empty structure. + + :param context: The ModbusServerContext datastore + :param framer: The framer strategy to use + :param identity: An optional identify structure + :param port: The serial port to attach to + :param stopbits: The number of stop bits to use + :param bytesize: The bytesize of the serial messages + :param parity: Which kind of parity to use + :param baudrate: The baud rate to use for the serial device + :param timeout: The timeout to use for the serial device + :param ignore_missing_slaves: True to not send errors on a request + to a missing slave + :param broadcast_enable: True to treat unit_id 0 as broadcast address, + False to treat 0 as any other unit_id + """ + raise NotImplementedException + + +# --------------------------------------------------------------------------- # +# Creation Factories +# --------------------------------------------------------------------------- # +async def StartTcpServer(context=None, identity=None, address=None, + custom_functions=[], defer_start=True, **kwargs): + """ A factory to start and run a tcp modbus server + + :param context: The ModbusServerContext datastore + :param identity: An optional identify structure + :param address: An optional (interface, port) to bind to. + :param custom_functions: An optional list of custom function classes + supported by server instance. + :param defer_start: if set, a coroutine which can be started and stopped + will be returned. Otherwise, the server will be immediately spun + up without the ability to shut it off from within the asyncio loop + :param ignore_missing_slaves: True to not send errors on a request to a + missing slave + :return: an initialized but inactive server object coroutine + """ + framer = kwargs.pop("framer", ModbusSocketFramer) + server = ModbusTcpServer(context, framer, identity, address, **kwargs) + + for f in custom_functions: + server.decoder.register(f) # pragma: no cover + + if not defer_start: + await server.serve_forever() + + return server + + + + +async def StartUdpServer(context=None, identity=None, address=None, + custom_functions=[], defer_start=True, **kwargs): + """ A factory to start and run a udp modbus server + + :param context: The ModbusServerContext datastore + :param identity: An optional identify structure + :param address: An optional (interface, port) to bind to. + :param custom_functions: An optional list of custom function classes + supported by server instance. + :param framer: The framer to operate with (default ModbusSocketFramer) + :param ignore_missing_slaves: True to not send errors on a request + to a missing slave + """ + framer = kwargs.pop('framer', ModbusSocketFramer) + server = ModbusUdpServer(context, framer, identity, address, **kwargs) + + for f in custom_functions: + server.decoder.register(f) # pragma: no cover + + if not defer_start: + await server.serve_forever() # pragma: no cover + + return server + + + +def StartSerialServer(context=None, identity=None, custom_functions=[], + **kwargs):# pragma: no cover + """ A factory to start and run a serial modbus server + + :param context: The ModbusServerContext datastore + :param identity: An optional identify structure + :param custom_functions: An optional list of custom function classes + supported by server instance. + :param framer: The framer to operate with (default ModbusAsciiFramer) + :param port: The serial port to attach to + :param stopbits: The number of stop bits to use + :param bytesize: The bytesize of the serial messages + :param parity: Which kind of parity to use + :param baudrate: The baud rate to use for the serial device + :param timeout: The timeout to use for the serial device + :param ignore_missing_slaves: True to not send errors on a request to a + missing slave + """ + raise NotImplementedException + import serial + framer = kwargs.pop('framer', ModbusAsciiFramer) + server = ModbusSerialServer(context, framer, identity, **kwargs) + for f in custom_functions: + server.decoder.register(f) + server.serve_forever() + +def StopServer(): + """ + Helper method to stop Async Server + """ + import warnings + warnings.warn("deprecated API for asyncio. Call server_close() on server object returned by StartXxxServer", + DeprecationWarning) + + + +# --------------------------------------------------------------------------- # +# Exported symbols +# --------------------------------------------------------------------------- # + + +__all__ = [ + "StartTcpServer", "StartUdpServer", "StartSerialServer" +] + diff --git a/pymodbus/server/sync.py b/pymodbus/server/sync.py index 9492265f7..f7b22454f 100644 --- a/pymodbus/server/sync.py +++ b/pymodbus/server/sync.py @@ -6,6 +6,7 @@ from binascii import b2a_hex import serial import socket +import ssl import traceback from pymodbus.constants import Defaults @@ -364,6 +365,63 @@ def server_close(self): thread.running = False +class ModbusTlsServer(ModbusTcpServer): + """ + A modbus threaded TLS server + + We inherit and overload the ModbusTcpServer so that we + can control the client threads as well as have a single + server context instance. + """ + + def __init__(self, context, framer=None, identity=None, + address=None, handler=None, allow_reuse_address=False, + sslctx=None, certfile=None, keyfile=None, **kwargs): + """ Overloaded initializer for the ModbusTcpServer + + If the identify structure is not passed in, the ModbusControlBlock + uses its own empty structure. + + :param context: The ModbusServerContext datastore + :param framer: The framer strategy to use + :param identity: An optional identify structure + :param address: An optional (interface, port) to bind to. + :param handler: A handler for each client session; default is + ModbusConnectedRequestHandler + :param allow_reuse_address: Whether the server will allow the + reuse of an address. + :param sslctx: The SSLContext to use for TLS (default None and auto + create) + :param certfile: The cert file path for TLS (used if sslctx is None) + :param keyfile: The key file path for TLS (used if sslctx is None) + :param ignore_missing_slaves: True to not send errors on a request + to a missing slave + :param broadcast_enable: True to treat unit_id 0 as broadcast address, + False to treat 0 as any other unit_id + """ + self.sslctx = sslctx + if self.sslctx is None: + self.sslctx = ssl.create_default_context() + self.sslctx.load_cert_chain(certfile=certfile, keyfile=keyfile) + # According to MODBUS/TCP Security Protocol Specification, it is + # TLSv2 at least + self.sslctx.options |= ssl.OP_NO_TLSv1_1 + self.sslctx.options |= ssl.OP_NO_TLSv1 + self.sslctx.options |= ssl.OP_NO_SSLv3 + self.sslctx.options |= ssl.OP_NO_SSLv2 + self.sslctx.verify_mode = ssl.CERT_OPTIONAL + self.sslctx.check_hostname = False + + ModbusTcpServer.__init__(self, context, framer, identity, address, + handler, allow_reuse_address, **kwargs) + + def server_activate(self): + """ Callback for starting listening over TLS connection + """ + self.socket = self.sslctx.wrap_socket(self.socket, server_side=True) + socketserver.ThreadingTCPServer.server_activate(self) + + class ModbusUdpServer(socketserver.ThreadingUDPServer): """ A modbus threaded udp socket server @@ -562,6 +620,30 @@ def StartTcpServer(context=None, identity=None, address=None, server.serve_forever() +def StartTlsServer(context=None, identity=None, address=None, sslctx=None, + certfile=None, keyfile=None, custom_functions=[], **kwargs): + """ A factory to start and run a tls modbus server + + :param context: The ModbusServerContext datastore + :param identity: An optional identify structure + :param address: An optional (interface, port) to bind to. + :param sslctx: The SSLContext to use for TLS (default None and auto create) + :param certfile: The cert file path for TLS (used if sslctx is None) + :param keyfile: The key file path for TLS (used if sslctx is None) + :param custom_functions: An optional list of custom function classes + supported by server instance. + :param ignore_missing_slaves: True to not send errors on a request to a + missing slave + """ + framer = kwargs.pop("framer", ModbusTlsFramer) + server = ModbusTlsServer(context, framer, identity, address, sslctx=sslctx, + certfile=certfile, keyfile=keyfile, **kwargs) + + for f in custom_functions: + server.decoder.register(f) + server.serve_forever() + + def StartUdpServer(context=None, identity=None, address=None, custom_functions=[], **kwargs): """ A factory to start and run a udp modbus server @@ -612,6 +694,6 @@ def StartSerialServer(context=None, identity=None, custom_functions=[], __all__ = [ - "StartTcpServer", "StartUdpServer", "StartSerialServer" + "StartTcpServer", "StartTlsServer", "StartUdpServer", "StartSerialServer" ] diff --git a/pymodbus/transaction.py b/pymodbus/transaction.py index 579fcd86d..06a80ac17 100644 --- a/pymodbus/transaction.py +++ b/pymodbus/transaction.py @@ -14,6 +14,7 @@ from pymodbus.framer.ascii_framer import ModbusAsciiFramer from pymodbus.framer.rtu_framer import ModbusRtuFramer from pymodbus.framer.socket_framer import ModbusSocketFramer +from pymodbus.framer.tls_framer import ModbusTlsFramer from pymodbus.framer.binary_framer import ModbusBinaryFramer from pymodbus.utilities import hexlify_packets, ModbusTransactionState from pymodbus.compat import iterkeys, byte2int @@ -37,7 +38,7 @@ # The Global Transaction Manager # --------------------------------------------------------------------------- # class ModbusTransactionManager(object): - """ Impelements a transaction for a manager + """ Implements a transaction for a manager The transaction protocol can be represented by the following pseudo code:: @@ -78,6 +79,8 @@ def _set_adu_size(self): self.base_adu_size = 7 # start(1)+ Address(2), LRC(2) + end(2) elif isinstance(self.client.framer, ModbusBinaryFramer): self.base_adu_size = 5 # start(1) + Address(1), CRC(2) + end(1) + elif isinstance(self.client.framer, ModbusTlsFramer): + self.base_adu_size = 0 # no header and footer else: self.base_adu_size = -1 @@ -91,7 +94,8 @@ def _calculate_exception_length(self): """ Returns the length of the Modbus Exception Response according to the type of Framer. """ - if isinstance(self.client.framer, ModbusSocketFramer): + if isinstance(self.client.framer, (ModbusSocketFramer, + ModbusTlsFramer)): return self.base_adu_size + 2 # Fcode(1), ExcecptionCode(1) elif isinstance(self.client.framer, ModbusAsciiFramer): return self.base_adu_size + 4 # Fcode(2), ExcecptionCode(2) @@ -459,6 +463,6 @@ def delTransaction(self, tid): __all__ = [ "FifoTransactionManager", "DictTransactionManager", - "ModbusSocketFramer", "ModbusRtuFramer", + "ModbusSocketFramer", "ModbusTlsFramer", "ModbusRtuFramer", "ModbusAsciiFramer", "ModbusBinaryFramer", ] diff --git a/pymodbus/version.py b/pymodbus/version.py index 51da88745..869f8e344 100644 --- a/pymodbus/version.py +++ b/pymodbus/version.py @@ -41,7 +41,7 @@ def __str__(self): return '[%s, version %s]' % (self.package, self.short()) -version = Version('pymodbus', 2, 2, 0) +version = Version('pymodbus', 2, 3, 0) version.__name__ = 'pymodbus' # fix epydoc error diff --git a/requirements-tests.txt b/requirements-tests.txt index 510306703..2ca42aa2d 100644 --- a/requirements-tests.txt +++ b/requirements-tests.txt @@ -6,7 +6,6 @@ mock >= 1.0.1 pyserial-asyncio==0.4.0;python_version>="3.4" pep8>=1.7.0 pyasn1>=0.2.3 -pycrypto>=2.6.1 pyserial>=3.4 pytest-cov>=2.5.1 pytest>=3.5.0 @@ -17,3 +16,4 @@ verboselogs >= 1.5 tornado==4.5.3 Twisted>=17.1.0 zope.interface>=4.4.0 +asynctest>=0.10.0 diff --git a/requirements.txt b/requirements.txt index ae349e55c..c44e3c1dd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -six==1.11.0 +six>=1.11.0 # ------------------------------------------------------------------- # if want to use the pymodbus serial stack, uncomment these # ------------------------------------------------------------------- @@ -16,7 +16,6 @@ six==1.11.0 #Twisted==17.1.0 #zope.interface==4.4.0 #pyasn1==0.2.3 -#pycrypto==2.6.1 #wsgiref==0.1.2 #cryptography==1.8.1 @@ -43,4 +42,4 @@ six==1.11.0 # if you want to use pymodbus REPL # ------------------------------------------------------------------- # click>=6.7 -# prompt-toolkit==2.0.4 \ No newline at end of file +# prompt-toolkit==2.0.4 diff --git a/setup.py b/setup.py index 9069eda16..38396aa02 100644 --- a/setup.py +++ b/setup.py @@ -40,8 +40,8 @@ version=__version__, description="A fully featured modbus protocol stack in python", long_description=""" - Pymodbus aims to be a fully implemented modbus protocol stack - implemented using twisted/asyncio/tornado. + Pymodbus aims to be a fully implemented modbus protocol stack + implemented using twisted/asyncio/tornado. Its orignal goal was to allow simulation of thousands of modbus devices on a single machine for monitoring software testing. """, @@ -64,7 +64,7 @@ maintainer=__maintainer__, maintainer_email='otlasanju@gmail.com', url='https://github.com/riptideio/pymodbus/', - license='BSD', + license='BSD-3-Clause', packages=find_packages(exclude=['examples', 'test']), exclude_package_data={'': ['examples', 'test', 'tools', 'doc']}, py_modules=['ez_setup'], @@ -85,7 +85,6 @@ 'twisted': [ 'twisted >= 12.2.0', 'pyasn1 >= 0.1.4', - 'pycrypto >= 2.6' ], 'tornado': [ 'tornado >= 4.5.3' diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 000000000..748fd4b7a --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,3 @@ +from pymodbus.compat import IS_PYTHON3, PYTHON_VERSION +if not IS_PYTHON3 or IS_PYTHON3 and PYTHON_VERSION.minor < 7: + collect_ignore = ["test_server_asyncio.py"] diff --git a/test/test_client_sync.py b/test/test_client_sync.py index 9c9496283..1014a9382 100644 --- a/test/test_client_sync.py +++ b/test/test_client_sync.py @@ -8,9 +8,11 @@ from mock import patch, Mock, MagicMock import socket import serial +import ssl from pymodbus.client.sync import ModbusTcpClient, ModbusUdpClient from pymodbus.client.sync import ModbusSerialClient, BaseModbusClient +from pymodbus.client.sync import ModbusTlsClient from pymodbus.exceptions import ConnectionException, NotImplementedException from pymodbus.exceptions import ParameterException from pymodbus.transaction import ModbusAsciiFramer, ModbusRtuFramer @@ -157,7 +159,6 @@ def testUdpClientRepr(self): ) self.assertEqual(repr(client), rep) - # -----------------------------------------------------------------------# # Test TCP Client # -----------------------------------------------------------------------# @@ -249,6 +250,97 @@ class CustomeRequest: client.framer = Mock() client.register(CustomeRequest) assert client.framer.decoder.register.called_once_with(CustomeRequest) + + # -----------------------------------------------------------------------# + # Test TLS Client + # -----------------------------------------------------------------------# + + def testSyncTlsClientInstantiation(self): + # default SSLContext + client = ModbusTlsClient() + self.assertNotEqual(client, None) + self.assertTrue(client.sslctx) + + # user defined SSLContext + context = ssl.create_default_context() + client = ModbusTlsClient(sslctx=context) + self.assertNotEqual(client, None) + self.assertEqual(client.sslctx, context) + + def testBasicSyncTlsClient(self): + ''' Test the basic methods for the tls sync client''' + + # receive/send + client = ModbusTlsClient() + client.socket = mockSocket() + self.assertEqual(0, client._send(None)) + self.assertEqual(1, client._send(b'\x00')) + self.assertEqual(b'\x00', client._recv(1)) + + # connect/disconnect + self.assertTrue(client.connect()) + client.close() + + # already closed socket + client.socket = False + client.close() + + self.assertEqual("ModbusTlsClient(localhost:802)", str(client)) + + def testTlsClientConnect(self): + ''' Test the tls client connection method''' + with patch.object(ssl.SSLSocket, 'connect') as mock_method: + client = ModbusTlsClient() + self.assertTrue(client.connect()) + + with patch.object(socket, 'create_connection') as mock_method: + mock_method.side_effect = socket.error() + client = ModbusTlsClient() + self.assertFalse(client.connect()) + + def testTlsClientSend(self): + ''' Test the tls client send method''' + client = ModbusTlsClient() + self.assertRaises(ConnectionException, lambda: client._send(None)) + + client.socket = mockSocket() + self.assertEqual(0, client._send(None)) + self.assertEqual(4, client._send('1234')) + + def testTlsClientRecv(self): + ''' Test the tls client receive method''' + client = ModbusTlsClient() + self.assertRaises(ConnectionException, lambda: client._recv(1024)) + + client.socket = mockSocket() + self.assertEqual(b'', client._recv(0)) + self.assertEqual(b'\x00' * 4, client._recv(4)) + + mock_socket = MagicMock() + mock_socket.recv.side_effect = iter([b'\x00', b'\x01', b'\x02']) + client.socket = mock_socket + client.timeout = 1 + self.assertEqual(b'\x00\x01\x02', client._recv(3)) + mock_socket.recv.side_effect = iter([b'\x00', b'\x01', b'\x02']) + self.assertEqual(b'\x00\x01', client._recv(2)) + + def testTlsClientRpr(self): + client = ModbusTlsClient() + rep = "<{} at {} socket={}, ipaddr={}, port={}, sslctx={}, " \ + "timeout={}>".format( + client.__class__.__name__, hex(id(client)), client.socket, + client.host, client.port, client.sslctx, client.timeout + ) + self.assertEqual(repr(client), rep) + + def testTlsClientRegister(self): + class CustomeRequest: + function_code = 79 + client = ModbusTlsClient() + client.framer = Mock() + client.register(CustomeRequest) + assert client.framer.decoder.register.called_once_with(CustomeRequest) + # -----------------------------------------------------------------------# # Test Serial Client # -----------------------------------------------------------------------# diff --git a/test/test_server_asyncio.py b/test/test_server_asyncio.py new file mode 100755 index 000000000..372c96479 --- /dev/null +++ b/test/test_server_asyncio.py @@ -0,0 +1,618 @@ +#!/usr/bin/env python +from pymodbus.compat import IS_PYTHON3, PYTHON_VERSION +import pytest +import asynctest +import asyncio +import logging +_logger = logging.getLogger() +if IS_PYTHON3: # Python 3 + from asynctest.mock import patch, Mock, MagicMock + +from pymodbus.device import ModbusDeviceIdentification +from pymodbus.factory import ServerDecoder +from pymodbus.server.asynchronous import ModbusTcpProtocol, ModbusUdpProtocol +from pymodbus.server.asyncio import StartTcpServer, StartUdpServer, StartSerialServer, StopServer, ModbusServerFactory +from pymodbus.server.asyncio import ModbusConnectedRequestHandler, ModbusBaseRequestHandler +from pymodbus.datastore import ModbusSequentialDataBlock +from pymodbus.datastore import ModbusSlaveContext, ModbusServerContext +from pymodbus.compat import byte2int +from pymodbus.transaction import ModbusSocketFramer +from pymodbus.exceptions import NoSuchSlaveException, ModbusIOException + +import sys +#---------------------------------------------------------------------------# +# Fixture +#---------------------------------------------------------------------------# +import platform +from distutils.version import LooseVersion + +IS_DARWIN = platform.system().lower() == "darwin" +OSX_SIERRA = LooseVersion("10.12") +if IS_DARWIN: + IS_HIGH_SIERRA_OR_ABOVE = LooseVersion(platform.mac_ver()[0]) + SERIAL_PORT = '/dev/ptyp0' if not IS_HIGH_SIERRA_OR_ABOVE else '/dev/ttyp0' +else: + IS_HIGH_SIERRA_OR_ABOVE = False + SERIAL_PORT = "/dev/ptmx" + +@pytest.mark.skipif(not IS_PYTHON3, reason="requires python3.4 or above") +class AsyncioServerTest(asynctest.TestCase): + ''' + This is the unittest for the pymodbus.server.asyncio module + + The scope of this unit test is the life-cycle management of the network + connections and server objects. + + This unittest suite does not attempt to test any of the underlying protocol details + ''' + + #-----------------------------------------------------------------------# + # Setup/TearDown + #-----------------------------------------------------------------------# + def setUp(self): + ''' + Initialize the test environment by setting up a dummy store and context + ''' + self.store = ModbusSlaveContext( di=ModbusSequentialDataBlock(0, [17]*100), + co=ModbusSequentialDataBlock(0, [17]*100), + hr=ModbusSequentialDataBlock(0, [17]*100), + ir=ModbusSequentialDataBlock(0, [17]*100)) + self.context = ModbusServerContext(slaves=self.store, single=True) + + def tearDown(self): + ''' Cleans up the test environment ''' + pass + + #-----------------------------------------------------------------------# + # Test ModbusConnectedRequestHandler + #-----------------------------------------------------------------------# + @asyncio.coroutine + def testStartTcpServer(self): + ''' Test that the modbus tcp asyncio server starts correctly ''' + identity = ModbusDeviceIdentification(info={0x00: 'VendorName'}) + self.loop = asynctest.Mock(self.loop) + server = yield from StartTcpServer(context=self.context,loop=self.loop,identity=identity) + self.assertEqual(server.control.Identity.VendorName, 'VendorName') + if PYTHON_VERSION >= (3, 6): + self.loop.create_server.assert_called_once() + + @pytest.mark.skipif(PYTHON_VERSION < (3, 7), reason="requires python3.7 or above") + @asyncio.coroutine + def testTcpServerServeNoDefer(self): + ''' Test StartTcpServer without deferred start (immediate execution of server) ''' + with patch('asyncio.base_events.Server.serve_forever', new_callable=asynctest.CoroutineMock) as serve: + server = yield from StartTcpServer(context=self.context,address=("127.0.0.1", 0), loop=self.loop, defer_start=False) + serve.assert_awaited() + + @pytest.mark.skipif(PYTHON_VERSION < (3, 7), reason="requires python3.7 or above") + @asyncio.coroutine + def testTcpServerServeForever(self): + ''' Test StartTcpServer serve_forever() method ''' + with patch('asyncio.base_events.Server.serve_forever', new_callable=asynctest.CoroutineMock) as serve: + server = yield from StartTcpServer(context=self.context,address=("127.0.0.1", 0), loop=self.loop) + yield from server.serve_forever() + serve.assert_awaited() + + @asyncio.coroutine + def testTcpServerServeForeverTwice(self): + ''' Call on serve_forever() twice should result in a runtime error ''' + server = yield from StartTcpServer(context=self.context,address=("127.0.0.1", 0), loop=self.loop) + if PYTHON_VERSION >= (3, 7): + server_task = asyncio.create_task(server.serve_forever()) + else: + server_task = asyncio.ensure_future(server.serve_forever()) + yield from server.serving + with self.assertRaises(RuntimeError): + yield from server.serve_forever() + server.server_close() + + @asyncio.coroutine + def testTcpServerReceiveData(self): + ''' Test data sent on socket is received by internals - doesn't not process data ''' + data = b'\x01\x00\x00\x00\x00\x06\x01\x03\x00\x00\x00\x19' + server = yield from StartTcpServer(context=self.context,address=("127.0.0.1", 0),loop=self.loop) + if PYTHON_VERSION >= (3, 7): + server_task = asyncio.create_task(server.serve_forever()) + else: + server_task = asyncio.ensure_future(server.serve_forever()) + yield from server.serving + with patch('pymodbus.transaction.ModbusSocketFramer.processIncomingPacket', new_callable=Mock) as process: + # process = server.framer.processIncomingPacket = Mock() + connected = self.loop.create_future() + random_port = server.server.sockets[0].getsockname()[1] # get the random server port + + class BasicClient(asyncio.BaseProtocol): + def connection_made(self, transport): + self.transport = transport + self.transport.write(data) + connected.set_result(True) + + def eof_received(self): + pass + + transport, protocol = yield from self.loop.create_connection(BasicClient, host='127.0.0.1',port=random_port) + yield from asyncio.sleep(0.1) # this may be better done by making an internal hook in the actual implementation + # if this unit test fails on a machine, see if increasing the sleep time makes a difference, if it does + # blame author for a fix + + if PYTHON_VERSION >= (3, 6): + process.assert_called_once() + self.assertTrue( process.call_args[1]["data"] == data ) + server.server_close() + + @asyncio.coroutine + def testTcpServerRoundtrip(self): + ''' Test sending and receiving data on tcp socket ''' + data = b"\x01\x00\x00\x00\x00\x06\x01\x03\x00\x00\x00\x01" # unit 1, read register + expected_response = b'\x01\x00\x00\x00\x00\x05\x01\x03\x02\x00\x11' # value of 17 as per context + server = yield from StartTcpServer(context=self.context,address=("127.0.0.1", 0),loop=self.loop) + if PYTHON_VERSION >= (3, 7): + server_task = asyncio.create_task(server.serve_forever()) + else: + server_task = asyncio.ensure_future(server.serve_forever()) + yield from server.serving + + random_port = server.server.sockets[0].getsockname()[1] # get the random server port + + connected, done = self.loop.create_future(),self.loop.create_future() + received_value = None + + class BasicClient(asyncio.BaseProtocol): + def connection_made(self, transport): + self.transport = transport + self.transport.write(data) + connected.set_result(True) + + def data_received(self, data): + nonlocal received_value, done + received_value = data + done.set_result(True) + + def eof_received(self): + pass + + transport, protocol = yield from self.loop.create_connection(BasicClient, host='127.0.0.1',port=random_port) + yield from asyncio.wait_for(done, timeout=0.1) + + self.assertEqual(received_value, expected_response) + + transport.close() + yield from asyncio.sleep(0) + server.server_close() + + @asyncio.coroutine + def testTcpServerConnectionLost(self): + ''' Test tcp stream interruption ''' + data = b"\x01\x00\x00\x00\x00\x06\x01\x01\x00\x00\x00\x01" + server = yield from StartTcpServer(context=self.context,address=("127.0.0.1", 0),loop=self.loop) + if PYTHON_VERSION >= (3, 7): + server_task = asyncio.create_task(server.serve_forever()) + else: + server_task = asyncio.ensure_future(server.serve_forever()) + yield from server.serving + + random_port = server.server.sockets[0].getsockname()[1] # get the random server port + + step1 = self.loop.create_future() + done = self.loop.create_future() + received_value = None + + class BasicClient(asyncio.BaseProtocol): + def connection_made(self, transport): + self.transport = transport + step1.set_result(True) + + transport, protocol = yield from self.loop.create_connection(BasicClient, host='127.0.0.1',port=random_port) + yield from step1 + + self.assertTrue( len(server.active_connections) == 1 ) + + protocol.transport.close() # close isn't synchronous and there's no notification that it's done + # so we have to wait a bit + yield from asyncio.sleep(0.1) + self.assertTrue( len(server.active_connections) == 0 ) + server.server_close() + + @asyncio.coroutine + def testTcpServerCloseActiveConnection(self): + ''' Test server_close() while there are active TCP connections ''' + data = b"\x01\x00\x00\x00\x00\x06\x01\x01\x00\x00\x00\x01" + server = yield from StartTcpServer(context=self.context,address=("127.0.0.1", 0),loop=self.loop) + if PYTHON_VERSION >= (3, 7): + server_task = asyncio.create_task(server.serve_forever()) + else: + server_task = asyncio.ensure_future(server.serve_forever()) + yield from server.serving + + random_port = server.server.sockets[0].getsockname()[1] # get the random server port + + step1 = self.loop.create_future() + done = self.loop.create_future() + received_value = None + + class BasicClient(asyncio.BaseProtocol): + def connection_made(self, transport): + self.transport = transport + step1.set_result(True) + + transport, protocol = yield from self.loop.create_connection(BasicClient, host='127.0.0.1',port=random_port) + yield from step1 + + server.server_close() + + # close isn't synchronous and there's no notification that it's done + # so we have to wait a bit + yield from asyncio.sleep(0.0) + self.assertTrue( len(server.active_connections) == 0 ) + + @asyncio.coroutine + def testTcpServerException(self): + ''' Sending garbage data on a TCP socket should drop the connection ''' + garbage = b'\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF' + server = yield from StartTcpServer(context=self.context,address=("127.0.0.1", 0),loop=self.loop) + if PYTHON_VERSION >= (3, 7): + server_task = asyncio.create_task(server.serve_forever()) + else: + server_task = asyncio.ensure_future(server.serve_forever()) + yield from server.serving + with patch('pymodbus.transaction.ModbusSocketFramer.processIncomingPacket', + new_callable=lambda : Mock(side_effect=Exception)) as process: + connect, receive, eof = self.loop.create_future(),self.loop.create_future(),self.loop.create_future() + received_data = None + random_port = server.server.sockets[0].getsockname()[1] # get the random server port + + class BasicClient(asyncio.BaseProtocol): + def connection_made(self, transport): + _logger.debug("Client connected") + self.transport = transport + transport.write(garbage) + connect.set_result(True) + + def data_received(self, data): + _logger.debug("Client received data") + receive.set_result(True) + received_data = data + + def eof_received(self): + _logger.debug("Client stream eof") + eof.set_result(True) + + transport, protocol = yield from self.loop.create_connection(BasicClient, host='127.0.0.1',port=random_port) + yield from asyncio.wait_for(connect, timeout=0.1) + yield from asyncio.wait_for(eof, timeout=0.1) + # neither of these should timeout if the test is successful + server.server_close() + + @asyncio.coroutine + def testTcpServerNoSlave(self): + ''' Test unknown slave unit exception ''' + context = ModbusServerContext(slaves={0x01: self.store, 0x02: self.store }, single=False) + data = b"\x01\x00\x00\x00\x00\x06\x05\x03\x00\x00\x00\x01" # get slave 5 function 3 (holding register) + server = yield from StartTcpServer(context=context,address=("127.0.0.1", 0),loop=self.loop) + if PYTHON_VERSION >= (3, 7): + server_task = asyncio.create_task(server.serve_forever()) + else: + server_task = asyncio.ensure_future(server.serve_forever()) + yield from server.serving + connect, receive, eof = self.loop.create_future(),self.loop.create_future(),self.loop.create_future() + received_data = None + random_port = server.server.sockets[0].getsockname()[1] # get the random server port + + class BasicClient(asyncio.BaseProtocol): + def connection_made(self, transport): + _logger.debug("Client connected") + self.transport = transport + transport.write(data) + connect.set_result(True) + + def data_received(self, data): + _logger.debug("Client received data") + receive.set_result(True) + received_data = data + + def eof_received(self): + _logger.debug("Client stream eof") + eof.set_result(True) + + transport, protocol = yield from self.loop.create_connection(BasicClient, host='127.0.0.1',port=random_port) + yield from asyncio.wait_for(connect, timeout=0.1) + self.assertFalse(eof.done()) + server.server_close() + + @asyncio.coroutine + def testTcpServerModbusError(self): + ''' Test sending garbage data on a TCP socket should drop the connection ''' + data = b"\x01\x00\x00\x00\x00\x06\x01\x03\x00\x00\x00\x01" # get slave 5 function 3 (holding register) + server = yield from StartTcpServer(context=self.context,address=("127.0.0.1", 0),loop=self.loop) + if PYTHON_VERSION >= (3, 7): + server_task = asyncio.create_task(server.serve_forever()) + else: + server_task = asyncio.ensure_future(server.serve_forever()) + yield from server.serving + with patch("pymodbus.register_read_message.ReadHoldingRegistersRequest.execute", + side_effect=NoSuchSlaveException): + connect, receive, eof = self.loop.create_future(),self.loop.create_future(),self.loop.create_future() + received_data = None + random_port = server.server.sockets[0].getsockname()[1] # get the random server port + + class BasicClient(asyncio.BaseProtocol): + def connection_made(self, transport): + _logger.debug("Client connected") + self.transport = transport + transport.write(data) + connect.set_result(True) + + def data_received(self, data): + _logger.debug("Client received data") + receive.set_result(True) + received_data = data + + def eof_received(self): + _logger.debug("Client stream eof") + eof.set_result(True) + + transport, protocol = yield from self.loop.create_connection(BasicClient, host='127.0.0.1',port=random_port) + yield from asyncio.wait_for(connect, timeout=0.1) + yield from asyncio.wait_for(receive, timeout=0.1) + self.assertFalse(eof.done()) + transport.close() + server.server_close() + + @asyncio.coroutine + def testTcpServerInternalException(self): + ''' Test sending garbage data on a TCP socket should drop the connection ''' + data = b"\x01\x00\x00\x00\x00\x06\x01\x03\x00\x00\x00\x01" # get slave 5 function 3 (holding register) + server = yield from StartTcpServer(context=self.context,address=("127.0.0.1", 0),loop=self.loop) + if PYTHON_VERSION >= (3, 7): + server_task = asyncio.create_task(server.serve_forever()) + else: + server_task = asyncio.ensure_future(server.serve_forever()) + yield from server.serving + with patch("pymodbus.register_read_message.ReadHoldingRegistersRequest.execute", + side_effect=Exception): + connect, receive, eof = self.loop.create_future(),self.loop.create_future(),self.loop.create_future() + received_data = None + random_port = server.server.sockets[0].getsockname()[1] # get the random server port + + class BasicClient(asyncio.BaseProtocol): + def connection_made(self, transport): + _logger.debug("Client connected") + self.transport = transport + transport.write(data) + connect.set_result(True) + + def data_received(self, data): + _logger.debug("Client received data") + receive.set_result(True) + received_data = data + + def eof_received(self): + _logger.debug("Client stream eof") + eof.set_result(True) + + transport, protocol = yield from self.loop.create_connection(BasicClient, host='127.0.0.1',port=random_port) + yield from asyncio.wait_for(connect, timeout=0.1) + yield from asyncio.wait_for(receive, timeout=0.1) + self.assertFalse(eof.done()) + + transport.close() + server.server_close() + + + + #-----------------------------------------------------------------------# + # Test ModbusUdpProtocol + #-----------------------------------------------------------------------# + + @asyncio.coroutine + def testStartUdpServer(self): + ''' Test that the modbus udp asyncio server starts correctly ''' + identity = ModbusDeviceIdentification(info={0x00: 'VendorName'}) + self.loop = asynctest.Mock(self.loop) + server = yield from StartUdpServer(context=self.context,loop=self.loop,identity=identity) + self.assertEqual(server.control.Identity.VendorName, 'VendorName') + if PYTHON_VERSION >= (3, 6): + self.loop.create_datagram_endpoint.assert_called_once() + + # async def testUdpServerServeNoDefer(self): + # ''' Test StartUdpServer without deferred start - NOT IMPLEMENTED - this test is hard to do without additional + # internal plumbing added to the implementation ''' + # asyncio.base_events.Server.serve_forever = asynctest.CoroutineMock() + # server = yield from StartUdpServer(address=("127.0.0.1", 0), loop=self.loop, defer_start=False) + # server.server.serve_forever.assert_awaited() + + @pytest.mark.skipif(PYTHON_VERSION < (3, 7), reason="requires python3.7 or above") + @asyncio.coroutine + def testUdpServerServeForeverStart(self): + ''' Test StartUdpServer serve_forever() method ''' + with patch('asyncio.base_events.Server.serve_forever', new_callable=asynctest.CoroutineMock) as serve: + server = yield from StartTcpServer(context=self.context,address=("127.0.0.1", 0), loop=self.loop) + yield from server.serve_forever() + serve.assert_awaited() + + @asyncio.coroutine + def testUdpServerServeForeverClose(self): + ''' Test StartUdpServer serve_forever() method ''' + server = yield from StartUdpServer(context=self.context,address=("127.0.0.1", 0), loop=self.loop) + if PYTHON_VERSION >= (3, 7): + server_task = asyncio.create_task(server.serve_forever()) + else: + server_task = asyncio.ensure_future(server.serve_forever()) + yield from server.serving + + self.assertTrue(asyncio.isfuture(server.on_connection_terminated)) + self.assertFalse(server.on_connection_terminated.done()) + + server.server_close() + self.assertTrue(server.protocol.is_closing()) + + @asyncio.coroutine + def testUdpServerServeForeverTwice(self): + ''' Call on serve_forever() twice should result in a runtime error ''' + identity = ModbusDeviceIdentification(info={0x00: 'VendorName'}) + server = yield from StartUdpServer(context=self.context,address=("127.0.0.1", 0), + loop=self.loop,identity=identity) + if PYTHON_VERSION >= (3, 7): + server_task = asyncio.create_task(server.serve_forever()) + else: + server_task = asyncio.ensure_future(server.serve_forever()) + yield from server.serving + with self.assertRaises(RuntimeError): + yield from server.serve_forever() + server.server_close() + + @asyncio.coroutine + def testUdpServerReceiveData(self): + ''' Test that the sending data on datagram socket gets data pushed to framer ''' + server = yield from StartUdpServer(context=self.context,address=("127.0.0.1", 0),loop=self.loop) + if PYTHON_VERSION >= (3, 7): + server_task = asyncio.create_task(server.serve_forever()) + else: + server_task = asyncio.ensure_future(server.serve_forever()) + yield from server.serving + with patch('pymodbus.transaction.ModbusSocketFramer.processIncomingPacket',new_callable=Mock) as process: + + server.endpoint.datagram_received(data=b"12345", addr=("127.0.0.1", 12345)) + yield from asyncio.sleep(0.1) + process.seal() + + if PYTHON_VERSION >= (3, 6): + process.assert_called_once() + self.assertTrue( process.call_args[1]["data"] == b"12345" ) + + server.server_close() + + @asyncio.coroutine + def testUdpServerSendData(self): + ''' Test that the modbus udp asyncio server correctly sends data outbound ''' + identity = ModbusDeviceIdentification(info={0x00: 'VendorName'}) + data = b'x\01\x00\x00\x00\x00\x06\x01\x03\x00\x00\x00\x19' + server = yield from StartUdpServer(context=self.context,address=("127.0.0.1", 0)) + if PYTHON_VERSION >= (3, 7): + server_task = asyncio.create_task(server.serve_forever()) + else: + server_task = asyncio.ensure_future(server.serve_forever()) + yield from server.serving + random_port = server.protocol._sock.getsockname()[1] + received = server.endpoint.datagram_received = Mock(wraps=server.endpoint.datagram_received) + done = self.loop.create_future() + received_value = None + + class BasicClient(asyncio.DatagramProtocol): + def connection_made(self, transport): + self.transport = transport + self.transport.sendto(data) + + def datagram_received(self, data, addr): + nonlocal received_value, done + print("received") + received_value = data + done.set_result(True) + self.transport.close() + + transport, protocol = yield from self.loop.create_datagram_endpoint( BasicClient, + remote_addr=('127.0.0.1', random_port)) + + yield from asyncio.sleep(0.1) + + if PYTHON_VERSION >= (3, 6): + received.assert_called_once() + self.assertEqual(received.call_args[0][0], data) + + server.server_close() + + self.assertTrue(server.protocol.is_closing()) + yield from asyncio.sleep(0.1) + + @asyncio.coroutine + def testUdpServerRoundtrip(self): + ''' Test sending and receiving data on udp socket''' + data = b"\x01\x00\x00\x00\x00\x06\x01\x03\x00\x00\x00\x01" # unit 1, read register + expected_response = b'\x01\x00\x00\x00\x00\x05\x01\x03\x02\x00\x11' # value of 17 as per context + server = yield from StartUdpServer(context=self.context,address=("127.0.0.1", 0),loop=self.loop) + if PYTHON_VERSION >= (3, 7): + server_task = asyncio.create_task(server.serve_forever()) + else: + server_task = asyncio.ensure_future(server.serve_forever()) + yield from server.serving + + random_port = server.protocol._sock.getsockname()[1] + + connected, done = self.loop.create_future(),self.loop.create_future() + received_value = None + + class BasicClient(asyncio.DatagramProtocol): + def connection_made(self, transport): + self.transport = transport + self.transport.sendto(data) + + def datagram_received(self, data, addr): + nonlocal received_value, done + print("received") + received_value = data + done.set_result(True) + + transport, protocol = yield from self.loop.create_datagram_endpoint( BasicClient, + remote_addr=('127.0.0.1', random_port)) + yield from asyncio.wait_for(done, timeout=0.1) + + self.assertEqual(received_value, expected_response) + + transport.close() + yield from asyncio.sleep(0) + server.server_close() + + @asyncio.coroutine + def testUdpServerException(self): + ''' Test sending garbage data on a TCP socket should drop the connection ''' + garbage = b'\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF' + server = yield from StartUdpServer(context=self.context,address=("127.0.0.1", 0),loop=self.loop) + if PYTHON_VERSION >= (3, 7): + server_task = asyncio.create_task(server.serve_forever()) + else: + server_task = asyncio.ensure_future(server.serve_forever()) + yield from server.serving + with patch('pymodbus.transaction.ModbusSocketFramer.processIncomingPacket', + new_callable=lambda: Mock(side_effect=Exception)) as process: + connect, receive, eof = self.loop.create_future(),self.loop.create_future(),self.loop.create_future() + received_data = None + random_port = server.protocol._sock.getsockname()[1] # get the random server port + + class BasicClient(asyncio.DatagramProtocol): + def connection_made(self, transport): + _logger.debug("Client connected") + self.transport = transport + transport.sendto(garbage) + connect.set_result(True) + + def datagram_received(self, data, addr): + nonlocal receive + _logger.debug("Client received data") + receive.set_result(True) + received_data = data + + transport, protocol = yield from self.loop.create_datagram_endpoint(BasicClient, + remote_addr=('127.0.0.1', random_port)) + yield from asyncio.wait_for(connect, timeout=0.1) + self.assertFalse(receive.done()) + self.assertFalse(server.protocol._sock._closed) + server.server_close() + + # -----------------------------------------------------------------------# + # Test ModbusServerFactory + # -----------------------------------------------------------------------# + def testModbusServerFactory(self): + ''' Test the base class for all the clients ''' + with self.assertWarns(DeprecationWarning): + factory = ModbusServerFactory(store=None) + + def testStopServer(self): + with self.assertWarns(DeprecationWarning): + StopServer() + + +# --------------------------------------------------------------------------- # +# Main +# --------------------------------------------------------------------------- # +if __name__ == "__main__": + asynctest.main() diff --git a/test/test_server_sync.py b/test/test_server_sync.py index 8134cf5fc..74ba0cfa5 100644 --- a/test/test_server_sync.py +++ b/test/test_server_sync.py @@ -7,14 +7,15 @@ from mock import patch, Mock import serial import socket +import ssl from pymodbus.device import ModbusDeviceIdentification from pymodbus.server.sync import ModbusBaseRequestHandler from pymodbus.server.sync import ModbusSingleRequestHandler from pymodbus.server.sync import ModbusConnectedRequestHandler from pymodbus.server.sync import ModbusDisconnectedRequestHandler -from pymodbus.server.sync import ModbusTcpServer, ModbusUdpServer, ModbusSerialServer -from pymodbus.server.sync import StartTcpServer, StartUdpServer, StartSerialServer +from pymodbus.server.sync import ModbusTcpServer, ModbusTlsServer, ModbusUdpServer, ModbusSerialServer +from pymodbus.server.sync import StartTcpServer, StartTlsServer, StartUdpServer, StartSerialServer from pymodbus.exceptions import NotImplementedException from pymodbus.bit_read_message import ReadCoilsRequest, ReadCoilsResponse from pymodbus.datastore import ModbusServerContext @@ -274,6 +275,44 @@ def testTcpServerProcess(self): server.process_request('request', 'client') self.assertTrue(mock_server.process_request.called) + #-----------------------------------------------------------------------# + # Test TLS Server + #-----------------------------------------------------------------------# + def testTlsServerInit(self): + ''' test that the synchronous TLS server intial correctly ''' + with patch.object(socket.socket, 'bind') as mock_socket: + with patch.object(ssl.SSLContext, 'load_cert_chain') as mock_method: + identity = ModbusDeviceIdentification(info={0x00: 'VendorName'}) + server = ModbusTlsServer(context=None, identity=identity) + self.assertIsNotNone(server.sslctx) + self.assertEqual(type(server.socket), ssl.SSLSocket) + server.server_close() + sslctx = ssl.create_default_context() + server = ModbusTlsServer(context=None, identity=identity, + sslctx=sslctx) + self.assertEqual(server.sslctx, sslctx) + self.assertEqual(type(server.socket), ssl.SSLSocket) + server.server_close() + + def testTlsServerClose(self): + ''' test that the synchronous TLS server closes correctly ''' + with patch.object(socket.socket, 'bind') as mock_socket: + with patch.object(ssl.SSLContext, 'load_cert_chain') as mock_method: + identity = ModbusDeviceIdentification(info={0x00: 'VendorName'}) + server = ModbusTlsServer(context=None, identity=identity) + server.threads.append(Mock(**{'running': True})) + server.server_close() + self.assertEqual(server.control.Identity.VendorName, 'VendorName') + self.assertFalse(server.threads[0].running) + + def testTlsServerProcess(self): + ''' test that the synchronous TLS server processes requests ''' + with patch('pymodbus.compat.socketserver.ThreadingTCPServer') as mock_server: + with patch.object(ssl.SSLContext, 'load_cert_chain') as mock_method: + server = ModbusTlsServer(None) + server.process_request('request', 'client') + self.assertTrue(mock_server.process_request.called) + #-----------------------------------------------------------------------# # Test UDP Server #-----------------------------------------------------------------------# @@ -347,6 +386,13 @@ def testStartTcpServer(self): with patch.object(socketserver.TCPServer, 'server_bind') as mock_binder: StartTcpServer() + def testStartTlsServer(self): + ''' Test the tls server starting factory ''' + with patch.object(ModbusTlsServer, 'serve_forever') as mock_server: + with patch.object(socketserver.TCPServer, 'server_bind') as mock_binder: + with patch.object(ssl.SSLContext, 'load_cert_chain') as mock_method: + StartTlsServer() + def testStartUdpServer(self): ''' Test the udp server starting factory ''' with patch.object(ModbusUdpServer, 'serve_forever') as mock_server: diff --git a/test/test_transaction.py b/test/test_transaction.py index 85e00223c..7c25a1f8c 100644 --- a/test/test_transaction.py +++ b/test/test_transaction.py @@ -5,8 +5,8 @@ from pymodbus.pdu import * from pymodbus.transaction import * from pymodbus.transaction import ( - ModbusTransactionManager, ModbusSocketFramer, ModbusAsciiFramer, - ModbusRtuFramer, ModbusBinaryFramer + ModbusTransactionManager, ModbusSocketFramer, ModbusTlsFramer, + ModbusAsciiFramer, ModbusRtuFramer, ModbusBinaryFramer ) from pymodbus.factory import ServerDecoder from pymodbus.compat import byte2int @@ -29,6 +29,7 @@ def setUp(self): self.client = None self.decoder = ServerDecoder() self._tcp = ModbusSocketFramer(decoder=self.decoder, client=None) + self._tls = ModbusTlsFramer(decoder=self.decoder, client=None) self._rtu = ModbusRtuFramer(decoder=self.decoder, client=None) self._ascii = ModbusAsciiFramer(decoder=self.decoder, client=None) self._binary = ModbusBinaryFramer(decoder=self.decoder, client=None) @@ -40,6 +41,7 @@ def tearDown(self): """ Cleans up the test environment """ del self._manager del self._tcp + del self._tls del self._rtu del self._ascii @@ -60,6 +62,7 @@ def testCalculateExceptionLength(self): ('binary', 7), ('rtu', 5), ('tcp', 9), + ('tls', 2), ('dummy', None)]: self._tm.client = MagicMock() if framer == "ascii": @@ -70,6 +73,8 @@ def testCalculateExceptionLength(self): self._tm.client.framer = self._rtu elif framer == "tcp": self._tm.client.framer = self._tcp + elif framer == "tls": + self._tm.client.framer = self._tls else: self._tm.client.framer = MagicMock() @@ -303,6 +308,137 @@ def testTCPFramerPacket(self): self.assertEqual(expected, actual) ModbusRequest.encode = old_encode + # ----------------------------------------------------------------------- # + # TLS tests + # ----------------------------------------------------------------------- # + def testTLSFramerTransactionReady(self): + """ Test a tls frame transaction """ + msg = b"\x01\x12\x34\x00\x08" + self.assertFalse(self._tls.isFrameReady()) + self.assertFalse(self._tls.checkFrame()) + self._tls.addToFrame(msg) + self.assertTrue(self._tls.isFrameReady()) + self.assertTrue(self._tls.checkFrame()) + self._tls.advanceFrame() + self.assertFalse(self._tls.isFrameReady()) + self.assertFalse(self._tls.checkFrame()) + self.assertEqual(b'', self._tls.getFrame()) + + def testTLSFramerTransactionFull(self): + """ Test a full tls frame transaction """ + msg = b"\x01\x12\x34\x00\x08" + self._tls.addToFrame(msg) + self.assertTrue(self._tls.checkFrame()) + result = self._tls.getFrame() + self.assertEqual(msg[0:], result) + self._tls.advanceFrame() + + def testTLSFramerTransactionHalf(self): + """ Test a half completed tls frame transaction """ + msg1 = b"" + msg2 = b"\x01\x12\x34\x00\x08" + self._tls.addToFrame(msg1) + self.assertFalse(self._tls.checkFrame()) + result = self._tls.getFrame() + self.assertEqual(b'', result) + self._tls.addToFrame(msg2) + self.assertTrue(self._tls.checkFrame()) + result = self._tls.getFrame() + self.assertEqual(msg2[0:], result) + self._tls.advanceFrame() + + def testTLSFramerTransactionShort(self): + """ Test that we can get back on track after an invalid message """ + msg1 = b"" + msg2 = b"\x01\x12\x34\x00\x08" + self._tls.addToFrame(msg1) + self.assertFalse(self._tls.checkFrame()) + result = self._tls.getFrame() + self.assertEqual(b'', result) + self._tls.advanceFrame() + self._tls.addToFrame(msg2) + self.assertEqual(5, len(self._tls._buffer)) + self.assertTrue(self._tls.checkFrame()) + result = self._tls.getFrame() + self.assertEqual(msg2[0:], result) + self._tls.advanceFrame() + + def testTLSFramerDecode(self): + """ Testmessage decoding """ + msg1 = b"" + msg2 = b"\x01\x12\x34\x00\x08" + result = self._tls.decode_data(msg1) + self.assertEqual(dict(), result); + result = self._tls.decode_data(msg2) + self.assertEqual(dict(fcode=1), result); + self._tls.advanceFrame() + + def testTLSIncomingPacket(self): + msg = b"\x01\x12\x34\x00\x08" + + unit = 0x01 + def mock_callback(self): + pass + + self._tls._process = MagicMock() + self._tls.isFrameReady = MagicMock(return_value=False) + self._tls.processIncomingPacket(msg, mock_callback, unit) + self.assertEqual(msg, self._tls.getRawFrame()) + self._tls.advanceFrame() + + self._tls.isFrameReady = MagicMock(return_value=True) + self._tls._validate_unit_id = MagicMock(return_value=False) + self._tls.processIncomingPacket(msg, mock_callback, unit) + self.assertEqual(b'', self._tls.getRawFrame()) + self._tls.advanceFrame() + + self._tls._validate_unit_id = MagicMock(return_value=True) + self._tls.processIncomingPacket(msg, mock_callback, unit) + self.assertEqual(msg, self._tls.getRawFrame()) + self._tls.advanceFrame() + + def testTLSProcess(self): + class MockResult(object): + def __init__(self, code): + self.function_code = code + + def mock_callback(self): + pass + + self._tls.decoder.decode = MagicMock(return_value=None) + self.assertRaises(ModbusIOException, + lambda: self._tls._process(mock_callback)) + + result = MockResult(0x01) + self._tls.decoder.decode = MagicMock(return_value=result) + self.assertRaises(InvalidMessageReceivedException, + lambda: self._tls._process(mock_callback, error=True)) + + self._tls._process(mock_callback) + self.assertEqual(b'', self._tls.getRawFrame()) + + def testTLSFramerPopulate(self): + """ Test a tls frame packet build """ + expected = ModbusRequest() + msg = b"\x01\x12\x34\x00\x08" + self._tls.addToFrame(msg) + self.assertTrue(self._tls.checkFrame()) + actual = ModbusRequest() + result = self._tls.populateResult(actual) + self.assertEqual(None, result) + self._tls.advanceFrame() + + def testTLSFramerPacket(self): + """ Test a tls frame packet build """ + old_encode = ModbusRequest.encode + ModbusRequest.encode = lambda self: b'' + message = ModbusRequest() + message.function_code = 0x01 + expected = b"\x01" + actual = self._tls.buildPacket(message) + self.assertEqual(expected, actual) + ModbusRequest.encode = old_encode + # ----------------------------------------------------------------------- # # RTU tests # ----------------------------------------------------------------------- # diff --git a/tox.ini b/tox.ini index f6b90e308..909d6a74d 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,7 @@ # directory. [tox] -envlist = py27, py34, py35, py36, py37, pypy +envlist = py27, py35, py36, py37, pypy [testenv] deps = -r requirements-tests.txt