Skip to content

Commit

Permalink
Add server requiring client's cert in TLS handshake feature
Browse files Browse the repository at this point in the history
This patch adds server requiring client's certificate feature which is
mentioned in the 6th step CertificateRequest to 9th step
VerifyClientCertSig in Table 5 TLS Full Handshake Protocol of MODBUS/TCP
Security Protocol Specification [1],

This feature is implemented with an optional argument "reqclicert" of
StartTlsServer() in both sync and async_io. So, users can force server
require client's certificate, or according to the SSL Context's original
behavior [2].

This fixes part of pymodbus-dev#606

[1]: http://modbus.org/docs/MB-TCP-Security-v21_2018-07-24.pdf
[2]: https://docs.python.org/3/library/ssl.html#ssl.SSLContext.verify_mode
  • Loading branch information
starnight committed Mar 7, 2021
1 parent 7e8e7cf commit 96458c8
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
11 changes: 7 additions & 4 deletions pymodbus/server/async_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ def __init__(self,
sslctx=None,
certfile=None,
keyfile=None,
reqclicert=False,
handler=None,
allow_reuse_address=False,
allow_reuse_port=False,
Expand All @@ -537,6 +538,7 @@ def __init__(self,
create)
:param certfile: The cert file path for TLS (used if sslctx is None)
:param keyfile: The key file path for TLS (used if sslctx is None)
:param reqclicert: Force the sever request client's certificate
:param handler: A handler for each client session; default is
ModbusConnectedRequestHandler. The handler class
receives connection create/teardown events
Expand Down Expand Up @@ -585,8 +587,8 @@ def __init__(self,
self.sslctx.options |= ssl.OP_NO_TLSv1
self.sslctx.options |= ssl.OP_NO_SSLv3
self.sslctx.options |= ssl.OP_NO_SSLv2
self.sslctx.verify_mode = ssl.CERT_OPTIONAL
self.sslctx.check_hostname = False
if reqclicert:
self.sslctx.verify_mode = ssl.CERT_REQUIRED
# asyncio future that will be done once server has started
self.serving = self.loop.create_future()
# constructors cannot be declared async, so we have to
Expand Down Expand Up @@ -838,7 +840,7 @@ async def StartTcpServer(context=None, identity=None, address=None,

async def StartTlsServer(context=None, identity=None, address=None,
sslctx=None,
certfile=None, keyfile=None,
certfile=None, keyfile=None, reqclicert=False,
allow_reuse_address=False,
allow_reuse_port=False,
custom_functions=[],
Expand All @@ -851,6 +853,7 @@ async def StartTlsServer(context=None, identity=None, address=None,
:param sslctx: The SSLContext to use for TLS (default None and auto create)
:param certfile: The cert file path for TLS (used if sslctx is None)
:param keyfile: The key file path for TLS (used if sslctx is None)
:param reqclicert: Force the sever request client's certificate
:param allow_reuse_address: Whether the server will allow the reuse of an
address.
:param allow_reuse_port: Whether the server will allow the reuse of a port.
Expand All @@ -865,7 +868,7 @@ async def StartTlsServer(context=None, identity=None, address=None,
"""
framer = kwargs.pop("framer", ModbusTlsFramer)
server = ModbusTlsServer(context, framer, identity, address, sslctx,
certfile, keyfile,
certfile, keyfile, reqclicert=reqclicert,
allow_reuse_address=allow_reuse_address,
allow_reuse_port=allow_reuse_port, **kwargs)

Expand Down
15 changes: 10 additions & 5 deletions pymodbus/server/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,8 @@ class ModbusTlsServer(ModbusTcpServer):

def __init__(self, context, framer=None, identity=None,
address=None, handler=None, allow_reuse_address=False,
sslctx=None, certfile=None, keyfile=None, **kwargs):
sslctx=None, certfile=None, keyfile=None, reqclicert=False,
**kwargs):
""" Overloaded initializer for the ModbusTcpServer
If the identify structure is not passed in, the ModbusControlBlock
Expand All @@ -394,6 +395,7 @@ def __init__(self, context, framer=None, identity=None,
create)
:param certfile: The cert file path for TLS (used if sslctx is None)
:param keyfile: The key file path for TLS (used if sslctx is None)
:param reqclicert: Force the sever request client's certificate
:param ignore_missing_slaves: True to not send errors on a request
to a missing slave
:param broadcast_enable: True to treat unit_id 0 as broadcast address,
Expand All @@ -409,8 +411,8 @@ def __init__(self, context, framer=None, identity=None,
self.sslctx.options |= ssl.OP_NO_TLSv1
self.sslctx.options |= ssl.OP_NO_SSLv3
self.sslctx.options |= ssl.OP_NO_SSLv2
self.sslctx.verify_mode = ssl.CERT_OPTIONAL
self.sslctx.check_hostname = False
if reqclicert:
self.sslctx.verify_mode = ssl.CERT_REQUIRED

ModbusTcpServer.__init__(self, context, framer, identity, address,
handler, allow_reuse_address, **kwargs)
Expand Down Expand Up @@ -621,7 +623,8 @@ def StartTcpServer(context=None, identity=None, address=None,


def StartTlsServer(context=None, identity=None, address=None, sslctx=None,
certfile=None, keyfile=None, custom_functions=[], **kwargs):
certfile=None, keyfile=None, reqclicert=False,
custom_functions=[], **kwargs):
""" A factory to start and run a tls modbus server
:param context: The ModbusServerContext datastore
Expand All @@ -630,14 +633,16 @@ def StartTlsServer(context=None, identity=None, address=None, sslctx=None,
:param sslctx: The SSLContext to use for TLS (default None and auto create)
:param certfile: The cert file path for TLS (used if sslctx is None)
:param keyfile: The key file path for TLS (used if sslctx is None)
:param reqclicert: Force the sever request client's certificate
:param custom_functions: An optional list of custom function classes
supported by server instance.
:param ignore_missing_slaves: True to not send errors on a request to a
missing slave
"""
framer = kwargs.pop("framer", ModbusTlsFramer)
server = ModbusTlsServer(context, framer, identity, address, sslctx=sslctx,
certfile=certfile, keyfile=keyfile, **kwargs)
certfile=certfile, keyfile=keyfile,
reqclicert=reqclicert, **kwargs)

for f in custom_functions:
server.decoder.register(f)
Expand Down
5 changes: 4 additions & 1 deletion test/test_server_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,10 @@ def testTlsServerInit(self):
with patch.object(socket.socket, 'bind') as mock_socket:
with patch.object(ssl.SSLContext, 'load_cert_chain') as mock_method:
identity = ModbusDeviceIdentification(info={0x00: 'VendorName'})
server = ModbusTlsServer(context=None, identity=identity)
server = ModbusTlsServer(context=None, identity=identity,
reqclicert=True)
self.assertIsNotNone(server.sslctx)
self.assertEqual(server.sslctx.verify_mode, ssl.CERT_REQUIRED)
self.assertEqual(type(server.socket), ssl.SSLSocket)
server.server_close()
sslctx = ssl.create_default_context()
Expand Down Expand Up @@ -386,6 +388,7 @@ def testStartTcpServer(self):
with patch.object(socketserver.TCPServer, 'server_bind') as mock_binder:
StartTcpServer()

@patch.dict(ssl.SSLContext, {'server_hostname': 'test.com'})
def testStartTlsServer(self):
''' Test the tls server starting factory '''
with patch.object(ModbusTlsServer, 'serve_forever') as mock_server:
Expand Down

0 comments on commit 96458c8

Please sign in to comment.