Skip to content

Commit

Permalink
Better testing of ssl opts and ws transport
Browse files Browse the repository at this point in the history
  • Loading branch information
dwoz committed Aug 12, 2023
1 parent 0e67653 commit 4919283
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 49 deletions.
43 changes: 23 additions & 20 deletions salt/transport/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import hashlib
import logging
import os
import ssl

import salt.utils.stringutils

log = logging.getLogger(__name__)

TRANSPORTS = (
"zeromq",
"tcp",
Expand Down Expand Up @@ -125,17 +129,17 @@ def publish_client(
elif "transport" in opts.get("pillar", {}).get("master", {}):
ttype = opts["pillar"]["master"]["transport"]

ssl = None
ssl_opts = None
if "ssl" in kwargs:
ssl = kwargs["ssl"]
ssl_opts = kwargs["ssl"]
elif opts.get("ssl", None) is not None:
ssl = opts["ssl"]
ssl_opts = opts["ssl"]

# switch on available ttypes
if ttype == "zeromq":
import salt.transport.zeromq

if ssl:
if ssl_opts:
log.warning("TLS not supported with zeromq transport")
return salt.transport.zeromq.PublishClient(
opts, io_loop, host=host, port=port, path=path
Expand All @@ -149,7 +153,7 @@ def publish_client(
host=host,
port=port,
path=path,
ssl=ssl,
ssl=ssl_opts,
)
elif ttype == "ws":
import salt.transport.ws
Expand All @@ -160,7 +164,7 @@ def publish_client(
host=host,
port=port,
path=path,
ssl=ssl,
ssl=ssl_opts,
)

raise Exception(f"Transport type not found: {ttype}")
Expand Down Expand Up @@ -392,8 +396,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):


def ssl_context(ssl_options, server_side=False):
if isinstance(ssl_options, ssl.SSLContext):
return ssl_options
default_version = ssl.PROTOCOL_TLS
if server_side:
default_version = ssl.PROTOCOL_TLS_SERVER
Expand All @@ -405,27 +407,28 @@ def ssl_context(ssl_options, server_side=False):
ssl_options["certfile"], ssl_options.get("keyfile", None)
)
if "cert_reqs" in ssl_options:
if ssl_options["cert_reqs"] == ssl.CERT_NONE:
if ssl_options["cert_reqs"].upper() == "CERT_NONE":
# This may have been set automatically by PROTOCOL_TLS_CLIENT but is
# incompatible with CERT_NONE so we must manually clear it.
context.check_hostname = False
context.verify_mode = getattr(ssl, VerifyMode, ssl_options["cert_reqs"])
context.verify_mode = getattr(ssl.VerifyMode, ssl_options["cert_reqs"])
if "ca_certs" in ssl_options:
context.load_verify_locations(ssl_options["ca_certs"])
if "verify_locations" in ssl_options:
for _ in ssl_options["verify_locations"]:
if _.lower().startswith("cafile:"):
cafile = _[7:]
context.load_verify_locations(cafile=cafile)
elif _.lower().startswith("capath:"):
capath = _[7:]
context.load_verify_locations(capath=capath)
elif _.lower().startswith("cadata:"):
cadata = _[7:]
context.load_verify_locations(cadata=cadata)
if isinstance(_, dict):
for key in _:
if key.lower() == "cafile":
context.load_verify_locations(cafile=_[key])
elif key.lower() == "capath":
context.load_verify_locations(capath=_[key])
elif key.lower() == "cadata":
context.load_verify_locations(cadata=_[key])
else:
log.warning("Unkown verify location type: %s", key)
else:
cafile = _
context.load_verify_locations(cafile=cafile)
context.load_verify_locations(cafile=_)
if "verify_flags" in ssl_options:
for flag in ssl_options["verify_flags"]:
context.verify_flags |= getattr(ssl.VerifyFlags, flag.upper())
Expand Down
5 changes: 2 additions & 3 deletions salt/transport/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import time
import urllib
import warnings
import ssl

import tornado
import tornado.concurrent
Expand All @@ -30,8 +29,8 @@

import salt.master
import salt.payload
import salt.transport.frame
import salt.transport.base
import salt.transport.frame
import salt.utils.asynchronous
import salt.utils.files
import salt.utils.msgpack
Expand Down Expand Up @@ -1133,7 +1132,7 @@ async def publish_payload(self, package, topic_list=None):
to_remove = []
if topic_list:
for topic in topic_list:
sent = Falses
sent = False
for client in list(self.clients):
if topic == client.id_:
try:
Expand Down
10 changes: 5 additions & 5 deletions salt/transport/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import socket
import time
import warnings
import ssl

import aiohttp
import aiohttp.web
Expand Down Expand Up @@ -109,6 +108,7 @@ async def getstream(self, **kwargs):
start = time.monotonic()
timeout = kwargs.get("timeout", None)
while ws is None and (not self._closed and not self._closing):
session = None
try:
ctx = None
if self.ssl is not None:
Expand Down Expand Up @@ -139,6 +139,8 @@ async def getstream(self, **kwargs):
exc,
self.backoff,
)
if session:
await session.close()
if timeout and time.monotonic() - start > timeout:
break
await asyncio.sleep(self.backoff)
Expand Down Expand Up @@ -374,9 +376,7 @@ async def publisher(
runner = aiohttp.web.ServerRunner(server)
await runner.setup()
site = aiohttp.web.SockSite(runner, sock, ssl_context=ctx)
log.info(
"Publisher binding to socket %s:%s", (self.pub_host, self.pub_port)
)
log.info("Publisher binding to socket %s:%s", self.pub_host, self.pub_port)
await site.start()

if self.pull_path:
Expand Down Expand Up @@ -548,7 +548,7 @@ def __init__(self, opts, io_loop): # pylint: disable=W0231
self.io_loop = io_loop
self._closing = False
self._closed = False
self.ssl = self.opts("ssl", None)
self.ssl = self.opts.get("ssl", None)

async def connect(self):
ctx = None
Expand Down
2 changes: 1 addition & 1 deletion tests/pytests/functional/channel/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def transport_ids(value):
return "transport({})".format(value)


@pytest.fixture(params=["tcp", "zeromq"], ids=transport_ids)
@pytest.fixture(params=["ws", "tcp", "zeromq"], ids=transport_ids)
def transport(request):
return request.param

Expand Down
70 changes: 70 additions & 0 deletions tests/pytests/unit/transport/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import ssl
import contextlib

import salt.transport.base

from tests.support.helpers import dedent

import pytest
from tests.support.mock import patch, Mock


@patch('ssl.SSLContext')
def test_ssl_context_legacy_opts(mock):
ctx = salt.transport.base.ssl_context({
'certfile': "server.crt",
'keyfile': "server.key",
'cert_reqs': "CERT_NONE",
"ca_certs": "ca.crt",
})
ctx.load_cert_chain.assert_called_with(
"server.crt",
"server.key",
)
ctx.load_verify_locations.assert_called_with(
"ca.crt"
)
assert ssl.VerifyMode.CERT_NONE == ctx.verify_mode
assert not ctx.check_hostname


@patch('ssl.SSLContext')
def test_ssl_context_opts(mock):
mock.verify_flags = ssl.VerifyFlags.VERIFY_X509_TRUSTED_FIRST
ctx = salt.transport.base.ssl_context({
'certfile': "server.crt",
'keyfile': "server.key",
'cert_reqs': "CERT_OPTIONAL",
"verify_locations": [
"ca.crt",
{"cafile": "crl.pem"},
{"capath": "/tmp/mycapathsdf"},
{"cadata": "mycadataother"},
{"CADATA": "mycadatasdf"},
],
"verify_flags": [
"VERIFY_CRL_CHECK_CHAIN",
]
})
ctx.load_cert_chain.assert_called_with(
"server.crt",
"server.key",
)
ctx.load_verify_locations.assert_any_call(
cafile="ca.crt"
)
ctx.load_verify_locations.assert_any_call(
cafile="crl.pem"
)
ctx.load_verify_locations.assert_any_call(
capath="/tmp/mycapathsdf"
)
ctx.load_verify_locations.assert_any_call(
cadata="mycadataother"
)
ctx.load_verify_locations.assert_called_with(
cadata="mycadatasdf"
)
assert ssl.VerifyMode.CERT_OPTIONAL == ctx.verify_mode
assert ctx.check_hostname
assert ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN & ctx.verify_flags
73 changes: 53 additions & 20 deletions tests/pytests/unit/transport/test_publish_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
:codeauthor: Thomas Jackson <[email protected]>
"""

import asyncio
import hashlib
import logging

Expand All @@ -10,6 +11,7 @@

import salt.crypt
import salt.transport.tcp
import salt.transport.ws
import salt.transport.zeromq
import salt.utils.stringutils
from tests.support.mock import MagicMock, patch
Expand All @@ -23,10 +25,10 @@


def transport_ids(value):
return "Transport({})".format(value)
return f"Transport({value})"


@pytest.fixture(params=("zeromq", "tcp"), ids=transport_ids)
@pytest.fixture(params=("zeromq", "tcp", "ws"), ids=transport_ids)
def transport(request):
return request.param

Expand Down Expand Up @@ -168,31 +170,28 @@ async def test_publish_client_connect_server_down(transport, io_loop):
await client.connect()
assert client._socket
elif transport == "tcp":
client = salt.transport.tcp.TCPPubClient(opts, io_loop, host=host, port=port)
try:
# XXX: This is an implimentation detail of the tcp transport.
# await client.connect(port)
io_loop.spawn_callback(client.connect)
except TimeoutError:
pass
except Exception: # pylint: disable=broad-except
log.error("Got exception", exc_info=True)
client = salt.transport.tcp.PublishClient(opts, io_loop, host=host, port=port)
io_loop.spawn_callback(client.connect)
assert client._stream is None
elif transport == "ws":
client = salt.transport.ws.PublishClient(opts, io_loop, host=host, port=port)
io_loop.spawn_callback(client.connect)
assert client._ws is None
assert client._session is None
client.close()
await asyncio.sleep(0.03)


async def test_publish_client_connect_server_comes_up(transport, io_loop):
opts = {"master_ip": "127.0.0.1"}
host = "127.0.0.1"
port = 11122
msg = salt.payload.dumps({"meh": 123})
if transport == "zeromq":
import asyncio

import zmq

ctx = zmq.asyncio.Context()
uri = f"tcp://{opts['master_ip']}:{port}"
msg = salt.payload.dumps({"meh": 123})
log.debug("TEST - Senging %r", msg)
client = salt.transport.zeromq.PublishClient(
opts, io_loop, host=host, port=port
Expand All @@ -213,6 +212,7 @@ async def recv():
task = asyncio.create_task(recv())
# Sleep to allow zmq to do it's thing.
await socket.send(msg)
await asyncio.sleep(0.03)
await task
response = task.result()
assert response
Expand All @@ -221,10 +221,9 @@ async def recv():
await asyncio.sleep(0.03)
ctx.term()
elif transport == "tcp":
import asyncio
import socket

client = salt.transport.tcp.TCPPubClient(opts, io_loop, host=host, port=port)
client = salt.transport.tcp.PublishClient(opts, io_loop, host=host, port=port)
# XXX: This is an implimentation detail of the tcp transport.
# await client.connect(port)
io_loop.spawn_callback(client.connect)
Expand All @@ -238,11 +237,45 @@ async def recv():
sock.listen(128)
await asyncio.sleep(0.03)

msg = salt.payload.dumps({"meh": 123})
msg = salt.transport.frame.frame_msg(msg, header=None)
data = salt.transport.frame.frame_msg(msg, header=None)
conn, addr = sock.accept()
conn.send(msg)
conn.send(data)
response = await client.recv()
assert response
assert msg == response
elif transport == "ws":
import socket

import aiohttp

client = salt.transport.ws.PublishClient(opts, io_loop, host=host, port=port)
io_loop.spawn_callback(client.connect)
assert client._ws is None
assert client._session is None
await asyncio.sleep(2)

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setblocking(0)
sock.bind((opts["master_ip"], port))
sock.listen(128)

async def handler(request):
ws = aiohttp.web.WebSocketResponse()
await ws.prepare(request)
data = salt.transport.frame.frame_msg(msg, header=None)
await ws.send_bytes(data)

server = aiohttp.web.Server(handler)
runner = aiohttp.web.ServerRunner(server)
await runner.setup()
site = aiohttp.web.SockSite(runner, sock)
await site.start()

await asyncio.sleep(0.03)

response = await client.recv()
assert msg == response
else:
raise Exception(f"Unknown transport {transport}")
client.close()
await asyncio.sleep(0.03)

0 comments on commit 4919283

Please sign in to comment.