diff --git a/pglookout/webserver.py b/pglookout/webserver.py index 0394ca6..a9c66c5 100644 --- a/pglookout/webserver.py +++ b/pglookout/webserver.py @@ -7,8 +7,12 @@ This file is under the Apache License, Version 2.0. See the file `LICENSE` for details. """ -from http.server import HTTPServer, SimpleHTTPRequestHandler -from logging import getLogger +from __future__ import annotations + +from http.server import BaseHTTPRequestHandler, HTTPServer, SimpleHTTPRequestHandler +from logging import getLogger, Logger +from pglookout.config import Config +from queue import Queue from socketserver import ThreadingMixIn from threading import Thread @@ -17,43 +21,60 @@ class ThreadedWebServer(ThreadingMixIn, HTTPServer): - cluster_state = None - log = None - cluster_monitor_check_queue = None - allow_reuse_address = True + allow_reuse_address: bool = True + + def __init__( + self, + address: str, + port: int, + RequestHandlerClass: type[BaseHTTPRequestHandler], + cluster_state: dict[str, int], + log: Logger, + cluster_monitor_check_queue: Queue[str], + ) -> None: + super().__init__((address, port), RequestHandlerClass) + self.cluster_state: dict[str, int] = cluster_state + self.log: Logger = log + self.cluster_monitor_check_queue: Queue[str] = cluster_monitor_check_queue class WebServer(Thread): - def __init__(self, config, cluster_state, cluster_monitor_check_queue): - Thread.__init__(self) - self.config = config - self.cluster_state = cluster_state - self.cluster_monitor_check_queue = cluster_monitor_check_queue - self.log = getLogger("WebServer") - self.address = self.config.get("http_address", "") - self.port = self.config.get("http_port", 15000) - self.server = None + def __init__(self, config: Config, cluster_state: dict[str, int], cluster_monitor_check_queue: Queue[str]) -> None: + super().__init__() + self.config: Config = config + self.cluster_state: dict[str, int] = cluster_state + self.cluster_monitor_check_queue: Queue[str] = cluster_monitor_check_queue + self.log: Logger = getLogger("WebServer") + self.address: str = self.config.get("http_address", "") + self.port: int = self.config.get("http_port", 15000) + self.server: ThreadedWebServer | None = None self.log.debug("WebServer initialized with address: %r port: %r", self.address, self.port) - self.is_initialized = threading.Event() + self.is_initialized: threading.Event = threading.Event() - def run(self): + def run(self) -> None: # We bind the port only when we start running - self.server = ThreadedWebServer((self.address, self.port), RequestHandler) - self.server.cluster_state = self.cluster_state - self.server.log = self.log - self.server.cluster_monitor_check_queue = self.cluster_monitor_check_queue + self.server = ThreadedWebServer( + address=self.address, + port=self.port, + RequestHandlerClass=RequestHandler, + cluster_state=self.cluster_state, + log=self.log, + cluster_monitor_check_queue=self.cluster_monitor_check_queue, + ) self.is_initialized.set() self.server.serve_forever() - def close(self): - if self.server: - self.log.debug("Closing WebServer") - self.server.shutdown() - self.log.debug("Closed WebServer") + def close(self) -> None: + if self.server is None: + return + + self.log.debug("Closing WebServer") + self.server.shutdown() + self.log.debug("Closed WebServer") class RequestHandler(SimpleHTTPRequestHandler): - def do_GET(self): + def do_GET(self) -> None: assert isinstance(self.server, ThreadedWebServer), f"server: {self.server!r}" self.server.log.debug("Got request: %r", self.path) if self.path.startswith("/state.json"): @@ -66,7 +87,7 @@ def do_GET(self): else: self.send_response(404) - def do_POST(self): + def do_POST(self) -> None: assert isinstance(self.server, ThreadedWebServer), f"server: {self.server!r}" self.server.log.debug("Got request: %r", self.path) if self.path.startswith("/check"): diff --git a/pyproject.toml b/pyproject.toml index 16ce87c..cb6440c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,14 +39,12 @@ exclude = [ 'pglookout/pgutil.py', 'pglookout/statsd.py', 'pglookout/version.py', - 'pglookout/webserver.py', # Tests. 'test/conftest.py', 'test/test_cluster_monitor.py', 'test/test_common.py', 'test/test_lookout.py', 'test/test_pgutil.py', - 'test/test_webserver.py', # Other. 'setup.py', 'version.py', diff --git a/test/test_webserver.py b/test/test_webserver.py index d81e71d..62c096a 100644 --- a/test/test_webserver.py +++ b/test/test_webserver.py @@ -6,6 +6,7 @@ This file is under the Apache License, Version 2.0. See the file `LICENSE` for details. """ +from pglookout.config import Config from pglookout.webserver import WebServer from queue import Queue @@ -13,8 +14,8 @@ import requests -def test_webserver(): - config = { +def test_webserver() -> None: + config: Config = { "http_port": random.randint(10000, 32000), } cluster_state = { @@ -22,12 +23,12 @@ def test_webserver(): } http_port = config["http_port"] base_url = f"http://127.0.0.1:{http_port}" - cluster_monitor_check_queue = Queue() + cluster_monitor_check_queue: Queue[str] = Queue() web = WebServer(config=config, cluster_state=cluster_state, cluster_monitor_check_queue=cluster_monitor_check_queue) try: web.start() - # wait for the thread to have started, else we're blocking forever as web.close can't shutdown the thread + # wait for the thread to have started, else we're blocking forever as web.close can't shut down the thread web.is_initialized.wait(timeout=30.0) result = requests.get(f"{base_url}/state.json", timeout=5).json()