Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Replace flask with quart #43

Merged
merged 12 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
457 changes: 128 additions & 329 deletions poetry.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ safe-ds-runner = "safeds_runner.main:main"
[tool.poetry.dependencies]
python = "^3.11,<3.12"
safe-ds = ">=0.17,<0.18"
flask = "^3.0.0"
flask-cors = "^4.0.0"
flask-sock = "^0.7.0"
gevent = "^23.9.1"
hypercorn = "^0.16.0"
quart = "^0.19.4"

[tool.poetry.dev-dependencies]
pytest = "^7.4.4"
pytest-cov = "^4.1.0"
pytest-timeout = "^2.2.0"
pytest-asyncio = "^0.23.3"
simple-websocket = "^1.0.0"

[tool.poetry.group.docs.dependencies]
mkdocs = "^1.4.3"
Expand Down
13 changes: 2 additions & 11 deletions src/safeds_runner/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging

from safeds_runner.server.pipeline_manager import PipelineManager
from safeds_runner.server.server import SafeDsServer


def start_server(port: int) -> None:
Expand All @@ -14,15 +14,6 @@ def start_server(port: int) -> None:
builtins.print = functools.partial(print, flush=True) # type: ignore[assignment]

logging.getLogger().setLevel(logging.DEBUG)
# Startup early, so our multiprocessing setup works
app_pipeline_manager = PipelineManager()
app_pipeline_manager.startup()
from gevent.monkey import patch_all

# Patch WebSockets to work in parallel
patch_all()

from safeds_runner.server.server import SafeDsServer

safeds_server = SafeDsServer(app_pipeline_manager) # pragma: no cover
safeds_server = SafeDsServer() # pragma: no cover
safeds_server.listen(port) # pragma: no cover
lars-reimann marked this conversation as resolved.
Show resolved Hide resolved
38 changes: 20 additions & 18 deletions src/safeds_runner/server/pipeline_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Module that contains the infrastructure for pipeline execution in child processes."""

import asyncio
import json
import logging
import multiprocessing
Expand All @@ -12,7 +13,6 @@
from pathlib import Path
from typing import Any

import simple_websocket
import stack_data

from safeds_runner.server.messages import (
Expand Down Expand Up @@ -41,7 +41,7 @@ class PipelineManager:
def __init__(self) -> None:
"""Create a new PipelineManager object, which is lazily started, when needed."""
self._placeholder_map: dict = {}
self._websocket_target: list[simple_websocket.Server] = []
self._websocket_target: list[asyncio.Queue] = []

@cached_property
def _multiprocessing_manager(self) -> SyncManager:
Expand All @@ -55,10 +55,7 @@ def _messages_queue(self) -> queue.Queue[Message]:

@cached_property
def _messages_queue_thread(self) -> threading.Thread:
return threading.Thread(
target=self._handle_queue_messages,
daemon=True,
)
return threading.Thread(target=self._handle_queue_messages, daemon=True, args=(asyncio.get_event_loop(),))

@cached_property
def _memoization_map(self) -> MemoizationMap:
Expand All @@ -79,43 +76,48 @@ def startup(self) -> None:
if not self._messages_queue_thread.is_alive():
self._messages_queue_thread.start()

def _handle_queue_messages(self) -> None:
def _handle_queue_messages(self, event_loop: asyncio.AbstractEventLoop) -> None:
"""
Relay messages from pipeline processes to the currently connected websocket endpoint.

Should be used in a dedicated thread.

Parameters
----------
event_loop : asyncio.AbstractEventLoop
Event Loop that handles websocket connections.
"""
try:
while self._messages_queue is not None:
message = self._messages_queue.get()
message_encoded = json.dumps(message.to_dict())
# only send messages to the same connection once
for connection in set(self._websocket_target):
connection.send(message_encoded)
asyncio.run_coroutine_threadsafe(connection.put(message_encoded), event_loop)
except BaseException as error: # noqa: BLE001 # pragma: no cover
logging.warning("Message queue terminated: %s", error.__repr__()) # pragma: no cover

def connect(self, websocket_connection: simple_websocket.Server) -> None:
def connect(self, websocket_connection_queue: asyncio.Queue) -> None:
"""
Add a websocket connection to relay event messages to, which are occurring during pipeline execution.
Add a websocket connection queue to relay event messages to, which are occurring during pipeline execution.

Parameters
----------
websocket_connection : simple_websocket.Server
New websocket connection.
websocket_connection_queue : asyncio.Queue
Message Queue for a websocket connection.
"""
self._websocket_target.append(websocket_connection)
self._websocket_target.append(websocket_connection_queue)

def disconnect(self, websocket_connection: simple_websocket.Server) -> None:
def disconnect(self, websocket_connection_queue: asyncio.Queue) -> None:
"""
Remove a websocket target connection to no longer receive messages.
Remove a websocket target connection queue to no longer receive messages.

Parameters
----------
websocket_connection : simple_websocket.Server
Websocket connection to be removed.
websocket_connection_queue : asyncio.Queue
Message Queue for a websocket connection to be removed.
"""
self._websocket_target.remove(websocket_connection)
self._websocket_target.remove(websocket_connection_queue)

def execute_pipeline(
self,
Expand Down
151 changes: 68 additions & 83 deletions src/safeds_runner/server/server.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
"""Module containing the server, endpoints and utility functions."""

import asyncio
import json
import logging
import sys

import flask.app
import flask_sock
import simple_websocket
from flask import Flask
from flask_cors import CORS
from flask_sock import Sock
import hypercorn.asyncio
import quart.app

from safeds_runner.server import messages
from safeds_runner.server.json_encoder import SafeDsEncoder
Expand All @@ -22,63 +19,27 @@
from safeds_runner.server.pipeline_manager import PipelineManager


def create_flask_app(testing: bool = False) -> flask.app.App:
def create_flask_app() -> quart.app.Quart:
"""
Create a flask app, that handles all requests.

Parameters
----------
testing : bool
Whether the app should run in a testing context.

Returns
-------
flask.app.App
Flask app.
"""
flask_app = Flask(__name__)
# Websocket Configuration
flask_app.config["SOCK_SERVER_OPTIONS"] = {"ping_interval": 25}
flask_app.config["TESTING"] = testing

# Allow access from VSCode extension
CORS(flask_app, resources={r"/*": {"origins": "vscode-webview://*"}})
return flask_app


def create_flask_websocket(flask_app: flask.app.App) -> flask_sock.Sock:
"""
Create a flask websocket extension.

Parameters
----------
flask_app: flask.app.App
Flask App Instance.
Create a quart app, that handles all requests.

Returns
-------
flask_sock.Sock
Websocket extension for the provided flask app.
quart.app.Quart
App.
"""
return Sock(flask_app)
return quart.app.Quart(__name__)


class SafeDsServer:
"""Server containing the flask app, websocket handler and endpoints."""

def __init__(self, app_pipeline_manager: PipelineManager) -> None:
"""
Create a new server object.

Parameters
----------
app_pipeline_manager : PipelineManager
Manager responsible for executing pipelines sent to this server.
"""
self.app_pipeline_manager = app_pipeline_manager
def __init__(self) -> None:
"""Create a new server object."""
self.app_pipeline_manager = PipelineManager()
self.app = create_flask_app()
self.sock = create_flask_websocket(self.app)
self.sock.route("/WSMain")(lambda ws: self._ws_main(ws, self.app_pipeline_manager))
self.app.config["pipeline_manager"] = self.app_pipeline_manager
self.app.websocket("/WSMain")(SafeDsServer.ws_main)

def listen(self, port: int) -> None:
"""
Expand All @@ -90,41 +51,56 @@ def listen(self, port: int) -> None:
Port to listen on
"""
logging.info("Starting Safe-DS Runner on port %s", str(port))
serve_config = hypercorn.config.Config()
# Only bind to host=127.0.0.1. Connections from other devices should not be accepted
from gevent.pywsgi import WSGIServer
serve_config.bind = f"127.0.0.1:{port}"
serve_config.websocket_ping_interval = 25.0
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(hypercorn.asyncio.serve(self.app, serve_config))
event_loop.run_forever() # pragma: no cover

WSGIServer(("127.0.0.1", port), self.app, spawn=8).serve_forever()
@staticmethod
async def ws_main() -> None:
"""Handle websocket requests to the WSMain endpoint and delegates with the required objects."""
await SafeDsServer._ws_main(quart.websocket, quart.current_app.config["pipeline_manager"])

@staticmethod
def _ws_main(ws: simple_websocket.Server, pipeline_manager: PipelineManager) -> None:
async def _ws_main(ws: quart.Websocket, pipeline_manager: PipelineManager) -> None:
"""
Handle websocket requests to the WSMain endpoint.

This function handles the bidirectional communication between the runner and the VS Code extension.

Parameters
----------
ws : simple_websocket.Server
Websocket Connection, provided by flask.
ws : quart.Websocket
Connection
pipeline_manager : PipelineManager
Manager used to execute pipelines on, and retrieve placeholders from
Pipeline Manager
"""
logging.debug("Request to WSRunProgram")
pipeline_manager.connect(ws)
output_queue: asyncio.Queue = asyncio.Queue()
pipeline_manager.connect(output_queue)
foreground_handler = asyncio.create_task(SafeDsServer._ws_main_foreground(ws, pipeline_manager, output_queue))
background_handler = asyncio.create_task(SafeDsServer._ws_main_background(ws, output_queue))
await asyncio.gather(foreground_handler, background_handler)

@staticmethod
async def _ws_main_foreground(
ws: quart.Websocket,
pipeline_manager: PipelineManager,
output_queue: asyncio.Queue,
) -> None:
while True:
# This would be a JSON message
received_message: str = ws.receive()
if received_message is None:
logging.debug("Received EOF, closing connection")
pipeline_manager.disconnect(ws)
ws.close()
return
received_message: str = await ws.receive()
logging.debug("Received Message: %s", received_message)
received_object, error_detail, error_short = parse_validate_message(received_message)
if received_object is None:
logging.error(error_detail)
pipeline_manager.disconnect(ws)
ws.close(message=error_short)
await output_queue.put(None)
pipeline_manager.disconnect(output_queue)
await ws.close(code=1000, reason=error_short)
return
match received_object.type:
case "shutdown":
Expand All @@ -135,8 +111,9 @@ def _ws_main(ws: simple_websocket.Server, pipeline_manager: PipelineManager) ->
program_data, invalid_message = messages.validate_program_message_data(received_object.data)
if program_data is None:
logging.error("Invalid message data specified in: %s (%s)", received_message, invalid_message)
pipeline_manager.disconnect(ws)
ws.close(None, invalid_message)
await output_queue.put(None)
pipeline_manager.disconnect(output_queue)
await ws.close(code=1000, reason=invalid_message)
return
# This should only be called from the extension as it is a security risk
pipeline_manager.execute_pipeline(program_data, received_object.id)
Expand All @@ -147,8 +124,9 @@ def _ws_main(ws: simple_websocket.Server, pipeline_manager: PipelineManager) ->
)
if placeholder_query_data is None:
logging.error("Invalid message data specified in: %s (%s)", received_message, invalid_message)
pipeline_manager.disconnect(ws)
ws.close(None, invalid_message)
await output_queue.put(None)
pipeline_manager.disconnect(output_queue)
await ws.close(code=1000, reason=invalid_message)
return
placeholder_type, placeholder_value = pipeline_manager.get_placeholder(
received_object.id,
Expand All @@ -157,8 +135,8 @@ def _ws_main(ws: simple_websocket.Server, pipeline_manager: PipelineManager) ->
# send back a value message
if placeholder_type is not None:
try:
broadcast_message(
[ws],
await send_message(
ws,
Message(
message_type_placeholder_value,
received_object.id,
Expand All @@ -171,8 +149,8 @@ def _ws_main(ws: simple_websocket.Server, pipeline_manager: PipelineManager) ->
)
except TypeError as _encoding_error:
# if the value can't be encoded send back that the value exists but is not displayable
broadcast_message(
[ws],
await send_message(
ws,
Message(
message_type_placeholder_value,
received_object.id,
Expand All @@ -186,8 +164,8 @@ def _ws_main(ws: simple_websocket.Server, pipeline_manager: PipelineManager) ->
else:
# Send back empty type / value, to communicate that no placeholder exists (yet)
# Use name from query to allow linking a response to a request on the peer
broadcast_message(
[ws],
await send_message(
ws,
Message(
message_type_placeholder_value,
received_object.id,
Expand All @@ -198,18 +176,25 @@ def _ws_main(ws: simple_websocket.Server, pipeline_manager: PipelineManager) ->
if received_object.type not in messages.message_types:
logging.warning("Invalid message type: %s", received_object.type)

@staticmethod
async def _ws_main_background(ws: quart.Websocket, output_queue: asyncio.Queue) -> None:
while True:
encoded_message = await output_queue.get()
if encoded_message is None:
return
await ws.send(encoded_message)


def broadcast_message(connections: list[simple_websocket.Server], message: Message) -> None:
async def send_message(connection: quart.Websocket, message: Message) -> None:
"""
Send any message to all the provided connections (to the VS Code extension).
Send a message to the provided websocket connection (to the VS Code extension).

Parameters
----------
connections : list[simple_websocket.Server]
List of Websocket connections that should receive the message.
connection : quart.Websocket
Connection that should receive the message.
message : Message
Object that will be sent.
"""
message_encoded = json.dumps(message.to_dict(), cls=SafeDsEncoder)
for connection in connections:
connection.send(message_encoded)
await connection.send(message_encoded)
Loading
Loading