Skip to content

Commit

Permalink
mypy: pglookout/webserver.py [BF-1560]
Browse files Browse the repository at this point in the history
  • Loading branch information
Samuel Giffard committed Mar 22, 2023
1 parent 84737ad commit 52bb635
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 34 deletions.
77 changes: 49 additions & 28 deletions pglookout/webserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@
This file is under the Apache License, Version 2.0.
See the file `LICENSE` for details.
"""
from http.server import HTTPServer, SimpleHTTPRequestHandler
from logging import getLogger
from __future__ import annotations

from http.server import BaseHTTPRequestHandler, HTTPServer, SimpleHTTPRequestHandler
from logging import getLogger, Logger
from pglookout.config import Config
from queue import Queue
from socketserver import ThreadingMixIn
from threading import Thread

Expand All @@ -17,43 +21,60 @@


class ThreadedWebServer(ThreadingMixIn, HTTPServer):
cluster_state = None
log = None
cluster_monitor_check_queue = None
allow_reuse_address = True
allow_reuse_address: bool = True

def __init__(
self,
address: str,
port: int,
RequestHandlerClass: type[BaseHTTPRequestHandler],
cluster_state: dict[str, int],
log: Logger,
cluster_monitor_check_queue: Queue[str],
) -> None:
super().__init__((address, port), RequestHandlerClass)
self.cluster_state: dict[str, int] = cluster_state
self.log: Logger = log
self.cluster_monitor_check_queue: Queue[str] = cluster_monitor_check_queue


class WebServer(Thread):
def __init__(self, config, cluster_state, cluster_monitor_check_queue):
Thread.__init__(self)
self.config = config
self.cluster_state = cluster_state
self.cluster_monitor_check_queue = cluster_monitor_check_queue
self.log = getLogger("WebServer")
self.address = self.config.get("http_address", "")
self.port = self.config.get("http_port", 15000)
self.server = None
def __init__(self, config: Config, cluster_state: dict[str, int], cluster_monitor_check_queue: Queue[str]) -> None:
super().__init__()
self.config: Config = config
self.cluster_state: dict[str, int] = cluster_state
self.cluster_monitor_check_queue: Queue[str] = cluster_monitor_check_queue
self.log: Logger = getLogger("WebServer")
self.address: str = self.config.get("http_address", "")
self.port: int = self.config.get("http_port", 15000)
self.server: ThreadedWebServer | None = None
self.log.debug("WebServer initialized with address: %r port: %r", self.address, self.port)
self.is_initialized = threading.Event()
self.is_initialized: threading.Event = threading.Event()

def run(self):
def run(self) -> None:
# We bind the port only when we start running
self.server = ThreadedWebServer((self.address, self.port), RequestHandler)
self.server.cluster_state = self.cluster_state
self.server.log = self.log
self.server.cluster_monitor_check_queue = self.cluster_monitor_check_queue
self.server = ThreadedWebServer(
address=self.address,
port=self.port,
RequestHandlerClass=RequestHandler,
cluster_state=self.cluster_state,
log=self.log,
cluster_monitor_check_queue=self.cluster_monitor_check_queue,
)
self.is_initialized.set()
self.server.serve_forever()

def close(self):
if self.server:
self.log.debug("Closing WebServer")
self.server.shutdown()
self.log.debug("Closed WebServer")
def close(self) -> None:
if self.server is None:
return

self.log.debug("Closing WebServer")
self.server.shutdown()
self.log.debug("Closed WebServer")


class RequestHandler(SimpleHTTPRequestHandler):
def do_GET(self):
def do_GET(self) -> None:
assert isinstance(self.server, ThreadedWebServer), f"server: {self.server!r}"
self.server.log.debug("Got request: %r", self.path)
if self.path.startswith("/state.json"):
Expand All @@ -66,7 +87,7 @@ def do_GET(self):
else:
self.send_response(404)

def do_POST(self):
def do_POST(self) -> None:
assert isinstance(self.server, ThreadedWebServer), f"server: {self.server!r}"
self.server.log.debug("Got request: %r", self.path)
if self.path.startswith("/check"):
Expand Down
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,12 @@ exclude = [
'pglookout/pgutil.py',
'pglookout/statsd.py',
'pglookout/version.py',
'pglookout/webserver.py',
# Tests.
'test/conftest.py',
'test/test_cluster_monitor.py',
'test/test_common.py',
'test/test_lookout.py',
'test/test_pgutil.py',
'test/test_webserver.py',
# Other.
'setup.py',
'version.py',
Expand Down
9 changes: 5 additions & 4 deletions test/test_webserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,29 @@
This file is under the Apache License, Version 2.0.
See the file `LICENSE` for details.
"""
from pglookout.config import Config
from pglookout.webserver import WebServer
from queue import Queue

import random
import requests


def test_webserver():
config = {
def test_webserver() -> None:
config: Config = {
"http_port": random.randint(10000, 32000),
}
cluster_state = {
"hello": 123,
}
http_port = config["http_port"]
base_url = f"http://127.0.0.1:{http_port}"
cluster_monitor_check_queue = Queue()
cluster_monitor_check_queue: Queue[str] = Queue()

web = WebServer(config=config, cluster_state=cluster_state, cluster_monitor_check_queue=cluster_monitor_check_queue)
try:
web.start()
# wait for the thread to have started, else we're blocking forever as web.close can't shutdown the thread
# wait for the thread to have started, else we're blocking forever as web.close can't shut down the thread
web.is_initialized.wait(timeout=30.0)

result = requests.get(f"{base_url}/state.json", timeout=5).json()
Expand Down

0 comments on commit 52bb635

Please sign in to comment.