Skip to content

Commit

Permalink
Add server requiring client's cert in TLS handshake feature
Browse files Browse the repository at this point in the history
This patch adds server requiring client's certificate feature which is
mentioned in the 6th step CertificateRequest to 9th step
VerifyClientCertSig in Table 5 TLS Full Handshake Protocol of MODBUS/TCP
Security Protocol Specification [1].

This patch implements the feature within both sync and async_io version.

* Server side:
Add an optional argument "reqclicert" of StartTlsServer(). So, users can
force server require client's certificate for TLS full handshake, or
according to the SSL Context's original behavior [2].

* Client side:
Add optional arguments "certfile" and "keyfile" for replying, if the
server requires client's certificate for TLS full handshake.

Besides, also add an optional argument "password" on both server and
client side for decrypting the private keyfile.

This fixes part of pymodbus-dev#606

[1]: http://modbus.org/docs/MB-TCP-Security-v21_2018-07-24.pdf
[2]: https://docs.python.org/3/library/ssl.html#ssl.SSLContext.verify_mode
  • Loading branch information
starnight committed Mar 10, 2021
1 parent 1da8e5e commit 14c77e5
Show file tree
Hide file tree
Showing 14 changed files with 207 additions and 107 deletions.
8 changes: 7 additions & 1 deletion examples/common/asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,13 @@ async def run_server():

# Tls:
# await StartTlsServer(context, identity=identity, address=("localhost", 8020),
# certfile="server.crt", keyfile="server.key",
# certfile="server.crt", keyfile="server.key", password="pwd",
# allow_reuse_address=True, allow_reuse_port=True,
# defer_start=False)

# Tls and force require client's certificate for TLS full handshake:
# await StartTlsServer(context, identity=identity, address=("localhost", 8020),
# certfile="server.crt", keyfile="server.key", password="pwd", reqclicert=True,
# allow_reuse_address=True, allow_reuse_port=True,
# defer_start=False)

Expand Down
10 changes: 8 additions & 2 deletions examples/common/synchronous_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,14 @@ def run_server():
# framer=ModbusRtuFramer, address=("0.0.0.0", 5020))

# TLS
# StartTlsServer(context, identity=identity, certfile="server.crt",
# keyfile="server.key", address=("0.0.0.0", 8020))
# StartTlsServer(context, identity=identity,
# certfile="server.crt", keyfile="server.key", password="pwd",
# address=("0.0.0.0", 8020))

# Tls and force require client's certificate for TLS full handshake:
# StartTlsServer(context, identity=identity,
# certfile="server.crt", keyfile="server.key", password="pwd", reqclicert=True,
# address=("0.0.0.0", 8020))

# Udp:
# StartUdpServer(context, identity=identity, address=("0.0.0.0", 5020))
Expand Down
14 changes: 8 additions & 6 deletions examples/contrib/asynchronous_asyncio_modbus_tls_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
# -------------------------------------------------------------------------- #
# the TLS detail security can be set in SSLContext which is the context here
# -------------------------------------------------------------------------- #
context = ssl.create_default_context()
context.options |= ssl.OP_NO_SSLv2
context.options |= ssl.OP_NO_SSLv3
context.options |= ssl.OP_NO_TLSv1
context.options |= ssl.OP_NO_TLSv1_1
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
sslctx.verify_mode = ssl.CERT_REQUIRED
sslctx.check_hostname = True

# Prepare client's certificate which the server requires for TLS full handshake
# sslctx.load_cert_chain(certfile="client.crt", keyfile="client.key",
# password="pwd")

async def start_async_test(client):
result = await client.read_coils(1, 8)
Expand All @@ -35,6 +37,6 @@ async def start_async_test(client):
# pass SSLContext which is the context here to ModbusTcpClient()
# -------------------------------------------------------------------------- #
loop, client = AsyncModbusTLSClient(ASYNC_IO, 'test.host.com', 8020,
sslctx=context)
sslctx=sslctx)
loop.run_until_complete(start_async_test(client.protocol))
loop.close()
14 changes: 8 additions & 6 deletions examples/contrib/modbus_tls_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@
# -------------------------------------------------------------------------- #
# the TLS detail security can be set in SSLContext which is the context here
# -------------------------------------------------------------------------- #
context = ssl.create_default_context()
context.options |= ssl.OP_NO_SSLv2
context.options |= ssl.OP_NO_SSLv3
context.options |= ssl.OP_NO_TLSv1
context.options |= ssl.OP_NO_TLSv1_1
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
sslctx.verify_mode = ssl.CERT_REQUIRED
sslctx.check_hostname = True

# Prepare client's certificate which the server requires for TLS full handshake
# sslctx.load_cert_chain(certfile="client.crt", keyfile="client.key",
# password="pwd")

# -------------------------------------------------------------------------- #
# pass SSLContext which is the context here to ModbusTcpClient()
# -------------------------------------------------------------------------- #
client = ModbusTlsClient('test.host.com', 8020, sslctx=context)
client = ModbusTlsClient('test.host.com', 8020, sslctx=sslctx)
client.connect()

result = client.read_coils(1, 8)
Expand Down
37 changes: 17 additions & 20 deletions pymodbus/client/asynchronous/async_io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ssl
from pymodbus.exceptions import ConnectionException
from pymodbus.client.asynchronous.mixins import AsyncModbusClientMixin
from pymodbus.client.tls_helper import sslctx_provider
from pymodbus.utilities import hexlify_packets
from pymodbus.transaction import FifoTransactionManager
import logging
Expand Down Expand Up @@ -449,25 +450,18 @@ def __init__(self, protocol_class=None, loop=None, framer=None):
ReconnectingAsyncioModbusTcpClient.__init__(self, protocol_class, loop)

@asyncio.coroutine
def start(self, host, port=802, sslctx=None, server_hostname=None):
def start(self, host='localhost', port=802, sslctx=None,
certfile=None, keyfile=None, password=None):
"""
Initiates connection to start client
:param host:
:param port:
:param sslctx:
:param server_hostname:
:return:
:param host: The host to connect to (default localhost)
:param port: Port to connect
:param sslctx:The SSLContext to use for TLS (default None and auto create)
:param certfile: The optional client's cert file path for TLS server request
:param keyfile: The optional client's key file path for TLS server request
:param password: The password for for decrypting client's private key file
"""
self.sslctx = sslctx
if self.sslctx is None:
self.sslctx = ssl.create_default_context()
# According to MODBUS/TCP Security Protocol Specification, it is
# TLSv2 at least
self.sslctx.options |= ssl.OP_NO_TLSv1_1
self.sslctx.options |= ssl.OP_NO_TLSv1
self.sslctx.options |= ssl.OP_NO_SSLv3
self.sslctx.options |= ssl.OP_NO_SSLv2
self.server_hostname = server_hostname
self.sslctx = sslctx_provider(sslctx, certfile, keyfile, password)
yield from ReconnectingAsyncioModbusTcpClient.start(self, host, port)

@asyncio.coroutine
Expand All @@ -478,7 +472,7 @@ def _connect(self):
self.host,
self.port,
ssl=self.sslctx,
server_hostname=self.server_hostname)
server_hostname=self.host)
except Exception as ex:
_logger.warning('Failed to connect: %s' % ex)
asyncio.ensure_future(self._reconnect(), loop=self.loop)
Expand Down Expand Up @@ -849,22 +843,25 @@ def init_tcp_client(proto_cls, loop, host, port, **kwargs):

@asyncio.coroutine
def init_tls_client(proto_cls, loop, host, port, sslctx=None,
server_hostname=None, framer=None, **kwargs):
certfile=None, keyfile=None, password=None,
framer=None,**kwargs):
"""
Helper function to initialize tcp client
:param proto_cls:
:param loop:
:param host:
:param port:
:param sslctx:
:param server_hostname:
:param certfile: The optional client's cert file path for TLS server request
:param keyfile: The optional client's key file path for TLS server request
:param password: The password for for decrypting client's private key file
:param framer:
:param kwargs:
:return:
"""
client = ReconnectingAsyncioModbusTlsClient(protocol_class=proto_cls,
loop=loop, framer=framer)
yield from client.start(host, port, sslctx, server_hostname)
yield from client.start(host, port, sslctx, certfile, keyfile, password)
return client


Expand Down
17 changes: 10 additions & 7 deletions pymodbus/client/asynchronous/factory/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@
LOGGER = logging.getLogger(__name__)

def async_io_factory(host="127.0.0.1", port=Defaults.TLSPort, sslctx=None,
server_hostname=None, framer=None, source_address=None,
certfile=None, keyfile=None, password=None,
framer=None, source_address=None,
timeout=None, **kwargs):
"""
Factory to create asyncio based asynchronous tls clients
:param host: Host IP address
:param host: Target server's name, also matched for certificate
:param port: Port
:param sslctx: The SSLContext to use for TLS (default None and auto create)
:param server_hostname: Target server's name matched for certificate
:param certfile: The optional client's cert file path for TLS server request
:param keyfile: The optional client's key file path for TLS server request
:param password: The password for for decrypting client's private key file
:param framer: Modbus Framer
:param source_address: Bind address
:param timeout: Timeout in seconds
Expand All @@ -33,12 +36,12 @@ def async_io_factory(host="127.0.0.1", port=Defaults.TLSPort, sslctx=None,
proto_cls = kwargs.get("proto_cls", None)
if not loop.is_running():
asyncio.set_event_loop(loop)
cor = init_tls_client(proto_cls, loop, host, port, sslctx, server_hostname,
framer)
cor = init_tls_client(proto_cls, loop, host, port,
sslctx, certfile, keyfile, password, framer)
client = loop.run_until_complete(asyncio.gather(cor))[0]
else:
cor = init_tls_client(proto_cls, loop, host, port, sslctx, server_hostname,
framer)
cor = init_tls_client(proto_cls, loop, host, port,
sslctx, certfile, keyfile, password, framer)
future = asyncio.run_coroutine_threadsafe(cor, loop=loop)
client = future.result()

Expand Down
15 changes: 9 additions & 6 deletions pymodbus/client/asynchronous/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,19 @@ class AsyncModbusTLSClient(object):
from pymodbus.client.asynchronous.tls import AsyncModbusTLSClient
"""
def __new__(cls, scheduler, host="127.0.0.1", port=Defaults.TLSPort,
framer=None, sslctx=None, server_hostname=None,
source_address=None, timeout=None, **kwargs):
framer=None, sslctx=None, certfile=None, keyfile=None,
password=None, source_address=None, timeout=None, **kwargs):
"""
Scheduler to use:
- async_io (asyncio)
:param scheduler: Backend to use
:param host: Host IP address
:param host: Target server's name, also matched for certificate
:param port: Port
:param framer: Modbus Framer to use
:param sslctx: The SSLContext to use for TLS (default None and auto create)
:param server_hostname: Target server's name matched for certificate
:param certfile: The optional client's cert file path for TLS server request
:param keyfile: The optional client's key file path for TLS server request
:param password: The password for for decrypting client's private key file
:param source_address: source address specific to underlying backend
:param timeout: Time out in seconds
:param kwargs: Other extra args specific to Backend being used
Expand All @@ -45,8 +47,9 @@ def __new__(cls, scheduler, host="127.0.0.1", port=Defaults.TLSPort,
framer = framer or ModbusTlsFramer(ClientDecoder())
factory_class = get_factory(scheduler)
yieldable = factory_class(host=host, port=port, sslctx=sslctx,
server_hostname=server_hostname,
framer=framer, source_address=source_address,
certfile=certfile, keyfile=keyfile,
password=password, framer=framer,
source_address=source_address,
timeout=timeout, **kwargs)
return yieldable

17 changes: 7 additions & 10 deletions pymodbus/client/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pymodbus.transaction import ModbusAsciiFramer, ModbusRtuFramer
from pymodbus.transaction import ModbusTlsFramer
from pymodbus.client.common import ModbusClientMixin
from pymodbus.client.tls_helper import sslctx_provider

# --------------------------------------------------------------------------- #
# Logging
Expand Down Expand Up @@ -368,27 +369,23 @@ class ModbusTlsClient(ModbusTcpClient):
"""

def __init__(self, host='localhost', port=Defaults.TLSPort, sslctx=None,
framer=ModbusTlsFramer, **kwargs):
certfile=None, keyfile=None, password=None, framer=ModbusTlsFramer,
**kwargs):
""" Initialize a client instance
:param host: The host to connect to (default localhost)
:param port: The modbus port to connect to (default 802)
:param sslctx: The SSLContext to use for TLS (default None and auto create)
:param certfile: The optional client's cert file path for TLS server request
:param keyfile: The optional client's key file path for TLS server request
:param password: The password for for decrypting client's private key file
:param source_address: The source address tuple to bind to (default ('', 0))
:param timeout: The timeout to use for this socket (default Defaults.Timeout)
:param framer: The modbus framer to use (default ModbusSocketFramer)
.. note:: The host argument will accept ipv4 and ipv6 hosts
"""
self.sslctx = sslctx
if self.sslctx is None:
self.sslctx = ssl.create_default_context()
# According to MODBUS/TCP Security Protocol Specification, it is
# TLSv2 at least
self.sslctx.options |= ssl.OP_NO_TLSv1_1
self.sslctx.options |= ssl.OP_NO_TLSv1
self.sslctx.options |= ssl.OP_NO_SSLv3
self.sslctx.options |= ssl.OP_NO_SSLv2
self.sslctx = sslctx_provider(sslctx, certfile, keyfile, password)
ModbusTcpClient.__init__(self, host, port, framer, **kwargs)

def connect(self):
Expand Down
31 changes: 31 additions & 0 deletions pymodbus/client/tls_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
TLS helper for Modbus TLS Client
------------------------------------------
"""
import ssl

def sslctx_provider(sslctx=None, certfile=None, keyfile=None, password=None):
""" Provide the SSLContext for ModbusTlsClient
If the user defined SSLContext is not passed in, sslctx_provider will
produce a default one.
:param sslctx: The user defined SSLContext to use for TLS (default None and
auto create)
:param certfile: The optional client's cert file path for TLS server request
:param keyfile: The optional client's key file path for TLS server request
:param password: The password for for decrypting client's private key file
"""
if sslctx is None:
# According to MODBUS/TCP Security Protocol Specification, it is
# TLSv2 at least
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
sslctx.verify_mode = ssl.CERT_REQUIRED
sslctx.check_hostname = True

if certfile and keyfile:
sslctx.load_cert_chain(certfile=certfile, keyfile=keyfile,
password=password)

return sslctx
27 changes: 13 additions & 14 deletions pymodbus/server/async_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pymodbus.exceptions import NotImplementedException, NoSuchSlaveException
from pymodbus.pdu import ModbusExceptions as merror
from pymodbus.compat import socketserver, byte2int
from pymodbus.server.tls_helper import sslctx_provider

# --------------------------------------------------------------------------- #
# Logging
Expand Down Expand Up @@ -524,6 +525,8 @@ def __init__(self,
sslctx=None,
certfile=None,
keyfile=None,
password=None,
reqclicert=False,
handler=None,
allow_reuse_address=False,
allow_reuse_port=False,
Expand All @@ -544,6 +547,8 @@ def __init__(self,
create)
:param certfile: The cert file path for TLS (used if sslctx is None)
:param keyfile: The key file path for TLS (used if sslctx is None)
:param password: The password for for decrypting the private key file
:param reqclicert: Force the sever request client's certificate
:param handler: A handler for each client session; default is
ModbusConnectedRequestHandler. The handler class
receives connection create/teardown events
Expand Down Expand Up @@ -582,18 +587,9 @@ def __init__(self,
if isinstance(identity, ModbusDeviceIdentification):
self.control.Identity.update(identity)

self.sslctx = sslctx
if self.sslctx is None:
self.sslctx = ssl.create_default_context()
self.sslctx.load_cert_chain(certfile=certfile, keyfile=keyfile)
# According to MODBUS/TCP Security Protocol Specification, it is
# TLSv2 at least
self.sslctx.options |= ssl.OP_NO_TLSv1_1
self.sslctx.options |= ssl.OP_NO_TLSv1
self.sslctx.options |= ssl.OP_NO_SSLv3
self.sslctx.options |= ssl.OP_NO_SSLv2
self.sslctx.verify_mode = ssl.CERT_OPTIONAL
self.sslctx.check_hostname = False
self.sslctx = sslctx_provider(sslctx, certfile, keyfile, password,
reqclicert)

# asyncio future that will be done once server has started
self.serving = self.loop.create_future()
# constructors cannot be declared async, so we have to
Expand Down Expand Up @@ -845,7 +841,8 @@ async def StartTcpServer(context=None, identity=None, address=None,

async def StartTlsServer(context=None, identity=None, address=None,
sslctx=None,
certfile=None, keyfile=None,
certfile=None, keyfile=None, password=None,
reqclicert=False,
allow_reuse_address=False,
allow_reuse_port=False,
custom_functions=[],
Expand All @@ -858,6 +855,8 @@ async def StartTlsServer(context=None, identity=None, address=None,
:param sslctx: The SSLContext to use for TLS (default None and auto create)
:param certfile: The cert file path for TLS (used if sslctx is None)
:param keyfile: The key file path for TLS (used if sslctx is None)
:param password: The password for for decrypting the private key file
:param reqclicert: Force the sever request client's certificate
:param allow_reuse_address: Whether the server will allow the reuse of an
address.
:param allow_reuse_port: Whether the server will allow the reuse of a port.
Expand All @@ -872,7 +871,7 @@ async def StartTlsServer(context=None, identity=None, address=None,
"""
framer = kwargs.pop("framer", ModbusTlsFramer)
server = ModbusTlsServer(context, framer, identity, address, sslctx,
certfile, keyfile,
certfile, keyfile, password, reqclicert,
allow_reuse_address=allow_reuse_address,
allow_reuse_port=allow_reuse_port, **kwargs)

Expand Down
Loading

0 comments on commit 14c77e5

Please sign in to comment.