From 0e395f73dcae4c31465426aa9411274a9bb36a23 Mon Sep 17 00:00:00 2001 From: Will Chen Date: Sat, 12 Oct 2024 22:13:06 -0700 Subject: [PATCH 1/9] Remove unused _requests_in_flight code --- mesop/server/server.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/mesop/server/server.py b/mesop/server/server.py index f37d0a94..bf02402c 100644 --- a/mesop/server/server.py +++ b/mesop/server/server.py @@ -50,13 +50,6 @@ STREAM_END = "data: \n\n" -def is_processing_request(): - return _requests_in_flight > 0 - - -_requests_in_flight = 0 - - def configure_flask_app( *, prod_mode: bool = True, exceptions_to_propagate: Sequence[type] = () ) -> Flask: @@ -271,16 +264,6 @@ def ui_stream() -> Response: response = make_sse_response(stream_with_context(generate_data(ui_request))) return response - @flask_app.before_request - def before_request(): - global _requests_in_flight - _requests_in_flight += 1 - - @flask_app.teardown_request - def teardown(error=None): - global _requests_in_flight - _requests_in_flight -= 1 - @flask_app.teardown_request def teardown_clear_stale_state_sessions(error=None): runtime().context().clear_stale_state_sessions() From b3360af87cf4b5ed23f40cb43a410b470a9f6733 Mon Sep 17 00:00:00 2001 From: Will Chen Date: Sat, 12 Oct 2024 22:16:38 -0700 Subject: [PATCH 2/9] Extract server_utils.py --- mesop/server/server.py | 149 +++-------------------------------- mesop/server/server_utils.py | 144 +++++++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+), 136 deletions(-) create mode 100644 mesop/server/server_utils.py diff --git a/mesop/server/server.py b/mesop/server/server.py index bf02402c..d2f56e16 100644 --- a/mesop/server/server.py +++ b/mesop/server/server.py @@ -1,11 +1,8 @@ import base64 import inspect import json -import os -import secrets import time -import urllib.parse as urlparse -from typing import Any, Generator, Iterable, Sequence +from typing import Generator, Sequence from urllib import request as urllib_request from urllib.error import URLError @@ -17,38 +14,22 @@ from mesop.events import LoadEvent from mesop.exceptions import format_traceback from mesop.runtime import runtime -from mesop.server.config import app_config from mesop.server.constants import WEB_COMPONENTS_PATH_SEGMENT +from mesop.server.server_utils import ( + AI_SERVICE_BASE_URL, + EXPERIMENTAL_EDITOR_TOOLBAR_ENABLED, + MESOP_CONCURRENT_UPDATES_ENABLED, + STREAM_END, + check_editor_access, + create_update_state_event, + is_same_site, + make_sse_response, + serialize, + sse_request, +) from mesop.utils.url_utils import remove_url_query_param from mesop.warn import warn -AI_SERVICE_BASE_URL = os.environ.get( - "MESOP_AI_SERVICE_BASE_URL", "http://localhost:43234" -) - -MESOP_CONCURRENT_UPDATES_ENABLED = ( - os.environ.get("MESOP_CONCURRENT_UPDATES_ENABLED", "false").lower() == "true" -) - -EXPERIMENTAL_EDITOR_TOOLBAR_ENABLED = ( - os.environ.get("MESOP_EXPERIMENTAL_EDITOR_TOOLBAR", "false").lower() == "true" -) - -if MESOP_CONCURRENT_UPDATES_ENABLED: - print("Experiment enabled: MESOP_CONCURRENT_UPDATES_ENABLED") - -if EXPERIMENTAL_EDITOR_TOOLBAR_ENABLED: - print("Experiment enabled: EXPERIMENTAL_EDITOR_TOOLBAR_ENABLED") - -LOCALHOSTS = ( - # For IPv4 localhost - "127.0.0.1", - # For IPv6 localhost - "::1", -) - -STREAM_END = "data: \n\n" - def configure_flask_app( *, prod_mode: bool = True, exceptions_to_propagate: Sequence[type] = () @@ -402,107 +383,3 @@ def hot_reload() -> Response: return response return flask_app - - -def check_editor_access(): - if not EXPERIMENTAL_EDITOR_TOOLBAR_ENABLED: - abort(403) # Throws a Forbidden Error - # Prevent accidental usages of editor mode outside of - # one's local computer - if request.remote_addr not in LOCALHOSTS: - abort(403) # Throws a Forbidden Error - # Visual editor should only be enabled in debug mode. - if not runtime().debug_mode: - abort(403) # Throws a Forbidden Error - - -def serialize(response: pb.UiResponse) -> str: - encoded = base64.b64encode(response.SerializeToString()).decode("utf-8") - return f"data: {encoded}\n\n" - - -def generate_state_token(): - """Generates a state token used to cache and look up Mesop state.""" - return secrets.token_urlsafe(16) - - -def create_update_state_event(diff: bool = False) -> str: - """Creates a state event to send to the client. - - Args: - diff: If true, sends diffs instead of the full state objects - - Returns: - serialized `pb.UiResponse` - """ - - state_token = "" - - # If enabled, we will save the user's state on the server, so that the client does not - # need to send the full state back on the next user event request. - if app_config.state_session_enabled: - state_token = generate_state_token() - runtime().context().save_state_to_session(state_token) - - update_state_event = pb.UpdateStateEvent( - state_token=state_token, - diff_states=runtime().context().diff_state() if diff else None, - full_states=runtime().context().serialize_state() if not diff else None, - ) - - return serialize(pb.UiResponse(update_state_event=update_state_event)) - - -def is_same_site(url1: str | None, url2: str | None): - """ - Determine if two URLs are the same site. - """ - # If either URL is false-y, they are not the same site - # (because we need a real URL to have an actual site) - if not url1 or not url2: - return False - try: - p1, p2 = urlparse.urlparse(url1), urlparse.urlparse(url2) - return p1.hostname == p2.hostname - except ValueError: - return False - - -SSE_DATA_PREFIX = "data: " - - -def sse_request( - url: str, data: dict[str, Any] -) -> Generator[dict[str, Any], None, None]: - """ - Make an SSE request and yield JSON parsed events. - """ - headers = { - "Content-Type": "application/json", - "Accept": "text/event-stream", - } - encoded_data = json.dumps(data).encode("utf-8") - req = urllib_request.Request( - url, data=encoded_data, headers=headers, method="POST" - ) - - with urllib_request.urlopen(req) as response: - for line in response: - if line.strip(): - decoded_line = line.decode("utf-8").strip() - if decoded_line.startswith(SSE_DATA_PREFIX): - event_data = json.loads(decoded_line[len(SSE_DATA_PREFIX) :]) - yield event_data - - -def make_sse_response( - response: Iterable[bytes] | bytes | Iterable[str] | str | None = None, -): - return Response( - response, - content_type="text/event-stream", - # "X-Accel-Buffering" impacts SSE responses due to response buffering (i.e. - # individual events may get batched together instead of being sent right away). - # See https://nginx.org/en/docs/http/ngx_http_proxy_module.html - headers={"X-Accel-Buffering": "no"}, - ) diff --git a/mesop/server/server_utils.py b/mesop/server/server_utils.py new file mode 100644 index 00000000..e2c49a28 --- /dev/null +++ b/mesop/server/server_utils.py @@ -0,0 +1,144 @@ +import base64 +import json +import os +import secrets +import urllib.parse as urlparse +from typing import Any, Generator, Iterable +from urllib import request as urllib_request + +from flask import Response, abort, request + +import mesop.protos.ui_pb2 as pb +from mesop.runtime import runtime +from mesop.server.config import app_config + +AI_SERVICE_BASE_URL = os.environ.get( + "MESOP_AI_SERVICE_BASE_URL", "http://localhost:43234" +) + +MESOP_CONCURRENT_UPDATES_ENABLED = ( + os.environ.get("MESOP_CONCURRENT_UPDATES_ENABLED", "false").lower() == "true" +) + +EXPERIMENTAL_EDITOR_TOOLBAR_ENABLED = ( + os.environ.get("MESOP_EXPERIMENTAL_EDITOR_TOOLBAR", "false").lower() == "true" +) + +if MESOP_CONCURRENT_UPDATES_ENABLED: + print("Experiment enabled: MESOP_CONCURRENT_UPDATES_ENABLED") + +if EXPERIMENTAL_EDITOR_TOOLBAR_ENABLED: + print("Experiment enabled: EXPERIMENTAL_EDITOR_TOOLBAR_ENABLED") + +LOCALHOSTS = ( + # For IPv4 localhost + "127.0.0.1", + # For IPv6 localhost + "::1", +) + +STREAM_END = "data: \n\n" + + +def check_editor_access(): + if not EXPERIMENTAL_EDITOR_TOOLBAR_ENABLED: + abort(403) # Throws a Forbidden Error + # Prevent accidental usages of editor mode outside of + # one's local computer + if request.remote_addr not in LOCALHOSTS: + abort(403) # Throws a Forbidden Error + # Visual editor should only be enabled in debug mode. + if not runtime().debug_mode: + abort(403) # Throws a Forbidden Error + + +def serialize(response: pb.UiResponse) -> str: + encoded = base64.b64encode(response.SerializeToString()).decode("utf-8") + return f"data: {encoded}\n\n" + + +def generate_state_token(): + """Generates a state token used to cache and look up Mesop state.""" + return secrets.token_urlsafe(16) + + +def create_update_state_event(diff: bool = False) -> str: + """Creates a state event to send to the client. + + Args: + diff: If true, sends diffs instead of the full state objects + + Returns: + serialized `pb.UiResponse` + """ + + state_token = "" + + # If enabled, we will save the user's state on the server, so that the client does not + # need to send the full state back on the next user event request. + if app_config.state_session_enabled: + state_token = generate_state_token() + runtime().context().save_state_to_session(state_token) + + update_state_event = pb.UpdateStateEvent( + state_token=state_token, + diff_states=runtime().context().diff_state() if diff else None, + full_states=runtime().context().serialize_state() if not diff else None, + ) + + return serialize(pb.UiResponse(update_state_event=update_state_event)) + + +def is_same_site(url1: str | None, url2: str | None): + """ + Determine if two URLs are the same site. + """ + # If either URL is false-y, they are not the same site + # (because we need a real URL to have an actual site) + if not url1 or not url2: + return False + try: + p1, p2 = urlparse.urlparse(url1), urlparse.urlparse(url2) + return p1.hostname == p2.hostname + except ValueError: + return False + + +SSE_DATA_PREFIX = "data: " + + +def sse_request( + url: str, data: dict[str, Any] +) -> Generator[dict[str, Any], None, None]: + """ + Make an SSE request and yield JSON parsed events. + """ + headers = { + "Content-Type": "application/json", + "Accept": "text/event-stream", + } + encoded_data = json.dumps(data).encode("utf-8") + req = urllib_request.Request( + url, data=encoded_data, headers=headers, method="POST" + ) + + with urllib_request.urlopen(req) as response: + for line in response: + if line.strip(): + decoded_line = line.decode("utf-8").strip() + if decoded_line.startswith(SSE_DATA_PREFIX): + event_data = json.loads(decoded_line[len(SSE_DATA_PREFIX) :]) + yield event_data + + +def make_sse_response( + response: Iterable[bytes] | bytes | Iterable[str] | str | None = None, +): + return Response( + response, + content_type="text/event-stream", + # "X-Accel-Buffering" impacts SSE responses due to response buffering (i.e. + # individual events may get batched together instead of being sent right away). + # See https://nginx.org/en/docs/http/ngx_http_proxy_module.html + headers={"X-Accel-Buffering": "no"}, + ) From 857e67f8788f7cc99b1363dff74f9330d7783407 Mon Sep 17 00:00:00 2001 From: Will Chen Date: Sat, 12 Oct 2024 22:21:13 -0700 Subject: [PATCH 3/9] Extract debug routes --- mesop/server/server.py | 142 +-------------------------- mesop/server/server_debug_routes.py | 146 ++++++++++++++++++++++++++++ 2 files changed, 148 insertions(+), 140 deletions(-) create mode 100644 mesop/server/server_debug_routes.py diff --git a/mesop/server/server.py b/mesop/server/server.py index d2f56e16..12ba673c 100644 --- a/mesop/server/server.py +++ b/mesop/server/server.py @@ -1,10 +1,5 @@ import base64 -import inspect -import json -import time from typing import Generator, Sequence -from urllib import request as urllib_request -from urllib.error import URLError from flask import Flask, Response, abort, request, stream_with_context @@ -15,17 +10,15 @@ from mesop.exceptions import format_traceback from mesop.runtime import runtime from mesop.server.constants import WEB_COMPONENTS_PATH_SEGMENT +from mesop.server.server_debug_routes import configure_debug_routes from mesop.server.server_utils import ( - AI_SERVICE_BASE_URL, EXPERIMENTAL_EDITOR_TOOLBAR_ENABLED, MESOP_CONCURRENT_UPDATES_ENABLED, STREAM_END, - check_editor_access, create_update_state_event, is_same_site, make_sse_response, serialize, - sse_request, ) from mesop.utils.url_utils import remove_url_query_param from mesop.warn import warn @@ -250,136 +243,5 @@ def teardown_clear_stale_state_sessions(error=None): runtime().context().clear_stale_state_sessions() if not prod_mode: - - @flask_app.route("/__editor__/commit", methods=["POST"]) - def page_commit() -> Response: - check_editor_access() - - try: - data = request.get_json() - except json.JSONDecodeError: - return Response("Invalid JSON format", status=400) - code = data.get("code") - path = data.get("path") - page_config = runtime().get_page_config(path=path) - assert page_config - module = inspect.getmodule(page_config.page_fn) - assert module - module_file = module.__file__ - assert module_file - module_file_path = module.__file__ - assert module_file_path - with open(module_file_path, "w") as file: - file.write(code) - - response_data = {"message": "Page commit successful"} - return Response( - json.dumps(response_data), status=200, mimetype="application/json" - ) - - @flask_app.route("/__editor__/save-interaction", methods=["POST"]) - def save_interaction() -> Response | dict[str, str]: - check_editor_access() - - data = request.get_json() - if not data: - return Response("Invalid JSON data", status=400) - - try: - req = urllib_request.Request( - AI_SERVICE_BASE_URL + "/save-interaction", - data=json.dumps(data).encode("utf-8"), - headers={"Content-Type": "application/json"}, - ) - with urllib_request.urlopen(req) as response: - if response.status == 200: - folder = json.loads(response.read().decode("utf-8"))["folder"] - return {"folder": folder} - else: - print(f"Error from AI service: {response.read().decode('utf-8')}") - return Response( - f"Error from AI service: {response.read().decode('utf-8')}", - status=500, - ) - except URLError as e: - return Response( - f"Error making request to AI service: {e!s}", status=500 - ) - - @flask_app.route("/__editor__/generate", methods=["POST"]) - def page_generate(): - check_editor_access() - - try: - data = request.get_json() - except json.JSONDecodeError: - return Response("Invalid JSON format", status=400) - if not data: - return Response("Invalid JSON data", status=400) - - prompt = data.get("prompt") - if not prompt: - return Response("Missing 'prompt' in JSON data", status=400) - - path = data.get("path") - page_config = runtime().get_page_config(path=path) - - line_number = data.get("lineNumber") - assert page_config - module = inspect.getmodule(page_config.page_fn) - if module is None: - return Response("Could not retrieve module source code.", status=500) - module_file = module.__file__ - assert module_file - with open(module_file) as file: - source_code = file.read() - print(f"Source code of module {module.__name__}:") - - def generate(): - try: - for event in sse_request( - AI_SERVICE_BASE_URL + "/adjust-mesop-app", - {"prompt": prompt, "code": source_code, "lineNumber": line_number}, - ): - if event.get("type") == "end": - sse_data = { - "type": "end", - "prompt": prompt, - "path": path, - "beforeCode": source_code, - "afterCode": event["code"], - "diff": event["diff"], - "message": "Prompt processed successfully", - } - yield f"data: {json.dumps(sse_data)}\n\n" - break - elif event.get("type") == "progress": - sse_data = {"data": event["data"], "type": "progress"} - yield f"data: {json.dumps(sse_data)}\n\n" - elif event.get("type") == "error": - sse_data = {"error": event["error"], "type": "error"} - yield f"data: {json.dumps(sse_data)}\n\n" - break - else: - raise Exception(f"Unknown event type: {event}") - except Exception as e: - sse_data = { - "error": "Could not connect to AI service: " + str(e), - "type": "error", - } - yield f"data: {json.dumps(sse_data)}\n\n" - - return make_sse_response(generate()) - - @flask_app.route("/__hot-reload__") - def hot_reload() -> Response: - counter = int(request.args["counter"]) - while True: - if counter < runtime().hot_reload_counter: - break - # Sleep a short duration but not too short that we hog up excessive CPU. - time.sleep(0.1) - response = Response(str(runtime().hot_reload_counter), status=200) - return response - + configure_debug_routes(flask_app) return flask_app diff --git a/mesop/server/server_debug_routes.py b/mesop/server/server_debug_routes.py new file mode 100644 index 00000000..e4179403 --- /dev/null +++ b/mesop/server/server_debug_routes.py @@ -0,0 +1,146 @@ +import inspect +import json +import time +from urllib import request as urllib_request +from urllib.error import URLError + +from flask import Flask, Response, request + +from mesop.runtime import runtime +from mesop.server.server_utils import ( + AI_SERVICE_BASE_URL, + check_editor_access, + make_sse_response, + sse_request, +) + + +def configure_debug_routes(flask_app: Flask): + @flask_app.route("/__editor__/commit", methods=["POST"]) + def page_commit() -> Response: + check_editor_access() + + try: + data = request.get_json() + except json.JSONDecodeError: + return Response("Invalid JSON format", status=400) + code = data.get("code") + path = data.get("path") + page_config = runtime().get_page_config(path=path) + assert page_config + module = inspect.getmodule(page_config.page_fn) + assert module + module_file = module.__file__ + assert module_file + module_file_path = module.__file__ + assert module_file_path + with open(module_file_path, "w") as file: + file.write(code) + + response_data = {"message": "Page commit successful"} + return Response( + json.dumps(response_data), status=200, mimetype="application/json" + ) + + @flask_app.route("/__editor__/save-interaction", methods=["POST"]) + def save_interaction() -> Response | dict[str, str]: + check_editor_access() + + data = request.get_json() + if not data: + return Response("Invalid JSON data", status=400) + + try: + req = urllib_request.Request( + AI_SERVICE_BASE_URL + "/save-interaction", + data=json.dumps(data).encode("utf-8"), + headers={"Content-Type": "application/json"}, + ) + with urllib_request.urlopen(req) as response: + if response.status == 200: + folder = json.loads(response.read().decode("utf-8"))["folder"] + return {"folder": folder} + else: + print(f"Error from AI service: {response.read().decode('utf-8')}") + return Response( + f"Error from AI service: {response.read().decode('utf-8')}", + status=500, + ) + except URLError as e: + return Response(f"Error making request to AI service: {e!s}", status=500) + + @flask_app.route("/__editor__/generate", methods=["POST"]) + def page_generate(): + check_editor_access() + + try: + data = request.get_json() + except json.JSONDecodeError: + return Response("Invalid JSON format", status=400) + if not data: + return Response("Invalid JSON data", status=400) + + prompt = data.get("prompt") + if not prompt: + return Response("Missing 'prompt' in JSON data", status=400) + + path = data.get("path") + page_config = runtime().get_page_config(path=path) + + line_number = data.get("lineNumber") + assert page_config + module = inspect.getmodule(page_config.page_fn) + if module is None: + return Response("Could not retrieve module source code.", status=500) + module_file = module.__file__ + assert module_file + with open(module_file) as file: + source_code = file.read() + print(f"Source code of module {module.__name__}:") + + def generate(): + try: + for event in sse_request( + AI_SERVICE_BASE_URL + "/adjust-mesop-app", + {"prompt": prompt, "code": source_code, "lineNumber": line_number}, + ): + if event.get("type") == "end": + sse_data = { + "type": "end", + "prompt": prompt, + "path": path, + "beforeCode": source_code, + "afterCode": event["code"], + "diff": event["diff"], + "message": "Prompt processed successfully", + } + yield f"data: {json.dumps(sse_data)}\n\n" + break + elif event.get("type") == "progress": + sse_data = {"data": event["data"], "type": "progress"} + yield f"data: {json.dumps(sse_data)}\n\n" + elif event.get("type") == "error": + sse_data = {"error": event["error"], "type": "error"} + yield f"data: {json.dumps(sse_data)}\n\n" + break + else: + raise Exception(f"Unknown event type: {event}") + except Exception as e: + sse_data = { + "error": "Could not connect to AI service: " + str(e), + "type": "error", + } + yield f"data: {json.dumps(sse_data)}\n\n" + + return make_sse_response(generate()) + + @flask_app.route("/__hot-reload__") + def hot_reload() -> Response: + counter = int(request.args["counter"]) + while True: + if counter < runtime().hot_reload_counter: + break + # Sleep a short duration but not too short that we hog up excessive CPU. + time.sleep(0.1) + response = Response(str(runtime().hot_reload_counter), status=200) + return response From 037f52c26823fb9280a387490fdd868654cd2dd3 Mon Sep 17 00:00:00 2001 From: Will Chen Date: Sat, 12 Oct 2024 22:31:24 -0700 Subject: [PATCH 4/9] Add flask-socketio dep --- build_defs/defaults.bzl | 4 ++++ build_defs/requirements.txt | 2 ++ build_defs/requirements_lock.txt | 35 ++++++++++++++++++++++++++++---- 3 files changed, 37 insertions(+), 4 deletions(-) diff --git a/build_defs/defaults.bzl b/build_defs/defaults.bzl index 275ba894..a18860b0 100644 --- a/build_defs/defaults.bzl +++ b/build_defs/defaults.bzl @@ -89,6 +89,10 @@ THIRD_PARTY_PY_FLASK = [ requirement("flask"), ] +THIRD_PARTY_PY_FLASK_SOCKETIO = [ + requirement("flask-socketio"), +] + THIRD_PARTY_PY_MATPLOTLIB = [ requirement("matplotlib"), ] diff --git a/build_defs/requirements.txt b/build_defs/requirements.txt index 40489e51..0e1849c7 100644 --- a/build_defs/requirements.txt +++ b/build_defs/requirements.txt @@ -8,6 +8,8 @@ python-dotenv # Optional (lazily-loaded) deps: sqlalchemy +flask-socketio + # greenlet is needed for SQL Alchemy depending on the architecture, but because of how # Bazel works using requirements_lock.txt, it does seem to able to install the # architecture specific requirements (in this case caught on Github CI). diff --git a/build_defs/requirements_lock.txt b/build_defs/requirements_lock.txt index c1c9513f..00bb1eac 100644 --- a/build_defs/requirements_lock.txt +++ b/build_defs/requirements_lock.txt @@ -16,6 +16,10 @@ babel==2.15.0 \ --hash=sha256:08706bdad8d0a3413266ab61bd6c34d0c28d6e1e7badf40a2cebe67644e2e1fb \ --hash=sha256:8daf0e265d05768bc6c7a314cf1321e9a123afc328cc635c18622a2f30a04413 # via mkdocs-material +bidict==0.23.1 \ + --hash=sha256:03069d763bc387bbd20e7d49914e75fc4132a41937fa3405417e1a5a2d006d71 \ + --hash=sha256:5dae8d4d79b552a71cbabc7deb25dfe8ce710b17ff41711e13010ead2abfc3e5 + # via python-socketio blinker==1.8.2 \ --hash=sha256:1779309f71bf239144b9399d06ae925637cf6634cf6bd131104184531bf67c01 \ --hash=sha256:8f77b09d3bf7c795e969e9486f39c2c5e9c39d4ee07424be2bc594ece9642d83 @@ -290,6 +294,12 @@ firebase-admin==6.5.0 \ flask==3.0.3 \ --hash=sha256:34e815dfaa43340d1d15a5c3a02b8476004037eb4840b34910c6e21679d288f3 \ --hash=sha256:ceb27b0af3823ea2737928a4d99d125a06175b8512c445cbd9a9ce200ef76842 + # via + # -r build_defs/requirements.txt + # flask-socketio +flask-socketio==5.4.1 \ + --hash=sha256:2e9b8864a5be37ca54f6c76a4d06b1ac5e0df61fde12d03afc81ab4057e1eb86 \ + --hash=sha256:895da879d162781b9193cbb8fe8f3cf25b263ff242980d5c5e6c16d3c03930d2 # via -r build_defs/requirements.txt fonttools==4.53.0 \ --hash=sha256:099634631b9dd271d4a835d2b2a9e042ccc94ecdf7e2dd9f7f34f7daf333358d \ @@ -580,6 +590,10 @@ grpcio-status==1.62.2 \ --hash=sha256:206ddf0eb36bc99b033f03b2c8e95d319f0044defae9b41ae21408e7e0cda48f \ --hash=sha256:62e1bfcb02025a1cd73732a2d33672d3e9d0df4d21c12c51e0bbcaf09bab742a # via google-api-core +h11==0.14.0 \ + --hash=sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d \ + --hash=sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761 + # via wsproto httplib2==0.22.0 \ --hash=sha256:14ae0a53c1ba8f3d37e9e27cf37eabb0fb9980f435ba405d546948b009dd64dc \ --hash=sha256:d7a10bc5ef5ab08322488bde8c726eeee5c8618723fdb399597ec58f3d82df81 @@ -716,7 +730,6 @@ markdown==3.6 \ --hash=sha256:48f276f4d8cfb8ce6527c8f79e2ee29708508bf4d40aa410fbc3b4ee832c850f \ --hash=sha256:ed4f41f6daecbeeb96e576ce414c41d2d876daa9a16cb35fa8ed8c2ddfad0224 # via - # -r build_defs/requirements.txt # mkdocs # mkdocs-autorefs # mkdocs-material @@ -1231,9 +1244,7 @@ pydantic-core==2.18.4 \ pygments==2.18.0 \ --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \ --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a - # via - # -r build_defs/requirements.txt - # mkdocs-material + # via mkdocs-material pyjwt[crypto]==2.8.0 \ --hash=sha256:57e28d156e3d5c10088e0c68abb90bfac3df82b40a71bd0daa20c65ccd5c23de \ --hash=sha256:59127c392cc44c2da5bb3192169a91f429924e17aff6534d70fdc02ab3e04320 @@ -1265,6 +1276,14 @@ python-dotenv==1.0.1 \ --hash=sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca \ --hash=sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a # via -r build_defs/requirements.txt +python-engineio==4.9.1 \ + --hash=sha256:7631cf5563086076611e494c643b3fa93dd3a854634b5488be0bba0ef9b99709 \ + --hash=sha256:f995e702b21f6b9ebde4e2000cd2ad0112ba0e5116ec8d22fe3515e76ba9dddd + # via python-socketio +python-socketio==5.11.4 \ + --hash=sha256:42efaa3e3e0b166fc72a527488a13caaac2cefc76174252486503bd496284945 \ + --hash=sha256:8b0b8ff2964b2957c865835e936310190639c00310a47d77321a594d1665355e + # via flask-socketio pytz==2024.1 \ --hash=sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812 \ --hash=sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319 @@ -1423,6 +1442,10 @@ rsa==4.9 \ --hash=sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7 \ --hash=sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21 # via google-auth +simple-websocket==1.1.0 \ + --hash=sha256:4af6069630a38ed6c561010f0e11a5bc0d4ca569b36306eb257cd9a192497c8c \ + --hash=sha256:7939234e7aa067c534abdab3a9ed933ec9ce4691b0713c78acb195560aa52ae4 + # via python-engineio six==1.16.0 \ --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 @@ -1543,3 +1566,7 @@ werkzeug==3.0.3 \ --hash=sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18 \ --hash=sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8 # via flask +wsproto==1.2.0 \ + --hash=sha256:ad565f26ecb92588a3e43bc3d96164de84cd9902482b130d0ddbaa9664a85065 \ + --hash=sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736 + # via simple-websocket From dfea01b30a5017ee72d8ac59d0a737ac3e6d8f7d Mon Sep 17 00:00:00 2001 From: Will Chen Date: Sat, 12 Oct 2024 23:10:09 -0700 Subject: [PATCH 5/9] Working websockets prototype --- build_defs/defaults.bzl | 4 + mesop/cli/cli.py | 10 +- mesop/server/BUILD | 11 +- mesop/server/server.py | 73 +++++- mesop/web/src/services/BUILD | 4 +- mesop/web/src/services/channel.ts | 372 +++++++++++++++++++++--------- package.json | 1 + yarn.lock | 31 +++ 8 files changed, 383 insertions(+), 123 deletions(-) diff --git a/build_defs/defaults.bzl b/build_defs/defaults.bzl index a18860b0..0e96539d 100644 --- a/build_defs/defaults.bzl +++ b/build_defs/defaults.bzl @@ -77,6 +77,10 @@ THIRD_PARTY_JS_HIGHLIGHTJS = [ "@npm//highlight.js", ] +THIRD_PARTY_JS_SOCKETIO_CLIENT = [ + "@npm//socket.io-client", +] + THIRD_PARTY_PY_ABSL_PY = [ requirement("absl-py"), ] diff --git a/mesop/cli/cli.py b/mesop/cli/cli.py index 1f8f1298..8fb591a1 100644 --- a/mesop/cli/cli.py +++ b/mesop/cli/cli.py @@ -153,7 +153,15 @@ def main(argv: Sequence[str]): log_startup(port=port()) logging.getLogger("werkzeug").setLevel(logging.WARN) - flask_app.run(host=get_public_host(), port=port(), use_reloader=False) + # flask_app.run(host=get_public_host(), port=port(), use_reloader=False) + socketio = flask_app.socketio # type: ignore + socketio.run( + flask_app, + host=get_public_host(), + port=port(), + use_reloader=False, + allow_unsafe_werkzeug=True, + ) if __name__ == "__main__": diff --git a/mesop/server/BUILD b/mesop/server/BUILD index 570f2086..bcd68fdb 100644 --- a/mesop/server/BUILD +++ b/mesop/server/BUILD @@ -4,6 +4,7 @@ load( "THIRD_PARTY_PY_DOTENV", "THIRD_PARTY_PY_FIREBASE_ADMIN", "THIRD_PARTY_PY_FLASK", + "THIRD_PARTY_PY_FLASK_SOCKETIO", "THIRD_PARTY_PY_GREENLET", "THIRD_PARTY_PY_MSGPACK", "THIRD_PARTY_PY_PYTEST", @@ -37,16 +38,18 @@ py_library( "//mesop/utils", "//mesop/warn", ] + THIRD_PARTY_PY_ABSL_PY + - THIRD_PARTY_PY_FLASK, + THIRD_PARTY_PY_FLASK + + THIRD_PARTY_PY_FLASK_SOCKETIO, ) py_library( name = "state_sessions", srcs = STATE_SESSIONS_SRCS, deps = [ - "//mesop/dataclass_utils", - "//mesop/exceptions", - ] + THIRD_PARTY_PY_MSGPACK + THIRD_PARTY_PY_DOTENV, + "//mesop/dataclass_utils", + "//mesop/exceptions", + ] + THIRD_PARTY_PY_MSGPACK + + THIRD_PARTY_PY_DOTENV, ) py_test( diff --git a/mesop/server/server.py b/mesop/server/server.py index 12ba673c..de7bf48e 100644 --- a/mesop/server/server.py +++ b/mesop/server/server.py @@ -25,7 +25,11 @@ def configure_flask_app( - *, prod_mode: bool = True, exceptions_to_propagate: Sequence[type] = () + *, + prod_mode: bool = True, + exceptions_to_propagate: Sequence[type] = (), + # TODO: plumb this from an env var + is_websockets_enabled=True, ) -> Flask: flask_app = Flask(__name__) @@ -244,4 +248,71 @@ def teardown_clear_stale_state_sessions(error=None): if not prod_mode: configure_debug_routes(flask_app) + ### WebSocket Event Handlers ### + if is_websockets_enabled: + from flask_socketio import SocketIO, emit + + socketio = SocketIO( + # TODO: DISABLE CORS!! + flask_app, + cors_allowed_origins="*", + ) # Adjust CORS as needed + + # @socketio.on("connect", namespace="/__ui__") + # def handle_connect(): + # # Handle new WebSocket connection + # print("Client connected via WebSocket") + # emit("response", {"message": "Connected to /__ui__ WebSocket"}) + + # @socketio.on("disconnect", namespace="/__ui__") + # def handle_disconnect(): + # # Handle WebSocket disconnection + # print("Client disconnected from WebSocket") + + @socketio.on("message", namespace="/__ui__") + def handle_message(message): + """ + Handle incoming messages from the client via WebSocket. + Expecting messages to be serialized UiRequest. + """ + try: + if not message: + emit("error", {"error": "Missing request payload"}) + return + + ui_request = pb.UiRequest() + ui_request.ParseFromString(base64.urlsafe_b64decode(message)) + + # Generate the response data + generator = generate_data(ui_request) + for data_chunk in generator: + if data_chunk == STREAM_END: + break + emit("response", {"data": data_chunk}) + + # Optionally, signal the end of the stream + emit("end", {"message": "Stream ended"}) + + except Exception as e: + error_response = pb.ServerError( + exception=str(e), traceback=format_traceback() + ) + if not runtime().debug_mode: + error_response.ClearField("traceback") + if "Mesop Internal Error:" in error_response.exception: + error_response.exception = ( + "Sorry, there was an internal error with Mesop." + ) + if "Mesop Developer Error:" in error_response.exception: + error_response.exception = ( + "Sorry, there was an error. Please contact the developer." + ) + ui_response = pb.UiResponse(error=error_response) + emit("error", {"error": serialize(ui_response)}) + + ### End of WebSocket Event Handlers ### + + # Replace the default Flask run method with SocketIO's run method + flask_app.socketio = socketio + return flask_app diff --git a/mesop/web/src/services/BUILD b/mesop/web/src/services/BUILD index 873e6761..b512cb49 100644 --- a/mesop/web/src/services/BUILD +++ b/mesop/web/src/services/BUILD @@ -1,4 +1,4 @@ -load("//build_defs:defaults.bzl", "ANGULAR_CORE_DEPS", "ANGULAR_MATERIAL_TS_DEPS", "ng_module") +load("//build_defs:defaults.bzl", "ANGULAR_CORE_DEPS", "ANGULAR_MATERIAL_TS_DEPS", "THIRD_PARTY_JS_SOCKETIO_CLIENT", "ng_module") package( default_visibility = ["//build_defs:mesop_internal"], @@ -13,5 +13,5 @@ ng_module( "//mesop/protos:ui_jspb_proto", "//mesop/web/src/dev_tools/services", "//mesop/web/src/utils", - ] + ANGULAR_CORE_DEPS + ANGULAR_MATERIAL_TS_DEPS, + ] + ANGULAR_CORE_DEPS + ANGULAR_MATERIAL_TS_DEPS + THIRD_PARTY_JS_SOCKETIO_CLIENT, ) diff --git a/mesop/web/src/services/channel.ts b/mesop/web/src/services/channel.ts index 0efd5ec2..e37d30de 100644 --- a/mesop/web/src/services/channel.ts +++ b/mesop/web/src/services/channel.ts @@ -20,6 +20,7 @@ import {getViewportSize} from '../utils/viewport_size'; import {ThemeService} from './theme_service'; import {getQueryParams} from '../utils/query_params'; import {ExperimentService} from './experiment_service'; +import {io, Socket} from 'socket.io-client'; // Import Socket.IO client // Pick 500ms as the minimum duration before showing a progress/busy indicator // for the channel. @@ -42,6 +43,11 @@ export enum ChannelStatus { CLOSED = 'CLOSED', } +export enum ConnectionType { + SSE = 'SSE', + WEBSOCKET = 'WEBSOCKET', +} + @Injectable({ providedIn: 'root', }) @@ -50,6 +56,7 @@ export class Channel { private isWaiting = false; private isWaitingTimeout: number | undefined; private eventSource!: SSE; + private socket!: Socket; private initParams!: InitParams; private states: States = new States(); private stateToken = ''; @@ -59,6 +66,7 @@ export class Channel { private queuedEvents: (() => void)[] = []; private hotReloadBackoffCounter = 0; private hotReloadCounter = 0; + private connectionType: ConnectionType = ConnectionType.SSE; // Default to SSE // Client-side state private overridedTitle = ''; @@ -86,7 +94,8 @@ export class Channel { /** * Return true if the channel has been doing work - * triggered by a user that's been taking a while. */ + * triggered by a user that's been taking a while. + */ isBusy(): boolean { return this.isWaiting && !this.isHotReloading(); } @@ -99,7 +108,29 @@ export class Channel { return this.componentConfigs; } - init(initParams: InitParams, request: UiRequest) { + /** + * Initialize the channel with the given parameters and UI request. + * Supports both SSE and WebSocket connections based on the `useWebSocket` flag. + * @param initParams Initialization parameters. + * @param request Initial UI request. + * @param useWebSocket Whether to use WebSocket for the connection. + */ + init(initParams: InitParams, request: UiRequest, useWebSocket = true) { + this.connectionType = useWebSocket + ? ConnectionType.WEBSOCKET + : ConnectionType.SSE; + + if (this.connectionType === ConnectionType.SSE) { + this.initSSE(initParams, request); + } else { + this.initWebSocket(initParams, request); + } + } + + /** + * Initialize Server-Sent Events (SSE) connection. + */ + private initSSE(initParams: InitParams, request: UiRequest) { this.eventSource = new SSE('/__ui__', { payload: generatePayloadString(request), }); @@ -114,7 +145,6 @@ export class Channel { this.initParams = initParams; this.eventSource.addEventListener('message', (e) => { - // Looks like Angular has a bug where it's not intercepting EventSource onmessage. zone.run(() => { const data = (e as any).data; if (data === '') { @@ -133,129 +163,218 @@ export class Channel { const array = toUint8Array(atob(data)); const uiResponse = UiResponse.deserializeBinary(array); - console.debug('Server event: ', uiResponse.toObject()); - switch (uiResponse.getTypeCase()) { - case UiResponse.TypeCase.UPDATE_STATE_EVENT: { - this.stateToken = uiResponse - .getUpdateStateEvent()! - .getStateToken()!; - switch (uiResponse.getUpdateStateEvent()!.getTypeCase()) { - case UpdateStateEvent.TypeCase.FULL_STATES: { - this.states = uiResponse - .getUpdateStateEvent()! - .getFullStates()!; - break; - } - case UpdateStateEvent.TypeCase.DIFF_STATES: { - const states = uiResponse - .getUpdateStateEvent()! - .getDiffStates()!; - - const numDiffStates = states.getStatesList().length; - const numStates = this.states.getStatesList().length; - - if (numDiffStates !== numStates) { - throw Error( - `Number of diffs (${numDiffStates}) doesn't equal the number of states (${numStates}))`, - ); - } - - // `this.states` should be populated at this point since the first update - // from the server should be the full state. - for (let i = 0; i < numDiffStates; ++i) { - const state = applyStateDiff( - this.states.getStatesList()[i].getData() as string, - states.getStatesList()[i].getData() as string, - ); - this.states.getStatesList()[i].setData(state); - } - break; - } - case UpdateStateEvent.TypeCase.TYPE_NOT_SET: - throw new Error('No state event data set'); - } + console.debug('Server event (SSE): ', uiResponse.toObject()); + this.handleUiResponse(uiResponse, onRender, onError, onCommand); + }); + }); + + this.eventSource.addEventListener('error', (e) => { + zone.run(() => { + console.error('SSE connection error:', e); + this.status = ChannelStatus.CLOSED; + this.eventSource.close(); + }); + }); + } + + /** + * Initialize WebSocket connection using Socket.IO. + */ + private initWebSocket(initParams: InitParams, request: UiRequest) { + this.socket = io('/__ui__', { + transports: ['websocket'], + reconnectionAttempts: 5, // Adjust as needed + // You can pass additional options here + }); + + this.status = ChannelStatus.OPEN; + this.isWaitingTimeout = setTimeout(() => { + this.isWaiting = true; + }, WAIT_TIMEOUT_MS); + + this.logger.log({type: 'StreamStart'}); + + const {zone, onRender, onError, onCommand} = initParams; + this.initParams = initParams; + + this.socket.on('connect', () => { + // Send the initial UiRequest upon connection + const payload = generatePayloadString(request); + this.socket.emit('message', payload); + }); + + this.socket.on('response', (data: any) => { + zone.run(() => { + if (data === '') { + this.socket.disconnect(); + this.status = ChannelStatus.CLOSED; + clearTimeout(this.isWaitingTimeout); + this.isWaiting = false; + this._isHotReloading = false; + this.logger.log({type: 'StreamEnd'}); + if (this.queuedEvents.length) { + const queuedEvent = this.queuedEvents.shift()!; + queuedEvent(); + } + return; + } + const prefix = 'data: '; + + const array = toUint8Array(atob(data.data.slice(prefix.length))); + const uiResponse = UiResponse.deserializeBinary(array); + console.debug('Server event (WebSocket): ', uiResponse.toObject()); + this.handleUiResponse(uiResponse, onRender, onError, onCommand); + }); + }); + + this.socket.on('error', (error: any) => { + zone.run(() => { + console.error('WebSocket error:', error); + this.status = ChannelStatus.CLOSED; + }); + }); + + this.socket.on('disconnect', (reason: string) => { + zone.run(() => { + this.status = ChannelStatus.CLOSED; + clearTimeout(this.isWaitingTimeout); + this.isWaiting = false; + this._isHotReloading = false; + }); + }); + } + + /** + * Handle UiResponse from the server. + */ + private handleUiResponse( + uiResponse: UiResponse, + onRender: InitParams['onRender'], + onError: InitParams['onError'], + onCommand: InitParams['onCommand'], + ) { + switch (uiResponse.getTypeCase()) { + case UiResponse.TypeCase.UPDATE_STATE_EVENT: { + this.stateToken = uiResponse.getUpdateStateEvent()!.getStateToken()!; + switch (uiResponse.getUpdateStateEvent()!.getTypeCase()) { + case UpdateStateEvent.TypeCase.FULL_STATES: { + this.states = uiResponse.getUpdateStateEvent()!.getFullStates()!; break; } - case UiResponse.TypeCase.RENDER: { - const rootComponent = uiResponse.getRender()!.getRootComponent()!; - const componentDiff = uiResponse.getRender()!.getComponentDiff()!; + case UpdateStateEvent.TypeCase.DIFF_STATES: { + const states = uiResponse.getUpdateStateEvent()!.getDiffStates()!; - for (const command of uiResponse.getRender()!.getCommandsList()) { - onCommand(command); - } + const numDiffStates = states.getStatesList().length; + const numStates = this.states.getStatesList().length; - const title = - this.overridedTitle || uiResponse.getRender()!.getTitle(); - if (title) { - this.title.setTitle(title); + if (numDiffStates !== numStates) { + throw Error( + `Number of diffs (${numDiffStates}) doesn't equal the number of states (${numStates}))`, + ); } - if ( - componentDiff !== undefined && - this.rootComponent !== undefined - ) { - // Angular does not update the UI if we apply the diff on the root - // component instance which is why we create copy of the root component - // first. - const rootComponentToUpdate = ComponentProto.deserializeBinary( - this.rootComponent.serializeBinary(), + // `this.states` should be populated at this point since the first update + // from the server should be the full state. + for (let i = 0; i < numDiffStates; ++i) { + const state = applyStateDiff( + this.states.getStatesList()[i].getData() as string, + states.getStatesList()[i].getData() as string, ); - applyComponentDiff(rootComponentToUpdate, componentDiff); - this.rootComponent = rootComponentToUpdate; - } else { - this.rootComponent = rootComponent; - } - const experimentSettings = uiResponse - .getRender()! - .getExperimentSettings(); - if (experimentSettings) { - this.experimentService.concurrentUpdatesEnabled = - experimentSettings.getConcurrentUpdatesEnabled() ?? false; - this.experimentService.experimentalEditorToolbarEnabled = - experimentSettings.getExperimentalEditorToolbarEnabled() ?? - false; + this.states.getStatesList()[i].setData(state); } - - this.componentConfigs = uiResponse - .getRender()! - .getComponentConfigsList(); - onRender( - this.rootComponent, - this.componentConfigs, - uiResponse.getRender()!.getJsModulesList(), - ); - this.logger.log({ - type: 'RenderLog', - states: this.states, - rootComponent: this.rootComponent, - }); break; } - case UiResponse.TypeCase.ERROR: - if ( - uiResponse.getError()?.getException() === - 'Token not found in state session backend.' - ) { - this.queuedEvents.unshift(() => { - console.warn( - 'Token not found in state session backend. Retrying user event.', - ); - request.getUserEvent()!.clearStateToken(); - request.getUserEvent()!.setStates(this.states); - this.init(this.initParams, request); - }); - } else { - onError(uiResponse.getError()!); - console.log('error', uiResponse.getError()); - } - break; - case UiResponse.TypeCase.TYPE_NOT_SET: - throw new Error(`Unhandled case for server event: ${uiResponse}`); + case UpdateStateEvent.TypeCase.TYPE_NOT_SET: + throw new Error('No state event data set'); } - }); - }); + break; + } + case UiResponse.TypeCase.RENDER: { + const rootComponent = uiResponse.getRender()!.getRootComponent()!; + const componentDiff = uiResponse.getRender()!.getComponentDiff()!; + + for (const command of uiResponse.getRender()!.getCommandsList()) { + onCommand(command); + } + + const title = this.overridedTitle || uiResponse.getRender()!.getTitle(); + if (title) { + this.title.setTitle(title); + } + + if (componentDiff !== undefined && this.rootComponent !== undefined) { + // Angular does not update the UI if we apply the diff on the root + // component instance which is why we create copy of the root component + // first. + const rootComponentToUpdate = ComponentProto.deserializeBinary( + this.rootComponent.serializeBinary(), + ); + applyComponentDiff(rootComponentToUpdate, componentDiff); + this.rootComponent = rootComponentToUpdate; + } else { + this.rootComponent = rootComponent; + } + const experimentSettings = uiResponse + .getRender()! + .getExperimentSettings(); + if (experimentSettings) { + this.experimentService.concurrentUpdatesEnabled = + experimentSettings.getConcurrentUpdatesEnabled() ?? false; + this.experimentService.experimentalEditorToolbarEnabled = + experimentSettings.getExperimentalEditorToolbarEnabled() ?? false; + } + + this.componentConfigs = uiResponse + .getRender()! + .getComponentConfigsList(); + onRender( + this.rootComponent, + this.componentConfigs, + uiResponse.getRender()!.getJsModulesList(), + ); + this.logger.log({ + type: 'RenderLog', + states: this.states, + rootComponent: this.rootComponent, + }); + break; + } + case UiResponse.TypeCase.ERROR: + if ( + uiResponse.getError()?.getException() === + 'Token not found in state session backend.' + ) { + this.queuedEvents.unshift(() => { + console.warn( + 'Token not found in state session backend. Retrying user event.', + ); + // Assuming you have access to the original request here + // If not, you might need to store it elsewhere + const request = new UiRequest(); + const userEvent = new UserEvent(); + userEvent.clearStateToken(); + userEvent.setStates(this.states); + request.setUserEvent(userEvent); + this.init( + this.initParams, + request, + this.connectionType === ConnectionType.WEBSOCKET, + ); + }); + } else { + onError(uiResponse.getError()!); + console.log('error', uiResponse.getError()); + } + break; + case UiResponse.TypeCase.TYPE_NOT_SET: + throw new Error(`Unhandled case for server event: ${uiResponse}`); + } } + /** + * Dispatch a user event to the server. + * Supports both SSE and WebSocket based on the current connection type. + */ dispatch(userEvent: UserEvent) { // Every user event should have an event handler, // except for the ones below: @@ -283,7 +402,11 @@ export class Channel { const request = new UiRequest(); request.setUserEvent(userEvent); - this.init(this.initParams, request); + this.init( + this.initParams, + request, + this.connectionType === ConnectionType.WEBSOCKET, + ); }; this.logger.log({type: 'UserEventLog', userEvent: userEvent}); @@ -315,6 +438,9 @@ export class Channel { } } + /** + * Check for hot reload by polling the server. + */ checkForHotReload() { const pollHotReloadEndpoint = async () => { try { @@ -354,6 +480,9 @@ export class Channel { this.overridedTitle = newTitle; } + /** + * Trigger a hot reload by sending a navigation event to the server. + */ hotReload() { // Only hot reload if there's no request in-flight. // Most likely the in-flight request will receive the updated UI. @@ -372,10 +501,17 @@ export class Channel { userEvent.setThemeSettings(this.themeService.getThemeSettings()); userEvent.setQueryParamsList(getQueryParams()); request.setUserEvent(userEvent); - this.init(this.initParams, request); + this.init( + this.initParams, + request, + this.connectionType === ConnectionType.WEBSOCKET, + ); } } +/** + * Generate a URL-safe base64 payload string from the UiRequest. + */ function generatePayloadString(request: UiRequest): string { request.setPath(window.location.pathname); const array = request.serializeBinary(); @@ -386,6 +522,9 @@ function generatePayloadString(request: UiRequest): string { return byteString; } +/** + * Convert a Uint8Array to a string. + */ function fromUint8Array(array: Uint8Array): string { // Chunk this to avoid RangeError: Maximum call stack size exceeded let result = ''; @@ -399,6 +538,9 @@ function fromUint8Array(array: Uint8Array): string { return result; } +/** + * Convert a string to a Uint8Array. + */ function toUint8Array(byteString: string): Uint8Array { const byteArray = new Uint8Array(byteString.length); for (let i = 0; i < byteString.length; i++) { diff --git a/package.json b/package.json index 0637b110..3382cfbb 100644 --- a/package.json +++ b/package.json @@ -67,6 +67,7 @@ "highlightjs": "^9.16.2", "rxjs": "^6.6.7", "rxjs-tslint-rules": "^4.34.8", + "socket.io-client": "^4.8.0", "tslib": "^2.3.1", "zone.js": "~0.14.0" }, diff --git a/yarn.lock b/yarn.lock index 3c78a442..c5a314d6 100644 --- a/yarn.lock +++ b/yarn.lock @@ -7943,6 +7943,17 @@ engine.io-client@~6.5.2: ws "~8.11.0" xmlhttprequest-ssl "~2.0.0" +engine.io-client@~6.6.1: + version "6.6.1" + resolved "https://registry.yarnpkg.com/engine.io-client/-/engine.io-client-6.6.1.tgz#28a9cc4e90d448e1d0ba9369ad08a7af82f9956a" + integrity sha512-aYuoak7I+R83M/BBPIOs2to51BmFIpC1wZe6zZzMrT2llVsHy5cvcmdsJgP2Qz6smHu+sD9oexiSUAVd8OfBPw== + dependencies: + "@socket.io/component-emitter" "~3.1.0" + debug "~4.3.1" + engine.io-parser "~5.2.1" + ws "~8.17.1" + xmlhttprequest-ssl "~2.1.1" + engine.io-parser@~2.1.0, engine.io-parser@~2.1.1: version "2.1.3" resolved "https://registry.yarnpkg.com/engine.io-parser/-/engine.io-parser-2.1.3.tgz#757ab970fbf2dfb32c7b74b033216d5739ef79a6" @@ -15226,6 +15237,16 @@ socket.io-client@^4.4.1: engine.io-client "~6.5.2" socket.io-parser "~4.2.4" +socket.io-client@^4.8.0: + version "4.8.0" + resolved "https://registry.yarnpkg.com/socket.io-client/-/socket.io-client-4.8.0.tgz#2ea0302d0032d23422bd2860f78127a800cad6a2" + integrity sha512-C0jdhD5yQahMws9alf/yvtsMGTaIDBnZ8Rb5HU56svyq0l5LIrGzIDZZD5pHQlmzxLuU91Gz+VpQMKgCTNYtkw== + dependencies: + "@socket.io/component-emitter" "~3.1.0" + debug "~4.3.2" + engine.io-client "~6.6.1" + socket.io-parser "~4.2.4" + socket.io-parser@~3.2.0: version "3.2.0" resolved "https://registry.yarnpkg.com/socket.io-parser/-/socket.io-parser-3.2.0.tgz#e7c6228b6aa1f814e6148aea325b51aa9499e077" @@ -17174,6 +17195,11 @@ ws@~8.11.0: resolved "https://registry.yarnpkg.com/ws/-/ws-8.11.0.tgz#6a0d36b8edfd9f96d8b25683db2f8d7de6e8e143" integrity sha512-HPG3wQd9sNQoT9xHyNCXoDUa+Xw/VevmY9FoHyQ+g+rrMn4j6FB4np7Z0OhdTgjx6MgQLK7jwSy1YecU1+4Asg== +ws@~8.17.1: + version "8.17.1" + resolved "https://registry.yarnpkg.com/ws/-/ws-8.17.1.tgz#9293da530bb548febc95371d90f9c878727d919b" + integrity sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ== + x-is-string@^0.1.0: version "0.1.0" resolved "https://registry.yarnpkg.com/x-is-string/-/x-is-string-0.1.0.tgz#474b50865af3a49a9c4657f05acd145458f77d82" @@ -17217,6 +17243,11 @@ xmlhttprequest-ssl@~2.0.0: resolved "https://registry.yarnpkg.com/xmlhttprequest-ssl/-/xmlhttprequest-ssl-2.0.0.tgz#91360c86b914e67f44dce769180027c0da618c67" integrity sha512-QKxVRxiRACQcVuQEYFsI1hhkrMlrXHPegbbd1yn9UHOmRxY+si12nQYzri3vbzt8VdTTRviqcKxcyllFas5z2A== +xmlhttprequest-ssl@~2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/xmlhttprequest-ssl/-/xmlhttprequest-ssl-2.1.1.tgz#0d045c3b2babad8e7db1af5af093f5d0d60df99a" + integrity sha512-ptjR8YSJIXoA3Mbv5po7RtSYHO6mZr8s7i5VGmEk7QY2pQWyT1o0N+W1gKbOyJPUCGXGnuw0wqe8f0L6Y0ny7g== + xtend@^4.0.0, xtend@^4.0.1: version "4.0.2" resolved "https://registry.yarnpkg.com/xtend/-/xtend-4.0.2.tgz#bb72779f5fa465186b1f438f674fa347fdb5db54" From c2f78327134ff1b8e84ce2c404360a7ad37e1b88 Mon Sep 17 00:00:00 2001 From: Will Chen Date: Mon, 14 Oct 2024 13:43:28 -0700 Subject: [PATCH 6/9] up --- .github/workflows/ci.yml | 8 ++ .vscode/settings.json | 3 +- mesop/cli/cli.py | 21 +-- mesop/components/input/e2e/input_test.ts | 3 + mesop/examples/e2e/chat_test.ts | 1 + mesop/protos/ui.proto | 1 + mesop/server/server.py | 7 +- mesop/server/server_utils.py | 14 +- mesop/server/wsgi_app.py | 10 ++ mesop/web/src/services/channel.ts | 143 +++++++------------ mesop/web/src/services/experiment_service.ts | 1 + scripts/cli_prod.sh | 1 + 12 files changed, 106 insertions(+), 107 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1df5fc54..9e4b9568 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -72,6 +72,14 @@ jobs: name: playwright-report-with-concurrent-updates-enabled path: playwright-report-with-concurrent-updates-enabled/ retention-days: 30 + - name: Run playwright test with websockets enabled + run: MESOP_WEBSOCKETES_ENABLED=true PLAYWRIGHT_HTML_OUTPUT_DIR=playwright-report-with-websockets-enabled yarn playwright test + - uses: actions/upload-artifact@a8a3f3ad30e3422c9c7b888a15615d19a852ae32 # v3.1.3 + if: always() + with: + name: playwright-report-with-websockets-enabled + path: playwright-report-with-websockets-enabled/ + retention-days: 30 - name: Run playwright test with memory state session run: MESOP_STATE_SESSION_BACKEND=memory PLAYWRIGHT_HTML_OUTPUT_DIR=playwright-report-with-memory-state-session yarn playwright test - uses: actions/upload-artifact@a8a3f3ad30e3422c9c7b888a15615d19a852ae32 # v3.1.3 diff --git a/.vscode/settings.json b/.vscode/settings.json index a19739ba..46baa90e 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -23,5 +23,6 @@ "bazel-out/**": true, "bazel-mesop/**": true }, - "python.analysis.extraPaths": ["./bazel-bin"] + "python.analysis.extraPaths": ["./bazel-bin"], + "typescript.tsdk": "node_modules/typescript/lib" } diff --git a/mesop/cli/cli.py b/mesop/cli/cli.py index 8fb591a1..9d061516 100644 --- a/mesop/cli/cli.py +++ b/mesop/cli/cli.py @@ -22,6 +22,7 @@ from mesop.server.flags import port from mesop.server.logging import log_startup from mesop.server.server import configure_flask_app +from mesop.server.server_utils import MESOP_WEBSOCKETS_ENABLED from mesop.server.static_file_serving import configure_static_file_serving from mesop.utils.host_util import get_public_host from mesop.utils.runfiles import get_runfile_location @@ -153,15 +154,17 @@ def main(argv: Sequence[str]): log_startup(port=port()) logging.getLogger("werkzeug").setLevel(logging.WARN) - # flask_app.run(host=get_public_host(), port=port(), use_reloader=False) - socketio = flask_app.socketio # type: ignore - socketio.run( - flask_app, - host=get_public_host(), - port=port(), - use_reloader=False, - allow_unsafe_werkzeug=True, - ) + if MESOP_WEBSOCKETS_ENABLED: + socketio = flask_app.socketio # type: ignore + socketio.run( + flask_app, + host=get_public_host(), + port=port(), + use_reloader=False, + allow_unsafe_werkzeug=True, + ) + else: + flask_app.run(host=get_public_host(), port=port(), use_reloader=False) if __name__ == "__main__": diff --git a/mesop/components/input/e2e/input_test.ts b/mesop/components/input/e2e/input_test.ts index 3e01e4d1..e9552517 100644 --- a/mesop/components/input/e2e/input_test.ts +++ b/mesop/components/input/e2e/input_test.ts @@ -19,6 +19,7 @@ test('test input on_blur works', async ({page}) => { // Fill in input and then click button and make sure values match await page.getByLabel('Input').click(); await page.getByLabel('Input').fill('abc'); + await page.getByLabel('Input').blur(); await page.getByRole('button', {name: 'button'}).click(); await expect(page.getByText('Input: abc')).toBeVisible(); await expect( @@ -28,6 +29,7 @@ test('test input on_blur works', async ({page}) => { // Same with textarea: await page.getByLabel('Regular textarea').click(); await page.getByLabel('Regular textarea').fill('123'); + await page.getByLabel('Regular textarea').blur(); await page.getByRole('button', {name: 'button'}).click(); await expect(page.getByText('Input: 123')).toBeVisible(); await expect( @@ -37,6 +39,7 @@ test('test input on_blur works', async ({page}) => { // Same with native textarea: await page.getByRole('textbox').nth(2).click(); await page.getByRole('textbox').nth(2).fill('second_textarea'); + await page.getByRole('textbox').nth(2).blur(); await page.getByRole('button', {name: 'button'}).click(); await expect(page.getByText('Input: second_textarea')).toBeVisible(); await expect( diff --git a/mesop/examples/e2e/chat_test.ts b/mesop/examples/e2e/chat_test.ts index 5563a057..772c9925 100644 --- a/mesop/examples/e2e/chat_test.ts +++ b/mesop/examples/e2e/chat_test.ts @@ -8,6 +8,7 @@ test('Chat UI can send messages and display responses', async ({page}) => { // Test that we can send a message. await page.locator('//input').fill('Lorem ipsum'); + await page.locator('//input').blur(); // Need to wait for the input state to be saved before clicking. await page.waitForTimeout(2000); await page.getByRole('button').filter({hasText: 'send'}).click(); diff --git a/mesop/protos/ui.proto b/mesop/protos/ui.proto index d0e33a69..3ec822d6 100644 --- a/mesop/protos/ui.proto +++ b/mesop/protos/ui.proto @@ -179,6 +179,7 @@ message RenderEvent { message ExperimentSettings { optional bool experimental_editor_toolbar_enabled = 1; optional bool concurrent_updates_enabled = 2; + optional bool websockets_enabled = 3; } // UI response event for updating state. diff --git a/mesop/server/server.py b/mesop/server/server.py index de7bf48e..33be6dba 100644 --- a/mesop/server/server.py +++ b/mesop/server/server.py @@ -14,6 +14,7 @@ from mesop.server.server_utils import ( EXPERIMENTAL_EDITOR_TOOLBAR_ENABLED, MESOP_CONCURRENT_UPDATES_ENABLED, + MESOP_WEBSOCKETS_ENABLED, STREAM_END, create_update_state_event, is_same_site, @@ -73,6 +74,7 @@ def render_loop( for js_module in js_modules ], experiment_settings=pb.ExperimentSettings( + websockets_enabled=MESOP_WEBSOCKETS_ENABLED, concurrent_updates_enabled=MESOP_CONCURRENT_UPDATES_ENABLED, experimental_editor_toolbar_enabled=EXPERIMENTAL_EDITOR_TOOLBAR_ENABLED, ) @@ -286,12 +288,9 @@ def handle_message(message): # Generate the response data generator = generate_data(ui_request) for data_chunk in generator: + emit("response", {"data": data_chunk}) if data_chunk == STREAM_END: break - emit("response", {"data": data_chunk}) - - # Optionally, signal the end of the stream - emit("end", {"message": "Stream ended"}) except Exception as e: error_response = pb.ServerError( diff --git a/mesop/server/server_utils.py b/mesop/server/server_utils.py index e2c49a28..8ef5184b 100644 --- a/mesop/server/server_utils.py +++ b/mesop/server/server_utils.py @@ -16,17 +16,25 @@ "MESOP_AI_SERVICE_BASE_URL", "http://localhost:43234" ) +MESOP_WEBSOCKETS_ENABLED = ( + os.environ.get("MESOP_WEBSOCKETS_ENABLED", "false").lower() == "true" +) + MESOP_CONCURRENT_UPDATES_ENABLED = ( os.environ.get("MESOP_CONCURRENT_UPDATES_ENABLED", "false").lower() == "true" ) +if MESOP_WEBSOCKETS_ENABLED: + print("Experiment enabled: MESOP_WEBSOCKETS_ENABLED") + print("Auto-enabling MESOP_CONCURRENT_UPDATES_ENABLED") + MESOP_CONCURRENT_UPDATES_ENABLED = True +elif MESOP_CONCURRENT_UPDATES_ENABLED: + print("Experiment enabled: MESOP_CONCURRENT_UPDATES_ENABLED") + EXPERIMENTAL_EDITOR_TOOLBAR_ENABLED = ( os.environ.get("MESOP_EXPERIMENTAL_EDITOR_TOOLBAR", "false").lower() == "true" ) -if MESOP_CONCURRENT_UPDATES_ENABLED: - print("Experiment enabled: MESOP_CONCURRENT_UPDATES_ENABLED") - if EXPERIMENTAL_EDITOR_TOOLBAR_ENABLED: print("Experiment enabled: EXPERIMENTAL_EDITOR_TOOLBAR_ENABLED") diff --git a/mesop/server/wsgi_app.py b/mesop/server/wsgi_app.py index 884ae10a..a6e40017 100644 --- a/mesop/server/wsgi_app.py +++ b/mesop/server/wsgi_app.py @@ -9,6 +9,7 @@ from mesop.server.flags import port from mesop.server.logging import log_startup from mesop.server.server import configure_flask_app +from mesop.server.server_utils import MESOP_WEBSOCKETS_ENABLED from mesop.server.static_file_serving import configure_static_file_serving from mesop.utils.host_util import get_local_host @@ -21,6 +22,15 @@ def __init__(self, flask_app: Flask): def run(self): log_startup(port=port()) + if MESOP_WEBSOCKETS_ENABLED: + socketio = self._flask_app.socketio # type: ignore + socketio.run( + self._flask_app, + host=get_local_host(), + port=port(), + use_reloader=False, + allow_unsafe_werkzeug=True, + ) self._flask_app.run(host=get_local_host(), port=port(), use_reloader=False) diff --git a/mesop/web/src/services/channel.ts b/mesop/web/src/services/channel.ts index e37d30de..bd2d8b90 100644 --- a/mesop/web/src/services/channel.ts +++ b/mesop/web/src/services/channel.ts @@ -43,11 +43,6 @@ export enum ChannelStatus { CLOSED = 'CLOSED', } -export enum ConnectionType { - SSE = 'SSE', - WEBSOCKET = 'WEBSOCKET', -} - @Injectable({ providedIn: 'root', }) @@ -56,7 +51,7 @@ export class Channel { private isWaiting = false; private isWaitingTimeout: number | undefined; private eventSource!: SSE; - private socket!: Socket; + private socket: Socket | undefined; private initParams!: InitParams; private states: States = new States(); private stateToken = ''; @@ -66,7 +61,6 @@ export class Channel { private queuedEvents: (() => void)[] = []; private hotReloadBackoffCounter = 0; private hotReloadCounter = 0; - private connectionType: ConnectionType = ConnectionType.SSE; // Default to SSE // Client-side state private overridedTitle = ''; @@ -97,6 +91,13 @@ export class Channel { * triggered by a user that's been taking a while. */ isBusy(): boolean { + if (this.experimentService.websocketsEnabled) { + // When WebSockets are enabled, we disable the busy indicator + // because it's possible for the server to push new data + // at any point. Apps should use their own loading indicators + // instead. + return false; + } return this.isWaiting && !this.isHotReloading(); } @@ -108,28 +109,16 @@ export class Channel { return this.componentConfigs; } - /** - * Initialize the channel with the given parameters and UI request. - * Supports both SSE and WebSocket connections based on the `useWebSocket` flag. - * @param initParams Initialization parameters. - * @param request Initial UI request. - * @param useWebSocket Whether to use WebSocket for the connection. - */ - init(initParams: InitParams, request: UiRequest, useWebSocket = true) { - this.connectionType = useWebSocket - ? ConnectionType.WEBSOCKET - : ConnectionType.SSE; + init(initParams: InitParams, request: UiRequest) { + console.debug('sending UI request', request); - if (this.connectionType === ConnectionType.SSE) { - this.initSSE(initParams, request); - } else { + if (this.experimentService.websocketsEnabled) { this.initWebSocket(initParams, request); + } else { + this.initSSE(initParams, request); } } - /** - * Initialize Server-Sent Events (SSE) connection. - */ private initSSE(initParams: InitParams, request: UiRequest) { this.eventSource = new SSE('/__ui__', { payload: generatePayloadString(request), @@ -145,6 +134,7 @@ export class Channel { this.initParams = initParams; this.eventSource.addEventListener('message', (e) => { + // Looks like Angular has a bug where it's not intercepting EventSource onmessage. zone.run(() => { const data = (e as any).data; if (data === '') { @@ -164,15 +154,13 @@ export class Channel { const array = toUint8Array(atob(data)); const uiResponse = UiResponse.deserializeBinary(array); console.debug('Server event (SSE): ', uiResponse.toObject()); - this.handleUiResponse(uiResponse, onRender, onError, onCommand); - }); - }); - - this.eventSource.addEventListener('error', (e) => { - zone.run(() => { - console.error('SSE connection error:', e); - this.status = ChannelStatus.CLOSED; - this.eventSource.close(); + this.handleUiResponse( + request, + uiResponse, + onRender, + onError, + onCommand, + ); }); }); } @@ -181,10 +169,15 @@ export class Channel { * Initialize WebSocket connection using Socket.IO. */ private initWebSocket(initParams: InitParams, request: UiRequest) { + if (this.socket) { + this.status = ChannelStatus.OPEN; + const payload = generatePayloadString(request); + this.socket.emit('message', payload); + return; + } this.socket = io('/__ui__', { transports: ['websocket'], - reconnectionAttempts: 5, // Adjust as needed - // You can pass additional options here + reconnectionAttempts: 3, }); this.status = ChannelStatus.OPEN; @@ -200,17 +193,16 @@ export class Channel { this.socket.on('connect', () => { // Send the initial UiRequest upon connection const payload = generatePayloadString(request); - this.socket.emit('message', payload); + this.socket!.emit('message', payload); }); this.socket.on('response', (data: any) => { + const prefix = 'data: '; + const payloadData = (data.data.slice(prefix.length) as string).trimEnd(); zone.run(() => { - if (data === '') { - this.socket.disconnect(); - this.status = ChannelStatus.CLOSED; - clearTimeout(this.isWaitingTimeout); - this.isWaiting = false; + if (payloadData === '') { this._isHotReloading = false; + this.status = ChannelStatus.CLOSED; this.logger.log({type: 'StreamEnd'}); if (this.queuedEvents.length) { const queuedEvent = this.queuedEvents.shift()!; @@ -218,17 +210,23 @@ export class Channel { } return; } - const prefix = 'data: '; - const array = toUint8Array(atob(data.data.slice(prefix.length))); + const array = toUint8Array(atob(payloadData)); const uiResponse = UiResponse.deserializeBinary(array); console.debug('Server event (WebSocket): ', uiResponse.toObject()); - this.handleUiResponse(uiResponse, onRender, onError, onCommand); + this.handleUiResponse( + request, + uiResponse, + onRender, + onError, + onCommand, + ); }); }); this.socket.on('error', (error: any) => { zone.run(() => { + this.socket = undefined; console.error('WebSocket error:', error); this.status = ChannelStatus.CLOSED; }); @@ -236,9 +234,8 @@ export class Channel { this.socket.on('disconnect', (reason: string) => { zone.run(() => { + this.socket = undefined; this.status = ChannelStatus.CLOSED; - clearTimeout(this.isWaitingTimeout); - this.isWaiting = false; this._isHotReloading = false; }); }); @@ -248,6 +245,7 @@ export class Channel { * Handle UiResponse from the server. */ private handleUiResponse( + request: UiRequest, uiResponse: UiResponse, onRender: InitParams['onRender'], onError: InitParams['onError'], @@ -318,6 +316,8 @@ export class Channel { .getRender()! .getExperimentSettings(); if (experimentSettings) { + this.experimentService.websocketsEnabled = + experimentSettings.getWebsocketsEnabled() ?? false; this.experimentService.concurrentUpdatesEnabled = experimentSettings.getConcurrentUpdatesEnabled() ?? false; this.experimentService.experimentalEditorToolbarEnabled = @@ -348,22 +348,13 @@ export class Channel { console.warn( 'Token not found in state session backend. Retrying user event.', ); - // Assuming you have access to the original request here - // If not, you might need to store it elsewhere - const request = new UiRequest(); - const userEvent = new UserEvent(); - userEvent.clearStateToken(); - userEvent.setStates(this.states); - request.setUserEvent(userEvent); - this.init( - this.initParams, - request, - this.connectionType === ConnectionType.WEBSOCKET, - ); + request.getUserEvent()!.clearStateToken(); + request.getUserEvent()!.setStates(this.states); + this.init(this.initParams, request); }); } else { onError(uiResponse.getError()!); - console.log('error', uiResponse.getError()); + console.error('error', uiResponse.getError()); } break; case UiResponse.TypeCase.TYPE_NOT_SET: @@ -371,10 +362,6 @@ export class Channel { } } - /** - * Dispatch a user event to the server. - * Supports both SSE and WebSocket based on the current connection type. - */ dispatch(userEvent: UserEvent) { // Every user event should have an event handler, // except for the ones below: @@ -402,14 +389,9 @@ export class Channel { const request = new UiRequest(); request.setUserEvent(userEvent); - this.init( - this.initParams, - request, - this.connectionType === ConnectionType.WEBSOCKET, - ); + this.init(this.initParams, request); }; this.logger.log({type: 'UserEventLog', userEvent: userEvent}); - if (this.status === ChannelStatus.CLOSED) { initUserEvent(); } else { @@ -433,14 +415,11 @@ export class Channel { )[0]; initUserEvent(); } - }, 1000); + }, 500); } } } - /** - * Check for hot reload by polling the server. - */ checkForHotReload() { const pollHotReloadEndpoint = async () => { try { @@ -480,9 +459,6 @@ export class Channel { this.overridedTitle = newTitle; } - /** - * Trigger a hot reload by sending a navigation event to the server. - */ hotReload() { // Only hot reload if there's no request in-flight. // Most likely the in-flight request will receive the updated UI. @@ -501,17 +477,10 @@ export class Channel { userEvent.setThemeSettings(this.themeService.getThemeSettings()); userEvent.setQueryParamsList(getQueryParams()); request.setUserEvent(userEvent); - this.init( - this.initParams, - request, - this.connectionType === ConnectionType.WEBSOCKET, - ); + this.init(this.initParams, request); } } -/** - * Generate a URL-safe base64 payload string from the UiRequest. - */ function generatePayloadString(request: UiRequest): string { request.setPath(window.location.pathname); const array = request.serializeBinary(); @@ -522,9 +491,6 @@ function generatePayloadString(request: UiRequest): string { return byteString; } -/** - * Convert a Uint8Array to a string. - */ function fromUint8Array(array: Uint8Array): string { // Chunk this to avoid RangeError: Maximum call stack size exceeded let result = ''; @@ -538,9 +504,6 @@ function fromUint8Array(array: Uint8Array): string { return result; } -/** - * Convert a string to a Uint8Array. - */ function toUint8Array(byteString: string): Uint8Array { const byteArray = new Uint8Array(byteString.length); for (let i = 0; i < byteString.length; i++) { diff --git a/mesop/web/src/services/experiment_service.ts b/mesop/web/src/services/experiment_service.ts index 57b1bf88..1c67afe1 100644 --- a/mesop/web/src/services/experiment_service.ts +++ b/mesop/web/src/services/experiment_service.ts @@ -6,4 +6,5 @@ import {Injectable} from '@angular/core'; export class ExperimentService { concurrentUpdatesEnabled = false; experimentalEditorToolbarEnabled = false; + websocketsEnabled = false; } diff --git a/scripts/cli_prod.sh b/scripts/cli_prod.sh index a9c5f3ec..ad4fff57 100755 --- a/scripts/cli_prod.sh +++ b/scripts/cli_prod.sh @@ -1,2 +1,3 @@ (lsof -t -i:32123 | xargs kill) || true && \ +MESOP_CONCURRENT_UPDATES_ENABLED=true \ bazel run //mesop/cli -- --path=mesop/mesop/example_index.py --prod From 808115aebda457bf3d9aad7bf31b066cde8f7d9d Mon Sep 17 00:00:00 2001 From: Will Chen Date: Mon, 14 Oct 2024 14:00:23 -0700 Subject: [PATCH 7/9] up2 --- mesop/server/server.py | 65 ++++++------------------------- mesop/web/src/services/channel.ts | 24 +++--------- 2 files changed, 17 insertions(+), 72 deletions(-) diff --git a/mesop/server/server.py b/mesop/server/server.py index 33be6dba..c6c60fb5 100644 --- a/mesop/server/server.py +++ b/mesop/server/server.py @@ -250,68 +250,25 @@ def teardown_clear_stale_state_sessions(error=None): if not prod_mode: configure_debug_routes(flask_app) - ### WebSocket Event Handlers ### + if is_websockets_enabled: from flask_socketio import SocketIO, emit - socketio = SocketIO( - # TODO: DISABLE CORS!! - flask_app, - cors_allowed_origins="*", - ) # Adjust CORS as needed - - # @socketio.on("connect", namespace="/__ui__") - # def handle_connect(): - # # Handle new WebSocket connection - # print("Client connected via WebSocket") - # emit("response", {"message": "Connected to /__ui__ WebSocket"}) - - # @socketio.on("disconnect", namespace="/__ui__") - # def handle_disconnect(): - # # Handle WebSocket disconnection - # print("Client disconnected from WebSocket") + socketio = SocketIO(flask_app) @socketio.on("message", namespace="/__ui__") def handle_message(message): - """ - Handle incoming messages from the client via WebSocket. - Expecting messages to be serialized UiRequest. - """ - try: - if not message: - emit("error", {"error": "Missing request payload"}) - return - - ui_request = pb.UiRequest() - ui_request.ParseFromString(base64.urlsafe_b64decode(message)) + if not message: + emit("error", {"error": "Missing request payload"}) + return - # Generate the response data - generator = generate_data(ui_request) - for data_chunk in generator: - emit("response", {"data": data_chunk}) - if data_chunk == STREAM_END: - break - - except Exception as e: - error_response = pb.ServerError( - exception=str(e), traceback=format_traceback() - ) - if not runtime().debug_mode: - error_response.ClearField("traceback") - if "Mesop Internal Error:" in error_response.exception: - error_response.exception = ( - "Sorry, there was an internal error with Mesop." - ) - if "Mesop Developer Error:" in error_response.exception: - error_response.exception = ( - "Sorry, there was an error. Please contact the developer." - ) - ui_response = pb.UiResponse(error=error_response) - emit("error", {"error": serialize(ui_response)}) + ui_request = pb.UiRequest() + ui_request.ParseFromString(base64.urlsafe_b64decode(message)) - ### End of WebSocket Event Handlers ### + generator = generate_data(ui_request) + for data_chunk in generator: + emit("response", {"data": data_chunk}) - # Replace the default Flask run method with SocketIO's run method - flask_app.socketio = socketio + flask_app.socketio = socketio # type: ignore return flask_app diff --git a/mesop/web/src/services/channel.ts b/mesop/web/src/services/channel.ts index bd2d8b90..65b175e4 100644 --- a/mesop/web/src/services/channel.ts +++ b/mesop/web/src/services/channel.ts @@ -154,13 +154,7 @@ export class Channel { const array = toUint8Array(atob(data)); const uiResponse = UiResponse.deserializeBinary(array); console.debug('Server event (SSE): ', uiResponse.toObject()); - this.handleUiResponse( - request, - uiResponse, - onRender, - onError, - onCommand, - ); + this.handleUiResponse(request, uiResponse, initParams); }); }); } @@ -214,13 +208,7 @@ export class Channel { const array = toUint8Array(atob(payloadData)); const uiResponse = UiResponse.deserializeBinary(array); console.debug('Server event (WebSocket): ', uiResponse.toObject()); - this.handleUiResponse( - request, - uiResponse, - onRender, - onError, - onCommand, - ); + this.handleUiResponse(request, uiResponse, initParams); }); }); @@ -235,8 +223,8 @@ export class Channel { this.socket.on('disconnect', (reason: string) => { zone.run(() => { this.socket = undefined; + console.error('WebSocket disconnected:', reason); this.status = ChannelStatus.CLOSED; - this._isHotReloading = false; }); }); } @@ -247,10 +235,9 @@ export class Channel { private handleUiResponse( request: UiRequest, uiResponse: UiResponse, - onRender: InitParams['onRender'], - onError: InitParams['onError'], - onCommand: InitParams['onCommand'], + initParams: InitParams, ) { + const {onRender, onError, onCommand} = initParams; switch (uiResponse.getTypeCase()) { case UiResponse.TypeCase.UPDATE_STATE_EVENT: { this.stateToken = uiResponse.getUpdateStateEvent()!.getStateToken()!; @@ -392,6 +379,7 @@ export class Channel { this.init(this.initParams, request); }; this.logger.log({type: 'UserEventLog', userEvent: userEvent}); + if (this.status === ChannelStatus.CLOSED) { initUserEvent(); } else { From 24cfdf1a832b755b2750e13343f8824e65f9140f Mon Sep 17 00:00:00 2001 From: Will Chen Date: Mon, 14 Oct 2024 14:59:39 -0700 Subject: [PATCH 8/9] minor tweaks --- mesop/server/server.py | 16 +++++++--------- mesop/server/wsgi_app.py | 5 ++++- mesop/web/src/services/channel.ts | 31 ++++++++++++++----------------- 3 files changed, 25 insertions(+), 27 deletions(-) diff --git a/mesop/server/server.py b/mesop/server/server.py index c6c60fb5..e00ea530 100644 --- a/mesop/server/server.py +++ b/mesop/server/server.py @@ -24,13 +24,11 @@ from mesop.utils.url_utils import remove_url_query_param from mesop.warn import warn +UI_PATH = "/__ui__" + def configure_flask_app( - *, - prod_mode: bool = True, - exceptions_to_propagate: Sequence[type] = (), - # TODO: plumb this from an env var - is_websockets_enabled=True, + *, prod_mode: bool = True, exceptions_to_propagate: Sequence[type] = () ) -> Flask: flask_app = Flask(__name__) @@ -224,7 +222,7 @@ def generate_data(ui_request: pb.UiRequest) -> Generator[str, None, None]: error=pb.ServerError(exception=str(e), traceback=format_traceback()) ) - @flask_app.route("/__ui__", methods=["POST"]) + @flask_app.route(UI_PATH, methods=["POST"]) def ui_stream() -> Response: # Prevent CSRF by checking the request site matches the site # of the URL root (where the Flask app is being served from) @@ -234,7 +232,7 @@ def ui_stream() -> Response: if not runtime().debug_mode and not is_same_site( request.headers.get("Origin"), request.url_root ): - abort(403, "Rejecting cross-site POST request to /__ui__") + abort(403, "Rejecting cross-site POST request to " + UI_PATH) data = request.data if not data: raise Exception("Missing request payload") @@ -251,12 +249,12 @@ def teardown_clear_stale_state_sessions(error=None): if not prod_mode: configure_debug_routes(flask_app) - if is_websockets_enabled: + if MESOP_WEBSOCKETS_ENABLED: from flask_socketio import SocketIO, emit socketio = SocketIO(flask_app) - @socketio.on("message", namespace="/__ui__") + @socketio.on("message", namespace=UI_PATH) def handle_message(message): if not message: emit("error", {"error": "Missing request payload"}) diff --git a/mesop/server/wsgi_app.py b/mesop/server/wsgi_app.py index a6e40017..dc5fc1f4 100644 --- a/mesop/server/wsgi_app.py +++ b/mesop/server/wsgi_app.py @@ -31,7 +31,10 @@ def run(self): use_reloader=False, allow_unsafe_werkzeug=True, ) - self._flask_app.run(host=get_local_host(), port=port(), use_reloader=False) + else: + self._flask_app.run( + host=get_local_host(), port=port(), use_reloader=False + ) def create_app( diff --git a/mesop/web/src/services/channel.ts b/mesop/web/src/services/channel.ts index 65b175e4..64d915d5 100644 --- a/mesop/web/src/services/channel.ts +++ b/mesop/web/src/services/channel.ts @@ -20,7 +20,9 @@ import {getViewportSize} from '../utils/viewport_size'; import {ThemeService} from './theme_service'; import {getQueryParams} from '../utils/query_params'; import {ExperimentService} from './experiment_service'; -import {io, Socket} from 'socket.io-client'; // Import Socket.IO client +import {io, Socket} from 'socket.io-client'; + +const STREAM_END = ''; // Pick 500ms as the minimum duration before showing a progress/busy indicator // for the channel. @@ -137,17 +139,14 @@ export class Channel { // Looks like Angular has a bug where it's not intercepting EventSource onmessage. zone.run(() => { const data = (e as any).data; - if (data === '') { + if (data === STREAM_END) { this.eventSource.close(); this.status = ChannelStatus.CLOSED; clearTimeout(this.isWaitingTimeout); this.isWaiting = false; this._isHotReloading = false; this.logger.log({type: 'StreamEnd'}); - if (this.queuedEvents.length) { - const queuedEvent = this.queuedEvents.shift()!; - queuedEvent(); - } + this.dequeueEvent(); return; } @@ -159,9 +158,6 @@ export class Channel { }); } - /** - * Initialize WebSocket connection using Socket.IO. - */ private initWebSocket(initParams: InitParams, request: UiRequest) { if (this.socket) { this.status = ChannelStatus.OPEN; @@ -194,14 +190,11 @@ export class Channel { const prefix = 'data: '; const payloadData = (data.data.slice(prefix.length) as string).trimEnd(); zone.run(() => { - if (payloadData === '') { + if (payloadData === STREAM_END) { this._isHotReloading = false; this.status = ChannelStatus.CLOSED; this.logger.log({type: 'StreamEnd'}); - if (this.queuedEvents.length) { - const queuedEvent = this.queuedEvents.shift()!; - queuedEvent(); - } + this.dequeueEvent(); return; } @@ -229,9 +222,6 @@ export class Channel { }); } - /** - * Handle UiResponse from the server. - */ private handleUiResponse( request: UiRequest, uiResponse: UiResponse, @@ -408,6 +398,13 @@ export class Channel { } } + private dequeueEvent() { + if (this.queuedEvents.length) { + const queuedEvent = this.queuedEvents.shift()!; + queuedEvent(); + } + } + checkForHotReload() { const pollHotReloadEndpoint = async () => { try { From 6b3fe5ca5180ae87dd592639975e294a6d018839 Mon Sep 17 00:00:00 2001 From: Will Chen Date: Tue, 15 Oct 2024 21:37:07 -0700 Subject: [PATCH 9/9] update ci --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9e4b9568..3519c0bb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -73,7 +73,7 @@ jobs: path: playwright-report-with-concurrent-updates-enabled/ retention-days: 30 - name: Run playwright test with websockets enabled - run: MESOP_WEBSOCKETES_ENABLED=true PLAYWRIGHT_HTML_OUTPUT_DIR=playwright-report-with-websockets-enabled yarn playwright test + run: MESOP_WEBSOCKETS_ENABLED=true PLAYWRIGHT_HTML_OUTPUT_DIR=playwright-report-with-websockets-enabled yarn playwright test - uses: actions/upload-artifact@a8a3f3ad30e3422c9c7b888a15615d19a852ae32 # v3.1.3 if: always() with: