Skip to content

Commit

Permalink
feat: Replace flask with quart (#43)
Browse files Browse the repository at this point in the history
Closes #42 

### Summary of Changes

- replaces flask with quart
- rewrite some tests
- added test against crashing
- solution should be more scalable as it replaces wsgi with asgi

---------

Co-authored-by: megalinter-bot <[email protected]>
  • Loading branch information
WinPlay02 and megalinter-bot authored Jan 26, 2024
1 parent bbb23a5 commit 5520b68
Show file tree
Hide file tree
Showing 6 changed files with 520 additions and 567 deletions.
457 changes: 128 additions & 329 deletions poetry.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ safe-ds-runner = "safeds_runner.main:main"
[tool.poetry.dependencies]
python = "^3.11,<3.12"
safe-ds = ">=0.17,<0.18"
flask = "^3.0.0"
flask-cors = "^4.0.0"
flask-sock = "^0.7.0"
gevent = "^23.9.1"
hypercorn = "^0.16.0"
quart = "^0.19.4"

[tool.poetry.dev-dependencies]
pytest = "^7.4.4"
pytest-cov = "^4.1.0"
pytest-timeout = "^2.2.0"
pytest-asyncio = "^0.23.3"
simple-websocket = "^1.0.0"

[tool.poetry.group.docs.dependencies]
mkdocs = "^1.4.3"
Expand Down
13 changes: 2 additions & 11 deletions src/safeds_runner/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging

from safeds_runner.server.pipeline_manager import PipelineManager
from safeds_runner.server.server import SafeDsServer


def start_server(port: int) -> None:
Expand All @@ -14,15 +14,6 @@ def start_server(port: int) -> None:
builtins.print = functools.partial(print, flush=True) # type: ignore[assignment]

logging.getLogger().setLevel(logging.DEBUG)
# Startup early, so our multiprocessing setup works
app_pipeline_manager = PipelineManager()
app_pipeline_manager.startup()
from gevent.monkey import patch_all

# Patch WebSockets to work in parallel
patch_all()

from safeds_runner.server.server import SafeDsServer

safeds_server = SafeDsServer(app_pipeline_manager) # pragma: no cover
safeds_server = SafeDsServer()
safeds_server.listen(port) # pragma: no cover
38 changes: 20 additions & 18 deletions src/safeds_runner/server/pipeline_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Module that contains the infrastructure for pipeline execution in child processes."""

import asyncio
import json
import logging
import multiprocessing
Expand All @@ -12,7 +13,6 @@
from pathlib import Path
from typing import Any

import simple_websocket
import stack_data

from safeds_runner.server.messages import (
Expand Down Expand Up @@ -41,7 +41,7 @@ class PipelineManager:
def __init__(self) -> None:
"""Create a new PipelineManager object, which is lazily started, when needed."""
self._placeholder_map: dict = {}
self._websocket_target: list[simple_websocket.Server] = []
self._websocket_target: list[asyncio.Queue] = []

@cached_property
def _multiprocessing_manager(self) -> SyncManager:
Expand All @@ -55,10 +55,7 @@ def _messages_queue(self) -> queue.Queue[Message]:

@cached_property
def _messages_queue_thread(self) -> threading.Thread:
return threading.Thread(
target=self._handle_queue_messages,
daemon=True,
)
return threading.Thread(target=self._handle_queue_messages, daemon=True, args=(asyncio.get_event_loop(),))

@cached_property
def _memoization_map(self) -> MemoizationMap:
Expand All @@ -79,43 +76,48 @@ def startup(self) -> None:
if not self._messages_queue_thread.is_alive():
self._messages_queue_thread.start()

def _handle_queue_messages(self) -> None:
def _handle_queue_messages(self, event_loop: asyncio.AbstractEventLoop) -> None:
"""
Relay messages from pipeline processes to the currently connected websocket endpoint.
Should be used in a dedicated thread.
Parameters
----------
event_loop : asyncio.AbstractEventLoop
Event Loop that handles websocket connections.
"""
try:
while self._messages_queue is not None:
message = self._messages_queue.get()
message_encoded = json.dumps(message.to_dict())
# only send messages to the same connection once
for connection in set(self._websocket_target):
connection.send(message_encoded)
asyncio.run_coroutine_threadsafe(connection.put(message_encoded), event_loop)
except BaseException as error: # noqa: BLE001 # pragma: no cover
logging.warning("Message queue terminated: %s", error.__repr__()) # pragma: no cover

def connect(self, websocket_connection: simple_websocket.Server) -> None:
def connect(self, websocket_connection_queue: asyncio.Queue) -> None:
"""
Add a websocket connection to relay event messages to, which are occurring during pipeline execution.
Add a websocket connection queue to relay event messages to, which are occurring during pipeline execution.
Parameters
----------
websocket_connection : simple_websocket.Server
New websocket connection.
websocket_connection_queue : asyncio.Queue
Message Queue for a websocket connection.
"""
self._websocket_target.append(websocket_connection)
self._websocket_target.append(websocket_connection_queue)

def disconnect(self, websocket_connection: simple_websocket.Server) -> None:
def disconnect(self, websocket_connection_queue: asyncio.Queue) -> None:
"""
Remove a websocket target connection to no longer receive messages.
Remove a websocket target connection queue to no longer receive messages.
Parameters
----------
websocket_connection : simple_websocket.Server
Websocket connection to be removed.
websocket_connection_queue : asyncio.Queue
Message Queue for a websocket connection to be removed.
"""
self._websocket_target.remove(websocket_connection)
self._websocket_target.remove(websocket_connection_queue)

def execute_pipeline(
self,
Expand Down
151 changes: 68 additions & 83 deletions src/safeds_runner/server/server.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
"""Module containing the server, endpoints and utility functions."""

import asyncio
import json
import logging
import sys

import flask.app
import flask_sock
import simple_websocket
from flask import Flask
from flask_cors import CORS
from flask_sock import Sock
import hypercorn.asyncio
import quart.app

from safeds_runner.server import messages
from safeds_runner.server.json_encoder import SafeDsEncoder
Expand All @@ -22,63 +19,27 @@
from safeds_runner.server.pipeline_manager import PipelineManager


def create_flask_app(testing: bool = False) -> flask.app.App:
def create_flask_app() -> quart.app.Quart:
"""
Create a flask app, that handles all requests.
Parameters
----------
testing : bool
Whether the app should run in a testing context.
Returns
-------
flask.app.App
Flask app.
"""
flask_app = Flask(__name__)
# Websocket Configuration
flask_app.config["SOCK_SERVER_OPTIONS"] = {"ping_interval": 25}
flask_app.config["TESTING"] = testing

# Allow access from VSCode extension
CORS(flask_app, resources={r"/*": {"origins": "vscode-webview://*"}})
return flask_app


def create_flask_websocket(flask_app: flask.app.App) -> flask_sock.Sock:
"""
Create a flask websocket extension.
Parameters
----------
flask_app: flask.app.App
Flask App Instance.
Create a quart app, that handles all requests.
Returns
-------
flask_sock.Sock
Websocket extension for the provided flask app.
quart.app.Quart
App.
"""
return Sock(flask_app)
return quart.app.Quart(__name__)


class SafeDsServer:
"""Server containing the flask app, websocket handler and endpoints."""

def __init__(self, app_pipeline_manager: PipelineManager) -> None:
"""
Create a new server object.
Parameters
----------
app_pipeline_manager : PipelineManager
Manager responsible for executing pipelines sent to this server.
"""
self.app_pipeline_manager = app_pipeline_manager
def __init__(self) -> None:
"""Create a new server object."""
self.app_pipeline_manager = PipelineManager()
self.app = create_flask_app()
self.sock = create_flask_websocket(self.app)
self.sock.route("/WSMain")(lambda ws: self._ws_main(ws, self.app_pipeline_manager))
self.app.config["pipeline_manager"] = self.app_pipeline_manager
self.app.websocket("/WSMain")(SafeDsServer.ws_main)

def listen(self, port: int) -> None:
"""
Expand All @@ -90,41 +51,56 @@ def listen(self, port: int) -> None:
Port to listen on
"""
logging.info("Starting Safe-DS Runner on port %s", str(port))
serve_config = hypercorn.config.Config()
# Only bind to host=127.0.0.1. Connections from other devices should not be accepted
from gevent.pywsgi import WSGIServer
serve_config.bind = f"127.0.0.1:{port}"
serve_config.websocket_ping_interval = 25.0
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(hypercorn.asyncio.serve(self.app, serve_config))
event_loop.run_forever() # pragma: no cover

WSGIServer(("127.0.0.1", port), self.app, spawn=8).serve_forever()
@staticmethod
async def ws_main() -> None:
"""Handle websocket requests to the WSMain endpoint and delegates with the required objects."""
await SafeDsServer._ws_main(quart.websocket, quart.current_app.config["pipeline_manager"])

@staticmethod
def _ws_main(ws: simple_websocket.Server, pipeline_manager: PipelineManager) -> None:
async def _ws_main(ws: quart.Websocket, pipeline_manager: PipelineManager) -> None:
"""
Handle websocket requests to the WSMain endpoint.
This function handles the bidirectional communication between the runner and the VS Code extension.
Parameters
----------
ws : simple_websocket.Server
Websocket Connection, provided by flask.
ws : quart.Websocket
Connection
pipeline_manager : PipelineManager
Manager used to execute pipelines on, and retrieve placeholders from
Pipeline Manager
"""
logging.debug("Request to WSRunProgram")
pipeline_manager.connect(ws)
output_queue: asyncio.Queue = asyncio.Queue()
pipeline_manager.connect(output_queue)
foreground_handler = asyncio.create_task(SafeDsServer._ws_main_foreground(ws, pipeline_manager, output_queue))
background_handler = asyncio.create_task(SafeDsServer._ws_main_background(ws, output_queue))
await asyncio.gather(foreground_handler, background_handler)

@staticmethod
async def _ws_main_foreground(
ws: quart.Websocket,
pipeline_manager: PipelineManager,
output_queue: asyncio.Queue,
) -> None:
while True:
# This would be a JSON message
received_message: str = ws.receive()
if received_message is None:
logging.debug("Received EOF, closing connection")
pipeline_manager.disconnect(ws)
ws.close()
return
received_message: str = await ws.receive()
logging.debug("Received Message: %s", received_message)
received_object, error_detail, error_short = parse_validate_message(received_message)
if received_object is None:
logging.error(error_detail)
pipeline_manager.disconnect(ws)
ws.close(message=error_short)
await output_queue.put(None)
pipeline_manager.disconnect(output_queue)
await ws.close(code=1000, reason=error_short)
return
match received_object.type:
case "shutdown":
Expand All @@ -135,8 +111,9 @@ def _ws_main(ws: simple_websocket.Server, pipeline_manager: PipelineManager) ->
program_data, invalid_message = messages.validate_program_message_data(received_object.data)
if program_data is None:
logging.error("Invalid message data specified in: %s (%s)", received_message, invalid_message)
pipeline_manager.disconnect(ws)
ws.close(None, invalid_message)
await output_queue.put(None)
pipeline_manager.disconnect(output_queue)
await ws.close(code=1000, reason=invalid_message)
return
# This should only be called from the extension as it is a security risk
pipeline_manager.execute_pipeline(program_data, received_object.id)
Expand All @@ -147,8 +124,9 @@ def _ws_main(ws: simple_websocket.Server, pipeline_manager: PipelineManager) ->
)
if placeholder_query_data is None:
logging.error("Invalid message data specified in: %s (%s)", received_message, invalid_message)
pipeline_manager.disconnect(ws)
ws.close(None, invalid_message)
await output_queue.put(None)
pipeline_manager.disconnect(output_queue)
await ws.close(code=1000, reason=invalid_message)
return
placeholder_type, placeholder_value = pipeline_manager.get_placeholder(
received_object.id,
Expand All @@ -157,8 +135,8 @@ def _ws_main(ws: simple_websocket.Server, pipeline_manager: PipelineManager) ->
# send back a value message
if placeholder_type is not None:
try:
broadcast_message(
[ws],
await send_message(
ws,
Message(
message_type_placeholder_value,
received_object.id,
Expand All @@ -171,8 +149,8 @@ def _ws_main(ws: simple_websocket.Server, pipeline_manager: PipelineManager) ->
)
except TypeError as _encoding_error:
# if the value can't be encoded send back that the value exists but is not displayable
broadcast_message(
[ws],
await send_message(
ws,
Message(
message_type_placeholder_value,
received_object.id,
Expand All @@ -186,8 +164,8 @@ def _ws_main(ws: simple_websocket.Server, pipeline_manager: PipelineManager) ->
else:
# Send back empty type / value, to communicate that no placeholder exists (yet)
# Use name from query to allow linking a response to a request on the peer
broadcast_message(
[ws],
await send_message(
ws,
Message(
message_type_placeholder_value,
received_object.id,
Expand All @@ -198,18 +176,25 @@ def _ws_main(ws: simple_websocket.Server, pipeline_manager: PipelineManager) ->
if received_object.type not in messages.message_types:
logging.warning("Invalid message type: %s", received_object.type)

@staticmethod
async def _ws_main_background(ws: quart.Websocket, output_queue: asyncio.Queue) -> None:
while True:
encoded_message = await output_queue.get()
if encoded_message is None:
return
await ws.send(encoded_message)


def broadcast_message(connections: list[simple_websocket.Server], message: Message) -> None:
async def send_message(connection: quart.Websocket, message: Message) -> None:
"""
Send any message to all the provided connections (to the VS Code extension).
Send a message to the provided websocket connection (to the VS Code extension).
Parameters
----------
connections : list[simple_websocket.Server]
List of Websocket connections that should receive the message.
connection : quart.Websocket
Connection that should receive the message.
message : Message
Object that will be sent.
"""
message_encoded = json.dumps(message.to_dict(), cls=SafeDsEncoder)
for connection in connections:
connection.send(message_encoded)
await connection.send(message_encoded)
Loading

0 comments on commit 5520b68

Please sign in to comment.