From 3a32d1ba16688c38c483b11ac9d66679c3c84d04 Mon Sep 17 00:00:00 2001 From: Adam Turner <9087854+aa-turner@users.noreply.github.com> Date: Tue, 9 Apr 2024 00:12:15 +0100 Subject: [PATCH] Implement hot reloading with websockets --- README.rst | 8 +--- pyproject.toml | 6 ++- sphinx_autobuild/__init__.py | 2 +- sphinx_autobuild/__main__.py | 29 ++++++++----- sphinx_autobuild/_hacks.py | 28 ------------ sphinx_autobuild/middleware.py | 44 +++++++++++++++++++ sphinx_autobuild/server.py | 78 ++++++++++++++++++++++++++++++++++ 7 files changed, 148 insertions(+), 47 deletions(-) delete mode 100644 sphinx_autobuild/_hacks.py create mode 100644 sphinx_autobuild/middleware.py create mode 100644 sphinx_autobuild/server.py diff --git a/README.rst b/README.rst index 8d4da31..f39ae65 100644 --- a/README.rst +++ b/README.rst @@ -10,7 +10,7 @@ sphinx-autobuild :target: https://opensource.org/licenses/MIT :alt: MIT -Rebuild Sphinx documentation on changes, with live-reload in the browser. +Rebuild Sphinx documentation on changes, with hot reloading in the browser. .. image:: ./docs/_static/demo.png :align: center @@ -167,16 +167,12 @@ __ https://github.com/sphinx-doc/sphinx-autobuild/issues/34 Acknowledgements ================ -This project stands on the shoulders of giants like -Sphinx_, LiveReload_ and python-livereload_, +This project stands on the shoulders of giants, without whom this project would not be possible. Many thanks to everyone who has `contributed code`_ as well as participated in `discussions on the issue tracker`_. This project is better thanks to your contribution. -.. _Sphinx: https://sphinx-doc.org/ -.. _LiveReload: https://livereload.com/ -.. _python-livereload: https://github.com/lepture/python-livereload .. _contributed code: https://github.com/sphinx-doc/sphinx-autobuild/graphs/contributors .. _discussions on the issue tracker: https://github.com/sphinx-doc/sphinx-autobuild/issues diff --git a/pyproject.toml b/pyproject.toml index 5123a04..0d71bcb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "flit_core.buildapi" # project metadata [project] name = "sphinx-autobuild" -description = "Rebuild Sphinx documentation on changes, with live-reload in the browser." +description = "Rebuild Sphinx documentation on changes, with hot reloading in the browser." readme = "README.rst" urls.Changelog = "https://github.com/sphinx-doc/sphinx-autobuild/blob/main/NEWS.rst" urls.Documentation = "https://github.com/sphinx-doc/sphinx-autobuild#readme" @@ -43,7 +43,9 @@ classifiers = [ ] dependencies = [ "sphinx", - "livereload", + "starlette>=0.35", + "uvicorn>=0.25", + "websockets>=11.0", "colorama", ] dynamic = ["version"] diff --git a/sphinx_autobuild/__init__.py b/sphinx_autobuild/__init__.py index 26a2dd6..92078cf 100644 --- a/sphinx_autobuild/__init__.py +++ b/sphinx_autobuild/__init__.py @@ -1,3 +1,3 @@ -"""Rebuild Sphinx documentation on changes, with live-reload in the browser.""" +"""Rebuild Sphinx documentation on changes, with hot reloading in the browser.""" __version__ = "2024.02.04" diff --git a/sphinx_autobuild/__main__.py b/sphinx_autobuild/__main__.py index cde9872..600799d 100644 --- a/sphinx_autobuild/__main__.py +++ b/sphinx_autobuild/__main__.py @@ -1,21 +1,25 @@ """Entrypoint for ``python -m sphinx_autobuild``.""" -from sphinx_autobuild import _hacks # isort:skip # noqa - import argparse import os import shlex import sys import colorama -from livereload import Server +import uvicorn # This isn't public API, but there aren't many better options from sphinx.cmd.build import get_parser as sphinx_get_parser +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.routing import Mount, WebSocketRoute +from starlette.staticfiles import StaticFiles from sphinx_autobuild import __version__ from sphinx_autobuild.build import Builder from sphinx_autobuild.filter import IgnoreFilter +from sphinx_autobuild.middleware import JavascriptInjectorMiddleware +from sphinx_autobuild.server import RebuildServer from sphinx_autobuild.utils import find_free_port, open_browser, show @@ -33,7 +37,6 @@ def main(): host_name = args.host port_num = args.port or find_free_port() url_host = f"{host_name}:{port_num}" - server = Server() pre_build_commands = list(map(shlex.split, args.pre_build)) @@ -43,15 +46,21 @@ def main(): pre_build_commands=pre_build_commands, ) + watch_dirs = [src_dir] + args.additional_watched_dirs ignore_handler = IgnoreFilter( [p for p in args.ignore + [out_dir, args.warnings_file, args.doctree_dir] if p], args.re_ignore, ) - server.watch(src_dir, builder, ignore=ignore_handler) - for dirpath in args.additional_watched_dirs: - dirpath = os.path.realpath(dirpath) - server.watch(dirpath, builder, ignore=ignore_handler) - server.watch(out_dir, ignore=ignore_handler) + watcher = RebuildServer(watch_dirs, ignore_handler, change_callback=builder) + + app = Starlette( + routes=[ + WebSocketRoute("/websocket-reload", watcher, name="reload"), + Mount("/", app=StaticFiles(directory=out_dir, html=True), name="static"), + ], + middleware=[Middleware(JavascriptInjectorMiddleware, ws_url=url_host)], + lifespan=watcher.lifespan, + ) if not args.no_initial_build: builder(rebuild=False) @@ -60,7 +69,7 @@ def main(): open_browser(url_host, args.delay) try: - server.serve(port=port_num, host=host_name, root=out_dir) + uvicorn.run(app, host=host_name, port=port_num, log_level="warning") except KeyboardInterrupt: show(context="Server ceasing operations. Cheerio!") diff --git a/sphinx_autobuild/_hacks.py b/sphinx_autobuild/_hacks.py deleted file mode 100644 index 2287088..0000000 --- a/sphinx_autobuild/_hacks.py +++ /dev/null @@ -1,28 +0,0 @@ -"""This file contains hacks needed to make things work. Ideally, this file is empty.""" - -from pathlib import PurePosixPath -from urllib.parse import urlparse - -import livereload.server as server -from tornado.web import OutputTransform - - -# Why do we do this? -# See https://github.com/sphinx-doc/sphinx-autobuild/issues/71#issuecomment-681854580 -class _FixedLiveScriptInjector(server.LiveScriptInjector): - def __init__(self, request): - # NOTE: Using super() here causes an infinite cycle, due to - # ConfiguredTransform not declaring an __init__. - OutputTransform.__init__(self, request) - - # Determine if this is an HTML page - path = PurePosixPath(urlparse(request.uri).path) - self.should_modify_request = path.suffix in ["", ".html"] - - def transform_first_chunk(self, status_code, headers, chunk, finishing): - if not self.should_modify_request: - return status_code, headers, chunk - return super().transform_first_chunk(status_code, headers, chunk, finishing) - - -server.LiveScriptInjector = _FixedLiveScriptInjector diff --git a/sphinx_autobuild/middleware.py b/sphinx_autobuild/middleware.py new file mode 100644 index 0000000..d39587a --- /dev/null +++ b/sphinx_autobuild/middleware.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from starlette.datastructures import MutableHeaders +from starlette.types import ASGIApp, Message, Receive, Scope, Send + + +def web_socket_script(ws_url: str) -> str: + # language=HTML + return f""" + +""" + + +class JavascriptInjectorMiddleware: + def __init__(self, app: ASGIApp, ws_url: str) -> None: + self.app = app + self.script = web_socket_script(ws_url).encode("utf-8") + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + add_script = False + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + async def send_wrapper(message: Message) -> None: + nonlocal add_script + if message["type"] == "http.response.start": + headers = MutableHeaders(scope=message) + if headers.get("Content-Type", "").startswith("text/html"): + add_script = True + if "Content-Length" in headers: + length = int(headers["Content-Length"]) + len(self.script) + headers["Content-Length"] = str(length) + elif message["type"] == "http.response.body": + request_complete = not message.get("more_body", False) + if add_script and request_complete: + message["body"] += self.script + await send(message) + + await self.app(scope, receive, send_wrapper) + return diff --git a/sphinx_autobuild/server.py b/sphinx_autobuild/server.py new file mode 100644 index 0000000..1b7673a --- /dev/null +++ b/sphinx_autobuild/server.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import asyncio +import os +from contextlib import AbstractAsyncContextManager, asynccontextmanager + +import watchfiles +from starlette.types import Receive, Scope, Send +from starlette.websockets import WebSocket + +TYPE_CHECKING = False +if TYPE_CHECKING: + from collections.abc import Callable + + from sphinx_autobuild.filter import IgnoreFilter + + +class RebuildServer: + def __init__( + self, + paths: list[os.PathLike[str]], + ignore_filter: IgnoreFilter, + change_callback: Callable[[], None], + ) -> None: + self.paths = [os.path.realpath(path, strict=True) for path in paths] + self.ignore = ignore_filter + self.change_callback = change_callback + self.flag = asyncio.Event() + self.should_exit = asyncio.Event() + + @asynccontextmanager + async def lifespan(self, _app) -> AbstractAsyncContextManager[None]: + task = asyncio.create_task(self.main()) + yield + self.should_exit.set() + await task + return + + async def main(self) -> None: + tasks = ( + asyncio.create_task(self.watch()), + asyncio.create_task(self.should_exit.wait()), + ) + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + [task.cancel() for task in pending] + [task.result() for task in done] + + async def watch(self) -> None: + async for _changes in watchfiles.awatch( + *self.paths, + watch_filter=lambda _, path: not self.ignore(path), + ): + self.change_callback() + self.flag.set() + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "websocket" + ws = WebSocket(scope, receive, send) + await ws.accept() + + tasks = ( + asyncio.create_task(self.watch_reloads(ws)), + asyncio.create_task(self.wait_client_disconnect(ws)), + ) + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + [task.cancel() for task in pending] + [task.result() for task in done] + + async def watch_reloads(self, ws: WebSocket) -> None: + while True: + await self.flag.wait() + self.flag.clear() + await ws.send_text("refresh") + + @staticmethod + async def wait_client_disconnect(ws: WebSocket) -> None: + async for _ in ws.iter_text(): + pass