Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add server requiring client's cert in TLS handshake feature #1

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pymodbus/client/asynchronous/async_io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def start(self, host, port=802, sslctx=None, server_hostname=None):
"""
self.sslctx = sslctx
if self.sslctx is None:
self.sslctx = ssl.create_default_context()
self.sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
# According to MODBUS/TCP Security Protocol Specification, it is
# TLSv2 at least
self.sslctx.options |= ssl.OP_NO_TLSv1_1
Expand Down
2 changes: 1 addition & 1 deletion pymodbus/client/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def __init__(self, host='localhost', port=Defaults.TLSPort, sslctx=None,
"""
self.sslctx = sslctx
if self.sslctx is None:
self.sslctx = ssl.create_default_context()
self.sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
# According to MODBUS/TCP Security Protocol Specification, it is
# TLSv2 at least
self.sslctx.options |= ssl.OP_NO_TLSv1_1
Expand Down
13 changes: 8 additions & 5 deletions pymodbus/server/async_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ def __init__(self,
sslctx=None,
certfile=None,
keyfile=None,
reqclicert=False,
handler=None,
allow_reuse_address=False,
allow_reuse_port=False,
Expand All @@ -537,6 +538,7 @@ def __init__(self,
create)
:param certfile: The cert file path for TLS (used if sslctx is None)
:param keyfile: The key file path for TLS (used if sslctx is None)
:param reqclicert: Force the sever request client's certificate
:param handler: A handler for each client session; default is
ModbusConnectedRequestHandler. The handler class
receives connection create/teardown events
Expand Down Expand Up @@ -577,16 +579,16 @@ def __init__(self,

self.sslctx = sslctx
if self.sslctx is None:
self.sslctx = ssl.create_default_context()
self.sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
self.sslctx.load_cert_chain(certfile=certfile, keyfile=keyfile)
# According to MODBUS/TCP Security Protocol Specification, it is
# TLSv2 at least
self.sslctx.options |= ssl.OP_NO_TLSv1_1
self.sslctx.options |= ssl.OP_NO_TLSv1
self.sslctx.options |= ssl.OP_NO_SSLv3
self.sslctx.options |= ssl.OP_NO_SSLv2
self.sslctx.verify_mode = ssl.CERT_OPTIONAL
self.sslctx.check_hostname = False
if reqclicert:
self.sslctx.verify_mode = ssl.CERT_REQUIRED
# asyncio future that will be done once server has started
self.serving = self.loop.create_future()
# constructors cannot be declared async, so we have to
Expand Down Expand Up @@ -838,7 +840,7 @@ async def StartTcpServer(context=None, identity=None, address=None,

async def StartTlsServer(context=None, identity=None, address=None,
sslctx=None,
certfile=None, keyfile=None,
certfile=None, keyfile=None, reqclicert=False,
allow_reuse_address=False,
allow_reuse_port=False,
custom_functions=[],
Expand All @@ -851,6 +853,7 @@ async def StartTlsServer(context=None, identity=None, address=None,
:param sslctx: The SSLContext to use for TLS (default None and auto create)
:param certfile: The cert file path for TLS (used if sslctx is None)
:param keyfile: The key file path for TLS (used if sslctx is None)
:param reqclicert: Force the sever request client's certificate
:param allow_reuse_address: Whether the server will allow the reuse of an
address.
:param allow_reuse_port: Whether the server will allow the reuse of a port.
Expand All @@ -865,7 +868,7 @@ async def StartTlsServer(context=None, identity=None, address=None,
"""
framer = kwargs.pop("framer", ModbusTlsFramer)
server = ModbusTlsServer(context, framer, identity, address, sslctx,
certfile, keyfile,
certfile, keyfile, reqclicert=reqclicert,
allow_reuse_address=allow_reuse_address,
allow_reuse_port=allow_reuse_port, **kwargs)

Expand Down
17 changes: 11 additions & 6 deletions pymodbus/server/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,8 @@ class ModbusTlsServer(ModbusTcpServer):

def __init__(self, context, framer=None, identity=None,
address=None, handler=None, allow_reuse_address=False,
sslctx=None, certfile=None, keyfile=None, **kwargs):
sslctx=None, certfile=None, keyfile=None, reqclicert=False,
**kwargs):
""" Overloaded initializer for the ModbusTcpServer

If the identify structure is not passed in, the ModbusControlBlock
Expand All @@ -394,23 +395,24 @@ def __init__(self, context, framer=None, identity=None,
create)
:param certfile: The cert file path for TLS (used if sslctx is None)
:param keyfile: The key file path for TLS (used if sslctx is None)
:param reqclicert: Force the sever request client's certificate
:param ignore_missing_slaves: True to not send errors on a request
to a missing slave
:param broadcast_enable: True to treat unit_id 0 as broadcast address,
False to treat 0 as any other unit_id
"""
self.sslctx = sslctx
if self.sslctx is None:
self.sslctx = ssl.create_default_context()
self.sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
self.sslctx.load_cert_chain(certfile=certfile, keyfile=keyfile)
# According to MODBUS/TCP Security Protocol Specification, it is
# TLSv2 at least
self.sslctx.options |= ssl.OP_NO_TLSv1_1
self.sslctx.options |= ssl.OP_NO_TLSv1
self.sslctx.options |= ssl.OP_NO_SSLv3
self.sslctx.options |= ssl.OP_NO_SSLv2
self.sslctx.verify_mode = ssl.CERT_OPTIONAL
self.sslctx.check_hostname = False
if reqclicert:
self.sslctx.verify_mode = ssl.CERT_REQUIRED

ModbusTcpServer.__init__(self, context, framer, identity, address,
handler, allow_reuse_address, **kwargs)
Expand Down Expand Up @@ -621,7 +623,8 @@ def StartTcpServer(context=None, identity=None, address=None,


def StartTlsServer(context=None, identity=None, address=None, sslctx=None,
certfile=None, keyfile=None, custom_functions=[], **kwargs):
certfile=None, keyfile=None, reqclicert=False,
custom_functions=[], **kwargs):
""" A factory to start and run a tls modbus server

:param context: The ModbusServerContext datastore
Expand All @@ -630,14 +633,16 @@ def StartTlsServer(context=None, identity=None, address=None, sslctx=None,
:param sslctx: The SSLContext to use for TLS (default None and auto create)
:param certfile: The cert file path for TLS (used if sslctx is None)
:param keyfile: The key file path for TLS (used if sslctx is None)
:param reqclicert: Force the sever request client's certificate
:param custom_functions: An optional list of custom function classes
supported by server instance.
:param ignore_missing_slaves: True to not send errors on a request to a
missing slave
"""
framer = kwargs.pop("framer", ModbusTlsFramer)
server = ModbusTlsServer(context, framer, identity, address, sslctx=sslctx,
certfile=certfile, keyfile=keyfile, **kwargs)
certfile=certfile, keyfile=keyfile,
reqclicert=reqclicert, **kwargs)

for f in custom_functions:
server.decoder.register(f)
Expand Down
10 changes: 9 additions & 1 deletion test/test_client_sync.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python
import unittest
from pymodbus.compat import IS_PYTHON3
from pymodbus.compat import IS_PYTHON3, PYTHON_VERSION
import pytest

if IS_PYTHON3: # Python 3
from unittest.mock import patch, Mock, MagicMock
Expand Down Expand Up @@ -257,6 +258,7 @@ class CustomeRequest:
# Test TLS Client
# -----------------------------------------------------------------------#

@pytest.mark.skipif(PYTHON_VERSION < (3, 6), reason="requires python3.6 or above")
def testSyncTlsClientInstantiation(self):
# default SSLContext
client = ModbusTlsClient()
Expand All @@ -269,6 +271,7 @@ def testSyncTlsClientInstantiation(self):
self.assertNotEqual(client, None)
self.assertEqual(client.sslctx, context)

@pytest.mark.skipif(PYTHON_VERSION < (3, 6), reason="requires python3.6 or above")
def testBasicSyncTlsClient(self):
''' Test the basic methods for the tls sync client'''

Expand All @@ -289,6 +292,7 @@ def testBasicSyncTlsClient(self):

self.assertEqual("ModbusTlsClient(localhost:802)", str(client))

@pytest.mark.skipif(PYTHON_VERSION < (3, 6), reason="requires python3.6 or above")
def testTlsClientConnect(self):
''' Test the tls client connection method'''
with patch.object(ssl.SSLSocket, 'connect') as mock_method:
Expand All @@ -300,6 +304,7 @@ def testTlsClientConnect(self):
client = ModbusTlsClient()
self.assertFalse(client.connect())

@pytest.mark.skipif(PYTHON_VERSION < (3, 6), reason="requires python3.6 or above")
def testTlsClientSend(self):
''' Test the tls client send method'''
client = ModbusTlsClient()
Expand All @@ -309,6 +314,7 @@ def testTlsClientSend(self):
self.assertEqual(0, client._send(None))
self.assertEqual(4, client._send('1234'))

@pytest.mark.skipif(PYTHON_VERSION < (3, 6), reason="requires python3.6 or above")
def testTlsClientRecv(self):
''' Test the tls client receive method'''
client = ModbusTlsClient()
Expand All @@ -326,6 +332,7 @@ def testTlsClientRecv(self):
mock_socket.recv.side_effect = iter([b'\x00', b'\x01', b'\x02'])
self.assertEqual(b'\x00\x01', client._recv(2))

@pytest.mark.skipif(PYTHON_VERSION < (3, 6), reason="requires python3.6 or above")
def testTlsClientRpr(self):
client = ModbusTlsClient()
rep = "<{} at {} socket={}, ipaddr={}, port={}, sslctx={}, " \
Expand All @@ -335,6 +342,7 @@ def testTlsClientRpr(self):
)
self.assertEqual(repr(client), rep)

@pytest.mark.skipif(PYTHON_VERSION < (3, 6), reason="requires python3.6 or above")
def testTlsClientRegister(self):
class CustomeRequest:
function_code = 79
Expand Down
13 changes: 10 additions & 3 deletions test/test_server_sync.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python
from pymodbus.compat import IS_PYTHON3
from pymodbus.compat import IS_PYTHON3, PYTHON_VERSION
import pytest
import unittest
if IS_PYTHON3: # Python 3
from unittest.mock import patch, Mock
Expand Down Expand Up @@ -278,22 +279,26 @@ def testTcpServerProcess(self):
#-----------------------------------------------------------------------#
# Test TLS Server
#-----------------------------------------------------------------------#
@pytest.mark.skipif(PYTHON_VERSION < (3, 6), reason="requires python3.6 or above")
def testTlsServerInit(self):
''' test that the synchronous TLS server intial correctly '''
with patch.object(socket.socket, 'bind') as mock_socket:
with patch.object(ssl.SSLContext, 'load_cert_chain') as mock_method:
identity = ModbusDeviceIdentification(info={0x00: 'VendorName'})
server = ModbusTlsServer(context=None, identity=identity)
server = ModbusTlsServer(context=None, identity=identity,
reqclicert=True)
self.assertIsNotNone(server.sslctx)
self.assertEqual(server.sslctx.verify_mode, ssl.CERT_REQUIRED)
self.assertEqual(type(server.socket), ssl.SSLSocket)
server.server_close()
sslctx = ssl.create_default_context()
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
server = ModbusTlsServer(context=None, identity=identity,
sslctx=sslctx)
self.assertEqual(server.sslctx, sslctx)
self.assertEqual(type(server.socket), ssl.SSLSocket)
server.server_close()

@pytest.mark.skipif(PYTHON_VERSION < (3, 6), reason="requires python3.6 or above")
def testTlsServerClose(self):
''' test that the synchronous TLS server closes correctly '''
with patch.object(socket.socket, 'bind') as mock_socket:
Expand All @@ -305,6 +310,7 @@ def testTlsServerClose(self):
self.assertEqual(server.control.Identity.VendorName, 'VendorName')
self.assertFalse(server.threads[0].running)

@pytest.mark.skipif(PYTHON_VERSION < (3, 6), reason="requires python3.6 or above")
def testTlsServerProcess(self):
''' test that the synchronous TLS server processes requests '''
with patch('pymodbus.compat.socketserver.ThreadingTCPServer') as mock_server:
Expand Down Expand Up @@ -386,6 +392,7 @@ def testStartTcpServer(self):
with patch.object(socketserver.TCPServer, 'server_bind') as mock_binder:
StartTcpServer()

@pytest.mark.skipif(PYTHON_VERSION < (3, 6), reason="requires python3.6 or above")
def testStartTlsServer(self):
''' Test the tls server starting factory '''
with patch.object(ModbusTlsServer, 'serve_forever') as mock_server:
Expand Down