Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HH-217737 fix fork workers race conditions #727

Merged
merged 1 commit into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 21 additions & 34 deletions frontik/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@
import multiprocessing
import os
import time
from collections.abc import Callable
from ctypes import c_bool, c_int
from threading import Lock
from typing import Optional, Union

from aiokafka import AIOKafkaProducer
from fastapi import FastAPI, HTTPException
from http_client import AIOHttpClientWrapper, HttpClientFactory
from http_client import options as http_client_options
from http_client.balancing import RequestBalancerBuilder, Upstream
from http_client.balancing import RequestBalancerBuilder
from lxml import etree
from tornado import httputil

Expand All @@ -35,7 +33,7 @@
router,
routers,
)
from frontik.service_discovery import UpstreamManager
from frontik.service_discovery import MasterServiceDiscovery, ServiceDiscovery, WorkerServiceDiscovery

app_logger = logging.getLogger('app_logger')
_server_tasks = set()
Expand Down Expand Up @@ -76,7 +74,6 @@ def __init__(self, app_module_name: Optional[str] = None) -> None:
self.json = frontik.producers.json_producer.JsonProducerFactory(self)

self.available_integrations: list[integrations.Integration] = []
self.http_client: Optional[AIOHttpClientWrapper] = None
self.http_client_factory: HttpClientFactory

self.statsd_client: Union[StatsDClient, StatsDClientStub] = create_statsd_client(options, self)
Expand All @@ -96,37 +93,35 @@ def __init__(self, app_module_name: Optional[str] = None) -> None:
self.settings: dict = {}

self.asgi_app = FrontikAsgiApp()
self.service_discovery: ServiceDiscovery

def __call__(self, tornado_request: httputil.HTTPServerRequest) -> None:
# for make it more asgi, reimplement tornado.http1connection._server_request_loop and ._read_message
task = asyncio.create_task(serve_tornado_request(self, self.asgi_app, tornado_request))
_server_tasks.add(task)
task.add_done_callback(_server_tasks.discard)

def create_upstream_manager(
self,
upstreams: dict[str, Upstream],
upstreams_lock: Optional[Lock],
send_to_all_workers: Optional[Callable],
with_consul: bool,
) -> None:
self.upstream_manager = UpstreamManager(
upstreams,
self.statsd_client,
upstreams_lock,
send_to_all_workers,
with_consul,
self.app_name,
)

self.upstream_manager.send_updates() # initial full state sending
def make_service_discovery(self) -> ServiceDiscovery:
if self.worker_state.is_master and options.consul_enabled:
return MasterServiceDiscovery(self.statsd_client, self.app_name)
else:
return WorkerServiceDiscovery(self.worker_state.initial_shared_data)

async def init(self) -> None:
async def install_integrations(self) -> None:
self.available_integrations, integration_futures = integrations.load_integrations(self)
await asyncio.gather(*[future for future in integration_futures if future])

self.http_client = AIOHttpClientWrapper()
self.service_discovery = self.make_service_discovery()
self.http_client_factory = self.make_http_client_factory()
self.asgi_app.http_client_factory = self.http_client_factory

async def init(self) -> None:
await self.install_integrations()

if self.worker_state.is_master:
self.worker_state.master_done.value = True

def make_http_client_factory(self) -> HttpClientFactory:
kafka_cluster = options.http_client_metrics_kafka_cluster
send_metrics_to_kafka = kafka_cluster and kafka_cluster in options.kafka_clusters

Expand All @@ -143,20 +138,12 @@ async def init(self) -> None:
self.get_kafka_producer(kafka_cluster) if send_metrics_to_kafka and kafka_cluster is not None else None
)

with_consul = self.worker_state.single_worker_mode and options.consul_enabled
self.create_upstream_manager({}, None, None, with_consul)
self.upstream_manager.register_service()

request_balancer_builder = RequestBalancerBuilder(
self.upstream_manager.get_upstreams(),
upstreams=self.service_discovery.get_upstreams_unsafe(),
statsd_client=self.statsd_client,
kafka_producer=kafka_producer,
)
self.http_client_factory = HttpClientFactory(self.app_name, self.http_client, request_balancer_builder)
self.asgi_app.http_client_factory = self.http_client_factory

if self.worker_state.single_worker_mode:
self.worker_state.master_done.value = True
return HttpClientFactory(self.app_name, AIOHttpClientWrapper(), request_balancer_builder)

def application_config(self) -> DefaultConfig:
return FrontikApplication.DefaultConfig()
Expand Down
1 change: 1 addition & 0 deletions frontik/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class Options:
xheaders: bool = False
validate_request_id: bool = False
xsrf_cookies: bool = False
max_body_size: int = 1_000_000_000
openapi_enabled: bool = False

config: Optional[str] = None
Expand Down
118 changes: 62 additions & 56 deletions frontik/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import sys
import time
from collections.abc import Callable
from contextlib import suppress
from copy import deepcopy
from dataclasses import dataclass, field
from functools import partial
from multiprocessing.sharedctypes import Synchronized
Expand Down Expand Up @@ -41,22 +43,23 @@ class WorkerState:
is_master: bool = True
children: dict = field(default_factory=lambda: {}) # pid: worker_id
Powerrr marked this conversation as resolved.
Show resolved Hide resolved
write_pipes: dict = field(default_factory=lambda: {}) # pid: write_pipe
resend_notification: Queue = field(default_factory=lambda: Queue(maxsize=1))
resend_dict: dict = field(default_factory=lambda: {}) # pid: flag
terminating: bool = False
single_worker_mode: bool = True
initial_shared_data: dict = field(default_factory=lambda: {})


def fork_workers(
*,
worker_state: WorkerState,
num_workers: int,
master_function: Callable,
master_before_fork_action: Callable,
master_after_fork_action: Callable,
master_before_shutdown_action: Callable,
worker_function: Callable,
worker_listener_handler: Callable,
) -> None:
log.info('starting %d processes', num_workers)
worker_state.single_worker_mode = False

def master_sigterm_handler(signum, _frame):
if not worker_state.is_master:
Expand All @@ -65,17 +68,17 @@ def master_sigterm_handler(signum, _frame):
worker_state.terminating = True
master_before_shutdown_action()
for pid, worker_id in worker_state.children.items():
log.info('sending %s to child %d (pid %d)', signal.Signals(signum).name, worker_id, pid)
log.info('sending %s to child %d (pid %d)', signal.SIGTERM.name, worker_id, pid)
os.kill(pid, signal.SIGTERM)

signal.signal(signal.SIGTERM, master_sigterm_handler)
signal.signal(signal.SIGINT, master_sigterm_handler)

shared_data, lock = master_before_fork_action()

worker_function_wrapped = partial(_worker_function_wrapper, worker_function, worker_listener_handler)
for worker_id in range(num_workers):
is_worker = _start_child(worker_id, worker_state, worker_function_wrapped)
if is_worker:
return
_start_child(worker_id, worker_state, shared_data, lock, worker_function_wrapped)

gc.enable()
timeout = time.time() + options.init_workers_timeout_sec
Expand All @@ -88,13 +91,17 @@ def master_sigterm_handler(signum, _frame):
f'{worker_state.init_workers_count_down.value} workers',
)
time.sleep(0.1)
_master_function_wrapper(worker_state, master_function)

__master_function_wrapper(worker_state, master_after_fork_action, shared_data, lock)
worker_state.master_done.value = True
_supervise_workers(worker_state, worker_function_wrapped, master_before_shutdown_action)
_supervise_workers(worker_state, shared_data, lock, worker_function_wrapped)


def _supervise_workers(
worker_state: WorkerState, worker_function: Callable, master_before_shutdown_action: Callable
worker_state: WorkerState,
shared_data: dict,
lock: Lock,
worker_function: Callable,
) -> None:
while worker_state.children:
try:
Expand All @@ -116,16 +123,6 @@ def _supervise_workers(

if os.WIFSIGNALED(status):
log.warning('child %d (pid %d) killed by signal %d, restarting', worker_id, pid, os.WTERMSIG(status))

# TODO remove this block # noqa
master_before_shutdown_action()
for pid, worker_id in worker_state.children.items():
log.info('sending %s to child %d (pid %d)', signal.Signals(os.WTERMSIG(status)).name, worker_id, pid)
os.kill(pid, signal.SIGTERM)
log.info('all children terminated, exiting')
time.sleep(options.stop_timeout)
sys.exit(0)

elif os.WEXITSTATUS(status) != 0:
log.warning('child %d (pid %d) exited with status %d, restarting', worker_id, pid, os.WEXITSTATUS(status))
else:
Expand All @@ -136,30 +133,35 @@ def _supervise_workers(
log.info('server is shutting down, not restarting %d', worker_id)
continue

is_worker = _start_child(worker_id, worker_state, worker_function)
if is_worker:
return
worker_pid = _start_child(worker_id, worker_state, shared_data, lock, worker_function)
on_worker_restart(worker_state, worker_pid)

log.info('all children terminated, exiting')
sys.exit(0)


# returns True inside child process, otherwise False
def _start_child(worker_id: int, worker_state: WorkerState, worker_function: Callable) -> bool:
def _start_child(
worker_id: int, worker_state: WorkerState, shared_data: dict, lock: Optional[Lock], worker_function: Callable
) -> int:
# it cannot be multiprocessing.pipe because we need to set nonblock flag and connect to asyncio
read_fd, write_fd = os.pipe()
os.set_blocking(read_fd, False)
os.set_blocking(write_fd, False)

if lock is not None:
with lock:
worker_state.initial_shared_data = deepcopy(shared_data)

prc = multiprocessing.Process(target=worker_function, args=(read_fd, write_fd, worker_state, worker_id))
prc.start()
pid = prc.pid
pid: int = prc.pid # type: ignore

os.close(read_fd)
worker_state.children[pid] = worker_id
_set_pipe_size(write_fd, worker_id)
worker_state.write_pipes[pid] = os.fdopen(write_fd, 'wb')
log.info('started child %d, pid=%d', worker_id, pid)
return False
return pid


def _set_pipe_size(fd: int, worker_id: int) -> None:
Expand All @@ -184,14 +186,17 @@ def _worker_function_wrapper(worker_function, worker_listener_handler, read_fd,
gc.enable()
worker_state.is_master = False

loop = asyncio.get_event_loop()
if loop.is_running():
loop.stop()
with suppress(Exception):
loop = asyncio.get_event_loop()
if loop.is_running():
loop.stop()

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

task = loop.create_task(_worker_listener(read_fd, worker_listener_handler))
LISTENER_TASK.add(task)

worker_function()


Expand All @@ -216,41 +221,38 @@ async def _worker_listener(read_fd: int, worker_listener_handler: Callable) -> N
log.exception('failed to fetch data from master %s', e)


def _master_function_wrapper(worker_state: WorkerState, master_function: Callable) -> None:
data_for_share: dict = {}
lock = Lock()

resend_notification: Queue = Queue(maxsize=1)
def __master_function_wrapper(
worker_state: WorkerState, master_after_fork_action: Callable, shared_data: dict, lock: Lock
) -> None:
if not lock:
master_after_fork_action(None)
return

resend_thread = Thread(
target=_resend,
args=(worker_state, resend_notification, worker_state.resend_dict, lock, data_for_share),
target=__resend,
args=(worker_state, worker_state.resend_notification, shared_data, lock),
daemon=True,
)
resend_thread.start()

send_to_all_workers = partial(_send_to_all, worker_state, resend_notification, worker_state.resend_dict)
master_function_thread = Thread(
target=master_function,
args=(data_for_share, lock, send_to_all_workers),
daemon=True,
)
master_function_thread.start()
update_shared_data_hook = partial(__send_to_all, worker_state, worker_state.resend_notification)
master_after_fork_action(update_shared_data_hook)


def _resend(
def __resend(
worker_state: WorkerState,
resend_notification: Queue,
resend_dict: dict[int, bool],
shared_data: dict,
lock: Lock,
data_for_share: dict,
) -> None:
resend_dict = worker_state.resend_dict

while True:
resend_notification.get()
time.sleep(1.0)

with lock:
data = pickle.dumps(list(data_for_share.values()))
data = pickle.dumps(list(shared_data.values()))
clients = list(resend_dict.keys())
if log.isEnabledFor(logging.DEBUG):
client_ids = ','.join(map(str, clients))
Expand All @@ -264,23 +266,21 @@ def _resend(
continue

# writing 2 times to ensure fix of client reading pattern
_send_update(resend_notification, resend_dict, worker_id, pipe, data)
_send_update(resend_notification, resend_dict, worker_id, pipe, data)
__send_update(resend_notification, resend_dict, worker_id, pipe, data)
__send_update(resend_notification, resend_dict, worker_id, pipe, data)


def _send_to_all(
def __send_to_all(
worker_state: WorkerState,
resend_notification: Queue,
resend_dict: dict[int, bool],
data_raw: Any,
data: bytes,
) -> None:
data = pickle.dumps(data_raw)
log.debug('sending data to all workers length: %d', len(data))
for worker_pid, pipe in worker_state.write_pipes.items():
_send_update(resend_notification, resend_dict, worker_pid, pipe, data)
__send_update(resend_notification, worker_state.resend_dict, worker_pid, pipe, data)


def _send_update(
def __send_update(
resend_notification: Queue,
resend_dict: dict[int, bool],
worker_pid: int,
Expand All @@ -301,3 +301,9 @@ def _send_update(
resend_notification.put_nowait(True)
except Exception as e:
log.exception('client %s pipe write failed %s', worker_pid, e)


def on_worker_restart(worker_state: WorkerState, worker_pid: int) -> None:
Powerrr marked this conversation as resolved.
Show resolved Hide resolved
worker_state.resend_dict[worker_pid] = True
with contextlib.suppress(Full):
worker_state.resend_notification.put_nowait(True)
Loading
Loading