Skip to content

Commit

Permalink
HH-217737 fix fork workers race conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
712u3 committed Aug 22, 2024
1 parent 456d2d7 commit 22341a3
Show file tree
Hide file tree
Showing 27 changed files with 281 additions and 234 deletions.
53 changes: 25 additions & 28 deletions frontik/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
import multiprocessing
import os
import time
from collections.abc import Callable
from ctypes import c_bool, c_int
from threading import Lock
from typing import Optional, Union

from aiokafka import AIOKafkaProducer
Expand Down Expand Up @@ -35,7 +33,7 @@
router,
routers,
)
from frontik.service_discovery import UpstreamManager
from frontik.service_discovery import MasterServiceDiscovery, WorkerServiceDiscovery

app_logger = logging.getLogger('app_logger')
_server_tasks = set()
Expand Down Expand Up @@ -76,8 +74,7 @@ def __init__(self, app_module_name: Optional[str] = None) -> None:
self.json = frontik.producers.json_producer.JsonProducerFactory(self)

self.available_integrations: list[integrations.Integration] = []
self.http_client: Optional[AIOHttpClientWrapper] = None
self.http_client_factory: HttpClientFactory
self.http_client_factory: Optional[HttpClientFactory] = None

self.statsd_client: Union[StatsDClient, StatsDClientStub] = create_statsd_client(options, self)

Expand All @@ -96,37 +93,40 @@ def __init__(self, app_module_name: Optional[str] = None) -> None:
self.settings: dict = {}

self.asgi_app = FrontikAsgiApp()
self.service_discovery: Union[MasterServiceDiscovery, WorkerServiceDiscovery, None] = None

def __call__(self, tornado_request: httputil.HTTPServerRequest) -> None:
# for make it more asgi, reimplement tornado.http1connection._server_request_loop and ._read_message
task = asyncio.create_task(serve_tornado_request(self, self.asgi_app, tornado_request))
_server_tasks.add(task)
task.add_done_callback(_server_tasks.discard)

def create_upstream_manager(
def make_service_discovery(
self,
upstreams: dict[str, Upstream],
upstreams_lock: Optional[Lock],
send_to_all_workers: Optional[Callable],
with_consul: bool,
) -> None:
self.upstream_manager = UpstreamManager(
upstreams,
self.statsd_client,
upstreams_lock,
send_to_all_workers,
with_consul,
self.app_name,
)

self.upstream_manager.send_updates() # initial full state sending
if with_consul:
self.service_discovery = MasterServiceDiscovery(self.statsd_client, self.app_name)
else:
self.service_discovery = WorkerServiceDiscovery(upstreams)

async def init(self) -> None:
self.available_integrations, integration_futures = integrations.load_integrations(self)
await asyncio.gather(*[future for future in integration_futures if future])

self.http_client = AIOHttpClientWrapper()
with_consul = self.worker_state.single_worker_mode and options.consul_enabled
self.make_service_discovery(self.worker_state.init_shared_data, with_consul)
if isinstance(self.service_discovery, MasterServiceDiscovery):
self.service_discovery.register_service()

self.http_client_factory = await self.make_http_client_factory()
self.asgi_app.http_client_factory = self.http_client_factory

if self.worker_state.single_worker_mode:
self.worker_state.master_done.value = True

async def make_http_client_factory(self) -> HttpClientFactory:
kafka_cluster = options.http_client_metrics_kafka_cluster
send_metrics_to_kafka = kafka_cluster and kafka_cluster in options.kafka_clusters

Expand All @@ -143,20 +143,17 @@ async def init(self) -> None:
self.get_kafka_producer(kafka_cluster) if send_metrics_to_kafka and kafka_cluster is not None else None
)

with_consul = self.worker_state.single_worker_mode and options.consul_enabled
self.create_upstream_manager({}, None, None, with_consul)
self.upstream_manager.register_service()

assert self.service_discovery is not None
request_balancer_builder = RequestBalancerBuilder(
self.upstream_manager.get_upstreams(),
self.service_discovery.get_upstreams_unsafe(), # very very bad
statsd_client=self.statsd_client,
kafka_producer=kafka_producer,
)
self.http_client_factory = HttpClientFactory(self.app_name, self.http_client, request_balancer_builder)
self.asgi_app.http_client_factory = self.http_client_factory
return HttpClientFactory(self.app_name, AIOHttpClientWrapper(), request_balancer_builder)

if self.worker_state.single_worker_mode:
self.worker_state.master_done.value = True
def deinit(self) -> None:
if isinstance(self.service_discovery, MasterServiceDiscovery):
self.service_discovery.deregister_service_and_close()

def application_config(self) -> DefaultConfig:
return FrontikApplication.DefaultConfig()
Expand Down
1 change: 1 addition & 0 deletions frontik/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def prepare(self) -> None:
self.xml_producer = self.application.xml.get_producer(self)
self.doc = self.xml_producer.doc

assert self.application.http_client_factory is not None
self._http_client: HttpClient = self.application.http_client_factory.get_http_client(
self.modify_http_client_request,
self.debug_mode.enabled,
Expand Down
108 changes: 56 additions & 52 deletions frontik/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import sys
import time
from collections.abc import Callable
from copy import deepcopy
from dataclasses import dataclass, field
from functools import partial
from multiprocessing.sharedctypes import Synchronized
Expand Down Expand Up @@ -41,16 +42,19 @@ class WorkerState:
is_master: bool = True
children: dict = field(default_factory=lambda: {}) # pid: worker_id
write_pipes: dict = field(default_factory=lambda: {}) # pid: write_pipe
resend_notification: Queue = field(default_factory=lambda: Queue(maxsize=1))
resend_dict: dict = field(default_factory=lambda: {}) # pid: flag
terminating: bool = False
single_worker_mode: bool = True
init_shared_data: dict = field(default_factory=lambda: {})


def fork_workers(
*,
worker_state: WorkerState,
num_workers: int,
master_function: Callable,
master_before_fork_action: Callable,
master_after_fork_action: Callable,
master_before_shutdown_action: Callable,
worker_function: Callable,
worker_listener_handler: Callable,
Expand All @@ -71,11 +75,13 @@ def master_sigterm_handler(signum, _frame):
signal.signal(signal.SIGTERM, master_sigterm_handler)
signal.signal(signal.SIGINT, master_sigterm_handler)

worker_function_wrapped = partial(_worker_function_wrapper, worker_function, worker_listener_handler)
shared_data, lock = master_before_fork_action()

worker_function_wrapped = partial(
_worker_function_wrapper, worker_function, shared_data, lock, worker_listener_handler
)
for worker_id in range(num_workers):
is_worker = _start_child(worker_id, worker_state, worker_function_wrapped)
if is_worker:
return
_start_child(worker_id, worker_state, worker_function_wrapped)

gc.enable()
timeout = time.time() + options.init_workers_timeout_sec
Expand All @@ -88,13 +94,15 @@ def master_sigterm_handler(signum, _frame):
f'{worker_state.init_workers_count_down.value} workers',
)
time.sleep(0.1)
_master_function_wrapper(worker_state, master_function)

__master_function_wrapper(worker_state, master_after_fork_action, shared_data, lock)
worker_state.master_done.value = True
_supervise_workers(worker_state, worker_function_wrapped, master_before_shutdown_action)
_supervise_workers(worker_state, worker_function_wrapped)


def _supervise_workers(
worker_state: WorkerState, worker_function: Callable, master_before_shutdown_action: Callable
worker_state: WorkerState,
worker_function: Callable,
) -> None:
while worker_state.children:
try:
Expand All @@ -116,16 +124,6 @@ def _supervise_workers(

if os.WIFSIGNALED(status):
log.warning('child %d (pid %d) killed by signal %d, restarting', worker_id, pid, os.WTERMSIG(status))

# TODO remove this block # noqa
master_before_shutdown_action()
for pid, worker_id in worker_state.children.items():
log.info('sending %s to child %d (pid %d)', signal.Signals(os.WTERMSIG(status)).name, worker_id, pid)
os.kill(pid, signal.SIGTERM)
log.info('all children terminated, exiting')
time.sleep(options.stop_timeout)
sys.exit(0)

elif os.WEXITSTATUS(status) != 0:
log.warning('child %d (pid %d) exited with status %d, restarting', worker_id, pid, os.WEXITSTATUS(status))
else:
Expand All @@ -136,30 +134,29 @@ def _supervise_workers(
log.info('server is shutting down, not restarting %d', worker_id)
continue

is_worker = _start_child(worker_id, worker_state, worker_function)
if is_worker:
return
worker_pid = _start_child(worker_id, worker_state, worker_function)
on_worker_restart(worker_state, worker_pid)

log.info('all children terminated, exiting')
sys.exit(0)


# returns True inside child process, otherwise False
def _start_child(worker_id: int, worker_state: WorkerState, worker_function: Callable) -> bool:
def _start_child(worker_id: int, worker_state: WorkerState, worker_function: Callable) -> int:
# it cannot be multiprocessing.pipe because we need to set nonblock flag and connect to asyncio
read_fd, write_fd = os.pipe()
os.set_blocking(read_fd, False)
os.set_blocking(write_fd, False)

prc = multiprocessing.Process(target=worker_function, args=(read_fd, write_fd, worker_state, worker_id))
prc.start()
pid = prc.pid
pid: int = prc.pid # type: ignore

os.close(read_fd)
worker_state.children[pid] = worker_id
_set_pipe_size(write_fd, worker_id)
worker_state.write_pipes[pid] = os.fdopen(write_fd, 'wb')
log.info('started child %d, pid=%d', worker_id, pid)
return False
return pid


def _set_pipe_size(fd: int, worker_id: int) -> None:
Expand All @@ -178,7 +175,9 @@ def _errno_from_exception(e: BaseException) -> Optional[int]:
return None


def _worker_function_wrapper(worker_function, worker_listener_handler, read_fd, write_fd, worker_state, worker_id):
def _worker_function_wrapper(
worker_function, shared_data, lock, worker_listener_handler, read_fd, write_fd, worker_state, worker_id
):
os.close(write_fd)
_set_pipe_size(read_fd, worker_id)
gc.enable()
Expand All @@ -192,6 +191,10 @@ def _worker_function_wrapper(worker_function, worker_listener_handler, read_fd,

task = loop.create_task(_worker_listener(read_fd, worker_listener_handler))
LISTENER_TASK.add(task)

if lock:
with lock:
worker_state.init_shared_data = deepcopy(shared_data)
worker_function()


Expand All @@ -216,41 +219,38 @@ async def _worker_listener(read_fd: int, worker_listener_handler: Callable) -> N
log.exception('failed to fetch data from master %s', e)


def _master_function_wrapper(worker_state: WorkerState, master_function: Callable) -> None:
data_for_share: dict = {}
lock = Lock()

resend_notification: Queue = Queue(maxsize=1)
def __master_function_wrapper(
worker_state: WorkerState, master_after_fork_action: Callable, shared_data: dict, lock: Lock
) -> None:
if not lock:
master_after_fork_action(None)
return

resend_thread = Thread(
target=_resend,
args=(worker_state, resend_notification, worker_state.resend_dict, lock, data_for_share),
target=__resend,
args=(worker_state, worker_state.resend_notification, shared_data, lock),
daemon=True,
)
resend_thread.start()

send_to_all_workers = partial(_send_to_all, worker_state, resend_notification, worker_state.resend_dict)
master_function_thread = Thread(
target=master_function,
args=(data_for_share, lock, send_to_all_workers),
daemon=True,
)
master_function_thread.start()
update_shared_data_hook = partial(__send_to_all, worker_state, worker_state.resend_notification)
master_after_fork_action(update_shared_data_hook)


def _resend(
def __resend(
worker_state: WorkerState,
resend_notification: Queue,
resend_dict: dict[int, bool],
shared_data: dict,
lock: Lock,
data_for_share: dict,
) -> None:
resend_dict = worker_state.resend_dict

while True:
resend_notification.get()
time.sleep(1.0)

with lock:
data = pickle.dumps(list(data_for_share.values()))
data = pickle.dumps(list(shared_data.values()))
clients = list(resend_dict.keys())
if log.isEnabledFor(logging.DEBUG):
client_ids = ','.join(map(str, clients))
Expand All @@ -264,23 +264,21 @@ def _resend(
continue

# writing 2 times to ensure fix of client reading pattern
_send_update(resend_notification, resend_dict, worker_id, pipe, data)
_send_update(resend_notification, resend_dict, worker_id, pipe, data)
__send_update(resend_notification, resend_dict, worker_id, pipe, data)
__send_update(resend_notification, resend_dict, worker_id, pipe, data)


def _send_to_all(
def __send_to_all(
worker_state: WorkerState,
resend_notification: Queue,
resend_dict: dict[int, bool],
data_raw: Any,
data: bytes,
) -> None:
data = pickle.dumps(data_raw)
log.debug('sending data to all workers length: %d', len(data))
for worker_pid, pipe in worker_state.write_pipes.items():
_send_update(resend_notification, resend_dict, worker_pid, pipe, data)
__send_update(resend_notification, worker_state.resend_dict, worker_pid, pipe, data)


def _send_update(
def __send_update(
resend_notification: Queue,
resend_dict: dict[int, bool],
worker_pid: int,
Expand All @@ -301,3 +299,9 @@ def _send_update(
resend_notification.put_nowait(True)
except Exception as e:
log.exception('client %s pipe write failed %s', worker_pid, e)


def on_worker_restart(worker_state: WorkerState, worker_pid: int) -> None:
worker_state.resend_dict[worker_pid] = True
with contextlib.suppress(Full):
worker_state.resend_notification.put_nowait(True)
Loading

0 comments on commit 22341a3

Please sign in to comment.