Skip to content

Commit

Permalink
Wean of tcp transport bits in ws transport
Browse files Browse the repository at this point in the history
  • Loading branch information
dwoz committed Aug 14, 2023
1 parent 3b975e5 commit 9c842e9
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 95 deletions.
19 changes: 9 additions & 10 deletions salt/transport/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,7 @@ async def recv(self, timeout=None):
await asyncio.sleep(0.001)
if timeout == 0:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
return msg[b"body"]
try:
events, _, _ = select.select([self._stream.socket], [], [], 0)
except TimeoutError:
Expand All @@ -395,8 +394,7 @@ async def recv(self, timeout=None):
return
self.unpacker.feed(byts)
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
return msg[b"body"]
elif timeout:
try:
return await asyncio.wait_for(self.recv(), timeout=timeout)
Expand All @@ -410,8 +408,7 @@ async def recv(self, timeout=None):
return
else:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
return msg[b"body"]
while not self._closing:
async with self._read_in_progress:
try:
Expand All @@ -428,8 +425,7 @@ async def recv(self, timeout=None):
continue
self.unpacker.feed(byts)
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
return msg[b"body"]

async def on_recv_handler(self, callback):
while not self._stream:
Expand All @@ -438,7 +434,7 @@ async def on_recv_handler(self, callback):
while True:
msg = await self.recv()
if msg:
callback(msg)
await callback(msg)

def on_recv(self, callback):
"""
Expand Down Expand Up @@ -1457,7 +1453,10 @@ def pre_fork(self, process_manager):
primarily be used to create IPC channels and create our daemon process to
do the actual publishing
"""
process_manager.add_process(self.publish_daemon, name=self.__class__.__name__)
process_manager.add_process(
self.publish_daemon,
args=[self.publish_payload],
name=self.__class__.__name__)

async def publish_payload(self, payload, *args):
return await self.pub_server.publish_payload(payload)
Expand Down
138 changes: 73 additions & 65 deletions salt/transport/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import socket
import time
import warnings
import functools

import aiohttp
import aiohttp.web
Expand All @@ -16,11 +17,9 @@
from salt.transport.tcp import (
USE_LOAD_BALANCER,
LoadBalancerServer,
TCPPuller,
_get_bind_addr,
_get_socket,
_set_tcp_keepalive,
_TCPPubServerPublisher,
)

log = logging.getLogger(__name__)
Expand All @@ -46,13 +45,13 @@ class PublishClient(salt.transport.base.PublishClient):
def __init__(self, opts, io_loop, **kwargs): # pylint: disable=W0231
self.opts = opts
self.io_loop = io_loop
self.message_client = None
self.unpacker = salt.utils.msgpack.Unpacker()

self.connected = False
self._closing = False
self._stream = None
self._closing = False
self._closed = False

self.backoff = opts.get("tcp_reconnect_backoff", 1)
self.resolver = kwargs.get("resolver")
self._read_in_progress = Lock()
Expand All @@ -75,15 +74,19 @@ def __init__(self, opts, io_loop, **kwargs): # pylint: disable=W0231
self._ws = None
self._session = None
self._closing = False
self._closed = False
self.on_recv_task = None

async def _close(self):
if self._ws is not None:
await self._ws.close()
self._ws = None
if self._session is not None:
await self._session.close()
self._session = None
if self.on_recv_task:
self.on_recv_task.cancel()
await self.on_recv_task
self.on_recv_task = None
if self._ws is not None:
await self._ws.close()
self._ws = None
self._closed = True

def close(self):
Expand Down Expand Up @@ -149,8 +152,6 @@ async def getstream(self, **kwargs):
async def _connect(self, timeout=None):
if self._ws is None:
self._ws, self._session = await self.getstream(timeout=timeout)
# if not self._stream_return_running:
# self.io_loop.spawn_callback(self._stream_return)
if self.connect_callback:
self.connect_callback(True)
self.connected = True
Expand All @@ -170,16 +171,6 @@ async def connect(
self.disconnect_callback = None
await self._connect(timeout=timeout)

def _decode_messages(self, messages):
if not isinstance(messages, dict):
# TODO: For some reason we need to decode here for things
# to work. Fix this.
body = salt.utils.msgpack.loads(messages)
body = salt.transport.frame.decode_embedded_strs(body)
else:
body = messages
return body

async def send(self, msg):
await self.message_client.send(msg, reply=False)

Expand All @@ -189,8 +180,7 @@ async def recv(self, timeout=None):
await asyncio.sleep(0.001)
if timeout == 0:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
return msg
try:
raw_msg = await asyncio.wait_for(self._ws.receive(), 0.0001)
except TimeoutError:
Expand All @@ -201,8 +191,7 @@ async def recv(self, timeout=None):
if raw_msg.type == aiohttp.WSMsgType.BINARY:
self.unpacker.feed(raw_msg.data)
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
return msg
elif raw_msg.type == aiohttp.WSMsgType.ERROR:
log.error(
"ws connection closed with exception %s", self._ws.exception()
Expand All @@ -211,39 +200,43 @@ async def recv(self, timeout=None):
return await asyncio.wait_for(self.recv(), timeout=timeout)
else:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
return msg
while True:
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
return msg
raw_msg = await self._ws.receive()
if raw_msg.type == aiohttp.WSMsgType.TEXT:
if raw_msg.data == "close":
await self._ws.close()
if raw_msg.type == aiohttp.WSMsgType.BINARY:
self.unpacker.feed(raw_msg.data)
for msg in self.unpacker:
framed_msg = salt.transport.frame.decode_embedded_strs(msg)
return framed_msg["body"]
return msg
elif raw_msg.type == aiohttp.WSMsgType.ERROR:
log.error(
"ws connection closed with exception %s",
self._ws.exception(),
)

async def handle_on_recv(self, callback):
async def on_recv_handler(self, callback):
while not self._ws:
await asyncio.sleep(0.003)
while True:
msg = await self.recv()
callback(msg)
await callback(msg)

def on_recv(self, callback):
"""
Register a callback for received messages (that we didn't initiate)
"""
self.io_loop.spawn_callback(self.handle_on_recv, callback)
if self.on_recv_task:
# XXX: We are not awaiting this canceled task. This still needs to
# be addressed.
self.on_recv_task.cancel()
if callback is None:
self.on_recv_task = None
else:
self.on_recv_task = asyncio.create_task(self.on_recv_handler(callback))

def __enter__(self):
return self
Expand Down Expand Up @@ -287,7 +280,9 @@ def __init__(
self.ssl = ssl
self.clients = set()
self._run = None
self.pub_sock = None
self.pub_writer = None
self.pub_reader = None
self._connecting = None

@property
def topic_support(self):
Expand Down Expand Up @@ -333,6 +328,7 @@ def publish_daemon(
except (KeyboardInterrupt, SystemExit):
pass
finally:
print("CLOSE")
self.close()

async def publisher(
Expand Down Expand Up @@ -369,33 +365,41 @@ async def publisher(
await runner.setup()
site = aiohttp.web.SockSite(runner, sock, ssl_context=ctx)
log.info("Publisher binding to socket %s:%s", self.pub_host, self.pub_port)
print('start site')
await site.start()
print('start puller')

self._pub_payload = publish_payload
if self.pull_path:
pull_uri = self.pull_path
with salt.utils.files.set_umask(0o177):
self.puller = await asyncio.start_unix_server(self.pull_handler, self.pull_path)
else:
pull_uri = self.pull_port

self.pull_sock = TCPPuller(
pull_uri,
io_loop=io_loop,
payload_handler=publish_payload,
)
# Securely create socket
log.warning("Starting the Salt Puller on %s", pull_uri)
with salt.utils.files.set_umask(0o177):
self.pull_sock.start()
self.puller = await asyncio.start_server(self.pull_handler, self.pull_host, self.pull_port)
print('puller started')
while self._run.is_set():
await asyncio.sleep(0.3)
await server.stop()
await self.server.stop()
await self.puller.wait_closed()

async def pull_handler(self, reader, writer):
print("puller got connection")
unpacker = salt.utils.msgpack.Unpacker()
while True:
data = await reader.read(1024)
unpacker.feed(data)
for msg in unpacker:
await self._pub_payload(msg)

def pre_fork(self, process_manager):
"""
Do anything necessary pre-fork. Since this is on the master side this will
primarily be used to create IPC channels and create our daemon process to
do the actual publishing
"""
process_manager.add_process(self.publish_daemon, name=self.__class__.__name__)
process_manager.add_process(
self.publish_daemon,
args=[self.publish_payload],
name=self.__class__.__name__)

async def handle_request(self, request):
try:
Expand All @@ -412,41 +416,45 @@ async def handle_request(self, request):
while True:
await asyncio.sleep(1)

async def _connect(self):
if self.pull_path:
self.pub_reader, self.pub_writer = await asyncio.open_unix_connection(self.pull_path)
else:
self.pub_reader, self.pub_writer = await asyncio.open_connection(self.pull_host, self.pull_port)
self._connecting = None

def connect(self):
log.debug("Connect pusher %s", self.pull_path)
self.pub_sock = salt.utils.asynchronous.SyncWrapper(
_TCPPubServerPublisher,
(
self.pull_host,
self.pull_port,
self.pull_path,
),
loop_kwarg="io_loop",
)
self.pub_sock.connect()
if self._connecting is None:
self._connecting = asyncio.create_task(self._connect())
return self._connecting

async def publish(self, payload, **kwargs):
"""
Publish "load" to minions
"""
if not self.pub_sock:
self.connect()
self.pub_sock.send(payload)
if not self.pub_writer:
await self.connect()
self.pub_writer.write(salt.payload.dumps(payload))
await self.pub_writer.drain()

async def publish_payload(self, package, *args):
payload = salt.transport.frame.frame_msg(package)
payload = salt.payload.dumps(package)
for ws in list(self.clients):
try:
await ws.send_bytes(payload)
except ConnectionResetError:
self.clients.discard(ws)

def close(self):
if self.pub_sock:
self.pub_sock.close()
self.pub_sock = None
if self.pub_writer:
self.pub_writer.close()
self.pub_writer = None
self.pub_reader = None
if self._run is not None:
self._run.clear()
if self._connecting:
self._connecting.cancel()


class RequestServer(salt.transport.base.DaemonizedRequestServer):
Expand Down
20 changes: 20 additions & 0 deletions tests/pytests/functional/transport/server/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import salt.utils.process

import pytest

def transport_ids(value):
return "Transport({})".format(value)


@pytest.fixture(params=("zeromq", "tcp", "ws"), ids=transport_ids)
def transport(request):
return request.param


@pytest.fixture
def process_manager():
pm = salt.utils.process.ProcessManager()
try:
yield pm
finally:
pm.terminate()
Loading

0 comments on commit 9c842e9

Please sign in to comment.