Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dwoz committed Aug 13, 2023
1 parent 4919283 commit 802b420
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 93 deletions.
74 changes: 33 additions & 41 deletions salt/transport/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,23 +184,40 @@ async def send(self, msg):
await self.message_client.send(msg, reply=False)

async def recv(self, timeout=None):
try:
await self._read_in_progress.acquire(timeout=0.001)
except tornado.gen.TimeoutError:
log.error("Unable to acquire read lock")
return
try:
if timeout == 0:
if not self._ws:
await asyncio.sleep(0.001)
return
while self._ws is None:
await self.connect()
await asyncio.sleep(0.001)
if timeout == 0:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
try:
raw_msg = await asyncio.wait_for(self._ws.receive(), 0.0001)
except TimeoutError:
return
if raw_msg.type == aiohttp.WSMsgType.TEXT:
if raw_msg.data == "close":
await self._ws.close()
if raw_msg.type == aiohttp.WSMsgType.BINARY:
self.unpacker.feed(raw_msg.data)
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
try:
raw_msg = await asyncio.wait_for(self._ws.receive(), 0.0001)
except TimeoutError:
return
elif raw_msg.type == aiohttp.WSMsgType.ERROR:
log.error(
"ws connection closed with exception %s", self._ws.exception()
)
elif timeout:
return await asyncio.wait_for(self.recv(), timeout=timeout)
else:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
while True:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
raw_msg = await self._ws.receive()
if raw_msg.type == aiohttp.WSMsgType.TEXT:
if raw_msg.data == "close":
await self._ws.close()
Expand All @@ -211,34 +228,9 @@ async def recv(self, timeout=None):
return framed_msg["body"]
elif raw_msg.type == aiohttp.WSMsgType.ERROR:
log.error(
"ws connection closed with exception %s", self._ws.exception()
"ws connection closed with exception %s",
self._ws.exception(),
)
elif timeout:
return await asyncio.wait_for(self.recv(), timeout=timeout)
else:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
while True:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
raw_msg = await self._ws.receive()
if raw_msg.type == aiohttp.WSMsgType.TEXT:
if raw_msg.data == "close":
await self._ws.close()
if raw_msg.type == aiohttp.WSMsgType.BINARY:
self.unpacker.feed(raw_msg.data)
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
elif raw_msg.type == aiohttp.WSMsgType.ERROR:
log.error(
"ws connection closed with exception %s",
self._ws.exception(),
)
finally:
self._read_in_progress.release()

async def handle_on_recv(self, callback):
while not self._ws:
Expand Down
2 changes: 1 addition & 1 deletion tests/pytests/functional/channel/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def root_dir(tmp_path):


def transport_ids(value):
return "transport({})".format(value)
return f"transport({value})"


@pytest.fixture(params=["ws", "tcp", "zeromq"], ids=transport_ids)
Expand Down
4 changes: 2 additions & 2 deletions tests/pytests/functional/cli/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _ret(self, jid, minion_id, fun, _return=True, _retcode=0):
},
use_bin_type=True,
)
tag = "salt/job/{}/ret".format(jid).encode()
tag = f"salt/job/{{jid.encode()}/ret"
return b"".join([tag, b"\n\n", dumped])

def connect(self, timeout=None):
Expand Down Expand Up @@ -170,7 +170,7 @@ def returns_for_job(jid):
"extension_modules": "",
"failhard": True,
}
with patch("salt.transport.tcp.TCPPubClient", MockSubscriber):
with patch("salt.transport.tcp.PublishClient", MockSubscriber):
batch = salt.cli.batch.Batch(opts, quiet=True)
with patch.object(batch.local, "pub", Mock(side_effect=mock_pub)):
with patch.object(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def handle_stream(self, stream, address):

@pytest.fixture
def client(io_loop, config):
client = salt.transport.tcp.TCPPubClient(
client = salt.transport.tcp.PublishClient(
config.copy(), io_loop, host=config["master_ip"], port=config["publish_port"]
)
try:
Expand Down
81 changes: 34 additions & 47 deletions tests/pytests/unit/transport/test_base.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,57 @@
import ssl
import contextlib

import salt.transport.base
from tests.support.mock import patch

from tests.support.helpers import dedent

import pytest
from tests.support.mock import patch, Mock


@patch('ssl.SSLContext')
@patch("ssl.SSLContext")
def test_ssl_context_legacy_opts(mock):
ctx = salt.transport.base.ssl_context({
'certfile': "server.crt",
'keyfile': "server.key",
'cert_reqs': "CERT_NONE",
"ca_certs": "ca.crt",
})
ctx = salt.transport.base.ssl_context(
{
"certfile": "server.crt",
"keyfile": "server.key",
"cert_reqs": "CERT_NONE",
"ca_certs": "ca.crt",
}
)
ctx.load_cert_chain.assert_called_with(
"server.crt",
"server.key",
)
ctx.load_verify_locations.assert_called_with(
"ca.crt"
)
ctx.load_verify_locations.assert_called_with("ca.crt")
assert ssl.VerifyMode.CERT_NONE == ctx.verify_mode
assert not ctx.check_hostname


@patch('ssl.SSLContext')
@patch("ssl.SSLContext")
def test_ssl_context_opts(mock):
mock.verify_flags = ssl.VerifyFlags.VERIFY_X509_TRUSTED_FIRST
ctx = salt.transport.base.ssl_context({
'certfile': "server.crt",
'keyfile': "server.key",
'cert_reqs': "CERT_OPTIONAL",
"verify_locations": [
"ca.crt",
{"cafile": "crl.pem"},
{"capath": "/tmp/mycapathsdf"},
{"cadata": "mycadataother"},
{"CADATA": "mycadatasdf"},
],
"verify_flags": [
"VERIFY_CRL_CHECK_CHAIN",
]
})
ctx = salt.transport.base.ssl_context(
{
"certfile": "server.crt",
"keyfile": "server.key",
"cert_reqs": "CERT_OPTIONAL",
"verify_locations": [
"ca.crt",
{"cafile": "crl.pem"},
{"capath": "/tmp/mycapathsdf"},
{"cadata": "mycadataother"},
{"CADATA": "mycadatasdf"},
],
"verify_flags": [
"VERIFY_CRL_CHECK_CHAIN",
],
}
)
ctx.load_cert_chain.assert_called_with(
"server.crt",
"server.key",
)
ctx.load_verify_locations.assert_any_call(
cafile="ca.crt"
)
ctx.load_verify_locations.assert_any_call(
cafile="crl.pem"
)
ctx.load_verify_locations.assert_any_call(
capath="/tmp/mycapathsdf"
)
ctx.load_verify_locations.assert_any_call(
cadata="mycadataother"
)
ctx.load_verify_locations.assert_called_with(
cadata="mycadatasdf"
)
ctx.load_verify_locations.assert_any_call(cafile="ca.crt")
ctx.load_verify_locations.assert_any_call(cafile="crl.pem")
ctx.load_verify_locations.assert_any_call(capath="/tmp/mycapathsdf")
ctx.load_verify_locations.assert_any_call(cadata="mycadataother")
ctx.load_verify_locations.assert_called_with(cadata="mycadatasdf")
assert ssl.VerifyMode.CERT_OPTIONAL == ctx.verify_mode
assert ctx.check_hostname
assert ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN & ctx.verify_flags
2 changes: 1 addition & 1 deletion tests/pytests/unit/transport/test_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ async def test_async_tcp_pub_channel_connect_publish_port(
future.set_result(True)
with patch("salt.crypt.AsyncAuth.gen_token", patch_auth), patch(
"salt.crypt.AsyncAuth.authenticated", patch_auth
), patch("salt.transport.tcp.TCPPubClient", transport):
), patch("salt.transport.tcp.PublishClient", transport):
channel = salt.channel.client.AsyncPubChannel.factory(opts)
with channel:
# We won't be able to succeed the connection because we're not mocking the tornado coroutine
Expand Down

0 comments on commit 802b420

Please sign in to comment.