Skip to content

Commit

Permalink
Fix pymodbus TLS module conflicts in 3.0.0 (#661)
Browse files Browse the repository at this point in the history
Developers add/fix features at the same time, then produce the conflicts
in pymodbus' TLS module. This patch tries to fix the conflicts.
  • Loading branch information
starnight authored Aug 9, 2021
1 parent 9e37031 commit e6674ec
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 39 deletions.
14 changes: 5 additions & 9 deletions pymodbus/client/asynchronous/async_io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,8 @@ def __init__(self, protocol_class=None, loop=None, framer=None):
self.framer = framer
ReconnectingAsyncioModbusTcpClient.__init__(self, protocol_class, loop)

async def start(self, host, port=802, sslctx=None,
server_hostname=None, certfile=None, keyfile=None,
password=None, **kwargs):
async def start(self, host='localhost', port=802, sslctx=None,
certfile=None, keyfile=None, password=None, **kwargs):
"""
Initiates connection to start client
:param host: The host to connect to (default localhost)
Expand All @@ -463,7 +462,6 @@ async def start(self, host, port=802, sslctx=None,
:param password: The password for for decrypting client's private key file
"""
self.sslctx = sslctx_provider(sslctx, certfile, keyfile, password)
self.server_hostname = server_hostname
return await ReconnectingAsyncioModbusTcpClient.start(self, host, port)

async def _connect(self):
Expand Down Expand Up @@ -840,8 +838,8 @@ async def init_tcp_client(proto_cls, loop, host, port, **kwargs):


async def init_tls_client(proto_cls, loop, host, port, sslctx=None,
server_hostname=None, certfile=None, keyfile=None,
password=None, framer=None, **kwargs):
certfile=None, keyfile=None, password=None,
framer=None, **kwargs):
"""
Helper function to initialize tcp client
:param proto_cls:
Expand All @@ -858,9 +856,7 @@ async def init_tls_client(proto_cls, loop, host, port, sslctx=None,
"""
client = ReconnectingAsyncioModbusTlsClient(protocol_class=proto_cls,
loop=loop, framer=framer)
await client.start(host, port, sslctx, server_hostname=server_hostname,
certfile=certfile, keyfile=keyfile, password=password,
**kwargs)
await client.start(host, port, sslctx, certfile, keyfile, password)
return client


Expand Down
4 changes: 3 additions & 1 deletion pymodbus/client/asynchronous/factory/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def async_io_factory(host="127.0.0.1", port=Defaults.TLSPort, sslctx=None,
framer=framer)
client = loop.run_until_complete(asyncio.gather(cor))[0]
elif loop is asyncio.get_event_loop():
return loop, init_tls_client(proto_cls, loop, host, port)
return loop, init_tls_client(proto_cls, loop, host, port,
sslctx, certfile, keyfile, password,
framer)
else:
cor = init_tls_client(proto_cls, loop, host, port,
sslctx, certfile, keyfile, password, framer)
Expand Down
40 changes: 18 additions & 22 deletions pymodbus/server/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,10 @@ class ModbusTlsServer(ModbusTcpServer):
server context instance.
"""

def __init__(self, context, framer=None, identity=None,
address=None, handler=None, allow_reuse_address=False,
sslctx=None, certfile=None, keyfile=None, **kwargs):
def __init__(self, context, framer=None, identity=None, address=None,
sslctx=None, certfile=None, keyfile=None, password=None,
reqclicert=False, handler=None, allow_reuse_address=False,
**kwargs):
""" Overloaded initializer for the ModbusTcpServer
If the identify structure is not passed in, the ModbusControlBlock
Expand All @@ -388,32 +389,24 @@ def __init__(self, context, framer=None, identity=None,
:param framer: The framer strategy to use
:param identity: An optional identify structure
:param address: An optional (interface, port) to bind to.
:param handler: A handler for each client session; default is
ModbusConnectedRequestHandler
:param allow_reuse_address: Whether the server will allow the
reuse of an address.
:param sslctx: The SSLContext to use for TLS (default None and auto
create)
:param certfile: The cert file path for TLS (used if sslctx is None)
:param keyfile: The key file path for TLS (used if sslctx is None)
:param password: The password for for decrypting the private key file
:param reqclicert: Force the sever request client's certificate
:param handler: A handler for each client session; default is
ModbusConnectedRequestHandler
:param allow_reuse_address: Whether the server will allow the
reuse of an address.
:param ignore_missing_slaves: True to not send errors on a request
to a missing slave
:param broadcast_enable: True to treat unit_id 0 as broadcast address,
False to treat 0 as any other unit_id
"""
framer = framer or ModbusTlsFramer
self.sslctx = sslctx
if self.sslctx is None:
self.sslctx = ssl.create_default_context()
self.sslctx.load_cert_chain(certfile=certfile, keyfile=keyfile)
# According to MODBUS/TCP Security Protocol Specification, it is
# TLSv2 at least
self.sslctx.options |= ssl.OP_NO_TLSv1_1
self.sslctx.options |= ssl.OP_NO_TLSv1
self.sslctx.options |= ssl.OP_NO_SSLv3
self.sslctx.options |= ssl.OP_NO_SSLv2
self.sslctx.verify_mode = ssl.CERT_OPTIONAL
self.sslctx.check_hostname = False
self.sslctx = sslctx_provider(sslctx, certfile, keyfile, password,
reqclicert)

ModbusTcpServer.__init__(self, context, framer, identity, address,
handler, allow_reuse_address, **kwargs)
Expand Down Expand Up @@ -627,7 +620,8 @@ def StartTcpServer(context=None, identity=None, address=None,


def StartTlsServer(context=None, identity=None, address=None, sslctx=None,
certfile=None, keyfile=None, custom_functions=[], **kwargs):
certfile=None, keyfile=None, password=None, reqclicert=False,
custom_functions=[], **kwargs):
""" A factory to start and run a tls modbus server
:param context: The ModbusServerContext datastore
Expand All @@ -636,14 +630,16 @@ def StartTlsServer(context=None, identity=None, address=None, sslctx=None,
:param sslctx: The SSLContext to use for TLS (default None and auto create)
:param certfile: The cert file path for TLS (used if sslctx is None)
:param keyfile: The key file path for TLS (used if sslctx is None)
:param password: The password for for decrypting the private key file
:param reqclicert: Force the sever request client's certificate
:param custom_functions: An optional list of custom function classes
supported by server instance.
:param ignore_missing_slaves: True to not send errors on a request to a
missing slave
"""
framer = kwargs.pop("framer", ModbusTlsFramer)
server = ModbusTlsServer(context, framer, identity, address, sslctx=sslctx,
certfile=certfile, keyfile=keyfile, **kwargs)
server = ModbusTlsServer(context, framer, identity, address, sslctx,
certfile, keyfile, password, reqclicert, **kwargs)

for f in custom_functions:
server.decoder.register(f)
Expand Down
22 changes: 15 additions & 7 deletions test/test_server_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pymodbus.server.sync import ModbusDisconnectedRequestHandler
from pymodbus.server.sync import ModbusTcpServer, ModbusTlsServer, ModbusUdpServer, ModbusSerialServer
from pymodbus.server.sync import StartTcpServer, StartTlsServer, StartUdpServer, StartSerialServer
from pymodbus.server.tls_helper import sslctx_provider
from pymodbus.exceptions import NotImplementedException
from pymodbus.bit_read_message import ReadCoilsRequest, ReadCoilsResponse
from pymodbus.datastore import ModbusServerContext
Expand Down Expand Up @@ -278,23 +279,30 @@ def testTcpServerProcess(self):
#-----------------------------------------------------------------------#
# Test TLS Server
#-----------------------------------------------------------------------#
def testTlsSSLCTX_Provider(self):
''' test that sslctx_provider() produce SSLContext correctly '''
with patch.object(ssl.SSLContext, 'load_cert_chain'):
sslctx = sslctx_provider(reqclicert=True)
self.assertIsNotNone(sslctx)
self.assertEqual(type(sslctx), ssl.SSLContext)
self.assertEqual(sslctx.verify_mode, ssl.CERT_REQUIRED)

sslctx_old = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
sslctx_new = sslctx_provider(sslctx=sslctx_old)
self.assertEqual(sslctx_new, sslctx_old)

def testTlsServerInit(self):
''' test that the synchronous TLS server initial correctly '''
with patch.object(socketserver.TCPServer, 'server_activate'):
with patch.object(ssl.SSLContext, 'load_cert_chain') as mock_method:
identity = ModbusDeviceIdentification(info={0x00: 'VendorName'})
server = ModbusTlsServer(context=None, identity=identity,
reqclicert=True,
bind_and_activate=False)
self.assertIs(server.framer, ModbusTlsFramer)
server.server_activate()
self.assertIsNotNone(server.sslctx)
self.assertEqual(type(server.socket), ssl.SSLSocket)
server.server_close()
sslctx = ssl.create_default_context()
server = ModbusTlsServer(context=None, identity=identity,
sslctx=sslctx, bind_and_activate=False)
server.server_activate()
self.assertEqual(server.sslctx, sslctx)
self.assertEqual(server.sslctx.verify_mode, ssl.CERT_REQUIRED)
self.assertEqual(type(server.socket), ssl.SSLSocket)
server.server_close()

Expand Down

0 comments on commit e6674ec

Please sign in to comment.