diff --git a/frontik/app.py b/frontik/app.py index 51bc95f49..f3a3500db 100644 --- a/frontik/app.py +++ b/frontik/app.py @@ -4,16 +4,14 @@ import multiprocessing import os import time -from collections.abc import Callable from ctypes import c_bool, c_int -from threading import Lock from typing import Optional, Union from aiokafka import AIOKafkaProducer from fastapi import FastAPI, HTTPException from http_client import AIOHttpClientWrapper, HttpClientFactory from http_client import options as http_client_options -from http_client.balancing import RequestBalancerBuilder, Upstream +from http_client.balancing import RequestBalancerBuilder from lxml import etree from tornado import httputil @@ -35,7 +33,7 @@ router, routers, ) -from frontik.service_discovery import UpstreamManager +from frontik.service_discovery import MasterServiceDiscovery, ServiceDiscovery, WorkerServiceDiscovery app_logger = logging.getLogger('app_logger') _server_tasks = set() @@ -76,7 +74,6 @@ def __init__(self, app_module_name: Optional[str] = None) -> None: self.json = frontik.producers.json_producer.JsonProducerFactory(self) self.available_integrations: list[integrations.Integration] = [] - self.http_client: Optional[AIOHttpClientWrapper] = None self.http_client_factory: HttpClientFactory self.statsd_client: Union[StatsDClient, StatsDClientStub] = create_statsd_client(options, self) @@ -96,6 +93,7 @@ def __init__(self, app_module_name: Optional[str] = None) -> None: self.settings: dict = {} self.asgi_app = FrontikAsgiApp() + self.service_discovery: ServiceDiscovery def __call__(self, tornado_request: httputil.HTTPServerRequest) -> None: # for make it more asgi, reimplement tornado.http1connection._server_request_loop and ._read_message @@ -103,30 +101,27 @@ def __call__(self, tornado_request: httputil.HTTPServerRequest) -> None: _server_tasks.add(task) task.add_done_callback(_server_tasks.discard) - def create_upstream_manager( - self, - upstreams: dict[str, Upstream], - upstreams_lock: Optional[Lock], - send_to_all_workers: Optional[Callable], - with_consul: bool, - ) -> None: - self.upstream_manager = UpstreamManager( - upstreams, - self.statsd_client, - upstreams_lock, - send_to_all_workers, - with_consul, - self.app_name, - ) - - self.upstream_manager.send_updates() # initial full state sending + def make_service_discovery(self) -> ServiceDiscovery: + if self.worker_state.is_master and options.consul_enabled: + return MasterServiceDiscovery(self.statsd_client, self.app_name) + else: + return WorkerServiceDiscovery(self.worker_state.initial_shared_data) - async def init(self) -> None: + async def install_integrations(self) -> None: self.available_integrations, integration_futures = integrations.load_integrations(self) await asyncio.gather(*[future for future in integration_futures if future]) - self.http_client = AIOHttpClientWrapper() + self.service_discovery = self.make_service_discovery() + self.http_client_factory = self.make_http_client_factory() + self.asgi_app.http_client_factory = self.http_client_factory + + async def init(self) -> None: + await self.install_integrations() + if self.worker_state.is_master: + self.worker_state.master_done.value = True + + def make_http_client_factory(self) -> HttpClientFactory: kafka_cluster = options.http_client_metrics_kafka_cluster send_metrics_to_kafka = kafka_cluster and kafka_cluster in options.kafka_clusters @@ -143,20 +138,12 @@ async def init(self) -> None: self.get_kafka_producer(kafka_cluster) if send_metrics_to_kafka and kafka_cluster is not None else None ) - with_consul = self.worker_state.single_worker_mode and options.consul_enabled - self.create_upstream_manager({}, None, None, with_consul) - self.upstream_manager.register_service() - request_balancer_builder = RequestBalancerBuilder( - self.upstream_manager.get_upstreams(), + upstreams=self.service_discovery.get_upstreams_unsafe(), statsd_client=self.statsd_client, kafka_producer=kafka_producer, ) - self.http_client_factory = HttpClientFactory(self.app_name, self.http_client, request_balancer_builder) - self.asgi_app.http_client_factory = self.http_client_factory - - if self.worker_state.single_worker_mode: - self.worker_state.master_done.value = True + return HttpClientFactory(self.app_name, AIOHttpClientWrapper(), request_balancer_builder) def application_config(self) -> DefaultConfig: return FrontikApplication.DefaultConfig() diff --git a/frontik/options.py b/frontik/options.py index 47030134c..e1f526d5f 100644 --- a/frontik/options.py +++ b/frontik/options.py @@ -17,6 +17,7 @@ class Options: xheaders: bool = False validate_request_id: bool = False xsrf_cookies: bool = False + max_body_size: int = 1_000_000_000 openapi_enabled: bool = False config: Optional[str] = None diff --git a/frontik/process.py b/frontik/process.py index 6c3448713..05bdd0d4d 100644 --- a/frontik/process.py +++ b/frontik/process.py @@ -12,6 +12,8 @@ import sys import time from collections.abc import Callable +from contextlib import suppress +from copy import deepcopy from dataclasses import dataclass, field from functools import partial from multiprocessing.sharedctypes import Synchronized @@ -41,22 +43,23 @@ class WorkerState: is_master: bool = True children: dict = field(default_factory=lambda: {}) # pid: worker_id write_pipes: dict = field(default_factory=lambda: {}) # pid: write_pipe + resend_notification: Queue = field(default_factory=lambda: Queue(maxsize=1)) resend_dict: dict = field(default_factory=lambda: {}) # pid: flag terminating: bool = False - single_worker_mode: bool = True + initial_shared_data: dict = field(default_factory=lambda: {}) def fork_workers( *, worker_state: WorkerState, num_workers: int, - master_function: Callable, + master_before_fork_action: Callable, + master_after_fork_action: Callable, master_before_shutdown_action: Callable, worker_function: Callable, worker_listener_handler: Callable, ) -> None: log.info('starting %d processes', num_workers) - worker_state.single_worker_mode = False def master_sigterm_handler(signum, _frame): if not worker_state.is_master: @@ -65,17 +68,17 @@ def master_sigterm_handler(signum, _frame): worker_state.terminating = True master_before_shutdown_action() for pid, worker_id in worker_state.children.items(): - log.info('sending %s to child %d (pid %d)', signal.Signals(signum).name, worker_id, pid) + log.info('sending %s to child %d (pid %d)', signal.SIGTERM.name, worker_id, pid) os.kill(pid, signal.SIGTERM) signal.signal(signal.SIGTERM, master_sigterm_handler) signal.signal(signal.SIGINT, master_sigterm_handler) + shared_data, lock = master_before_fork_action() + worker_function_wrapped = partial(_worker_function_wrapper, worker_function, worker_listener_handler) for worker_id in range(num_workers): - is_worker = _start_child(worker_id, worker_state, worker_function_wrapped) - if is_worker: - return + _start_child(worker_id, worker_state, shared_data, lock, worker_function_wrapped) gc.enable() timeout = time.time() + options.init_workers_timeout_sec @@ -88,13 +91,17 @@ def master_sigterm_handler(signum, _frame): f'{worker_state.init_workers_count_down.value} workers', ) time.sleep(0.1) - _master_function_wrapper(worker_state, master_function) + + __master_function_wrapper(worker_state, master_after_fork_action, shared_data, lock) worker_state.master_done.value = True - _supervise_workers(worker_state, worker_function_wrapped, master_before_shutdown_action) + _supervise_workers(worker_state, shared_data, lock, worker_function_wrapped) def _supervise_workers( - worker_state: WorkerState, worker_function: Callable, master_before_shutdown_action: Callable + worker_state: WorkerState, + shared_data: dict, + lock: Lock, + worker_function: Callable, ) -> None: while worker_state.children: try: @@ -116,16 +123,6 @@ def _supervise_workers( if os.WIFSIGNALED(status): log.warning('child %d (pid %d) killed by signal %d, restarting', worker_id, pid, os.WTERMSIG(status)) - - # TODO remove this block # noqa - master_before_shutdown_action() - for pid, worker_id in worker_state.children.items(): - log.info('sending %s to child %d (pid %d)', signal.Signals(os.WTERMSIG(status)).name, worker_id, pid) - os.kill(pid, signal.SIGTERM) - log.info('all children terminated, exiting') - time.sleep(options.stop_timeout) - sys.exit(0) - elif os.WEXITSTATUS(status) != 0: log.warning('child %d (pid %d) exited with status %d, restarting', worker_id, pid, os.WEXITSTATUS(status)) else: @@ -136,30 +133,35 @@ def _supervise_workers( log.info('server is shutting down, not restarting %d', worker_id) continue - is_worker = _start_child(worker_id, worker_state, worker_function) - if is_worker: - return + worker_pid = _start_child(worker_id, worker_state, shared_data, lock, worker_function) + on_worker_restart(worker_state, worker_pid) + log.info('all children terminated, exiting') sys.exit(0) -# returns True inside child process, otherwise False -def _start_child(worker_id: int, worker_state: WorkerState, worker_function: Callable) -> bool: +def _start_child( + worker_id: int, worker_state: WorkerState, shared_data: dict, lock: Optional[Lock], worker_function: Callable +) -> int: # it cannot be multiprocessing.pipe because we need to set nonblock flag and connect to asyncio read_fd, write_fd = os.pipe() os.set_blocking(read_fd, False) os.set_blocking(write_fd, False) + if lock is not None: + with lock: + worker_state.initial_shared_data = deepcopy(shared_data) + prc = multiprocessing.Process(target=worker_function, args=(read_fd, write_fd, worker_state, worker_id)) prc.start() - pid = prc.pid + pid: int = prc.pid # type: ignore os.close(read_fd) worker_state.children[pid] = worker_id _set_pipe_size(write_fd, worker_id) worker_state.write_pipes[pid] = os.fdopen(write_fd, 'wb') log.info('started child %d, pid=%d', worker_id, pid) - return False + return pid def _set_pipe_size(fd: int, worker_id: int) -> None: @@ -184,14 +186,17 @@ def _worker_function_wrapper(worker_function, worker_listener_handler, read_fd, gc.enable() worker_state.is_master = False - loop = asyncio.get_event_loop() - if loop.is_running(): - loop.stop() + with suppress(Exception): + loop = asyncio.get_event_loop() + if loop.is_running(): + loop.stop() + loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) task = loop.create_task(_worker_listener(read_fd, worker_listener_handler)) LISTENER_TASK.add(task) + worker_function() @@ -216,41 +221,38 @@ async def _worker_listener(read_fd: int, worker_listener_handler: Callable) -> N log.exception('failed to fetch data from master %s', e) -def _master_function_wrapper(worker_state: WorkerState, master_function: Callable) -> None: - data_for_share: dict = {} - lock = Lock() - - resend_notification: Queue = Queue(maxsize=1) +def __master_function_wrapper( + worker_state: WorkerState, master_after_fork_action: Callable, shared_data: dict, lock: Lock +) -> None: + if not lock: + master_after_fork_action(None) + return resend_thread = Thread( - target=_resend, - args=(worker_state, resend_notification, worker_state.resend_dict, lock, data_for_share), + target=__resend, + args=(worker_state, worker_state.resend_notification, shared_data, lock), daemon=True, ) resend_thread.start() - send_to_all_workers = partial(_send_to_all, worker_state, resend_notification, worker_state.resend_dict) - master_function_thread = Thread( - target=master_function, - args=(data_for_share, lock, send_to_all_workers), - daemon=True, - ) - master_function_thread.start() + update_shared_data_hook = partial(__send_to_all, worker_state, worker_state.resend_notification) + master_after_fork_action(update_shared_data_hook) -def _resend( +def __resend( worker_state: WorkerState, resend_notification: Queue, - resend_dict: dict[int, bool], + shared_data: dict, lock: Lock, - data_for_share: dict, ) -> None: + resend_dict = worker_state.resend_dict + while True: resend_notification.get() time.sleep(1.0) with lock: - data = pickle.dumps(list(data_for_share.values())) + data = pickle.dumps(list(shared_data.values())) clients = list(resend_dict.keys()) if log.isEnabledFor(logging.DEBUG): client_ids = ','.join(map(str, clients)) @@ -264,23 +266,21 @@ def _resend( continue # writing 2 times to ensure fix of client reading pattern - _send_update(resend_notification, resend_dict, worker_id, pipe, data) - _send_update(resend_notification, resend_dict, worker_id, pipe, data) + __send_update(resend_notification, resend_dict, worker_id, pipe, data) + __send_update(resend_notification, resend_dict, worker_id, pipe, data) -def _send_to_all( +def __send_to_all( worker_state: WorkerState, resend_notification: Queue, - resend_dict: dict[int, bool], - data_raw: Any, + data: bytes, ) -> None: - data = pickle.dumps(data_raw) log.debug('sending data to all workers length: %d', len(data)) for worker_pid, pipe in worker_state.write_pipes.items(): - _send_update(resend_notification, resend_dict, worker_pid, pipe, data) + __send_update(resend_notification, worker_state.resend_dict, worker_pid, pipe, data) -def _send_update( +def __send_update( resend_notification: Queue, resend_dict: dict[int, bool], worker_pid: int, @@ -301,3 +301,9 @@ def _send_update( resend_notification.put_nowait(True) except Exception as e: log.exception('client %s pipe write failed %s', worker_pid, e) + + +def on_worker_restart(worker_state: WorkerState, worker_pid: int) -> None: + worker_state.resend_dict[worker_pid] = True + with contextlib.suppress(Full): + worker_state.resend_notification.put_nowait(True) diff --git a/frontik/server.py b/frontik/server.py index d7a11b3af..911c65c62 100644 --- a/frontik/server.py +++ b/frontik/server.py @@ -54,8 +54,9 @@ def main(config_file: Optional[str] = None) -> None: fork_workers( worker_state=app.worker_state, num_workers=options.workers, - master_function=partial(_multi_worker_master_function, app), - master_before_shutdown_action=lambda: app.upstream_manager.deregister_service_and_close(), # noqa PLW0108 + master_before_fork_action=partial(_master_before_fork_action, app), + master_after_fork_action=partial(_master_after_fork_action, app), + master_before_shutdown_action=partial(_master_before_shutdown_action, app), worker_function=partial(_run_worker, app), worker_listener_handler=partial(_worker_listener_handler, app), ) @@ -68,30 +69,39 @@ def main(config_file: Optional[str] = None) -> None: sys.exit(1) -def _multi_worker_master_function( +def _master_before_fork_action(app: FrontikApplication) -> tuple[dict, Optional[Lock]]: + async def async_actions() -> None: + await app.install_integrations() + if (local_before_fork_action := getattr(app, 'before_fork_action', None)) is not None: + await local_before_fork_action() + + asyncio.run(async_actions()) + return app.service_discovery.get_upstreams_with_lock() + + +def _master_after_fork_action( app: FrontikApplication, - upstreams: dict[str, Upstream], - upstreams_lock: Lock, - send_to_all_workers: Callable, + update_shared_data_hook: Optional[Callable], ) -> None: - app.create_upstream_manager(upstreams, upstreams_lock, send_to_all_workers, with_consul=options.consul_enabled) - app.upstream_manager.register_service() + if update_shared_data_hook is None: + return + + app.service_discovery.set_update_shared_data_hook(update_shared_data_hook) + app.service_discovery.send_updates() # send in case there were updates between worker creation and this point + app.service_discovery.register_service() + + +def _master_before_shutdown_action(app: FrontikApplication) -> None: + asyncio.run(_deinit_app(app)) def _worker_listener_handler(app: FrontikApplication, data: list[Upstream]) -> None: - app.upstream_manager.update_upstreams(data) + app.service_discovery.update_upstreams(data) def _run_worker(app: FrontikApplication) -> None: MDC.init('worker') - try: - import uvloop - except ImportError: - log.info('There is no installed uvloop; use asyncio event loop') - else: - uvloop.install() - loop = asyncio.get_event_loop() executor = ThreadPoolExecutor(options.common_executor_pool_size) loop.set_default_executor(executor) @@ -111,7 +121,7 @@ def run_server(app: FrontikApplication) -> None: """Starts Frontik server for an application""" loop = asyncio.get_event_loop() log.info('starting server on %s:%s', options.host, options.port) - http_server = HTTPServer(app, xheaders=options.xheaders) + http_server = HTTPServer(app, xheaders=options.xheaders, max_body_size=options.max_body_size) http_server.bind(options.port, options.host, reuse_port=options.reuse_port) http_server.start() @@ -150,12 +160,14 @@ async def _init_app(frontik_app: FrontikApplication) -> None: frontik_app.worker_state.init_workers_count_down.value -= 1 log.info('worker is up, remaining workers = %s', frontik_app.worker_state.init_workers_count_down.value) + frontik_app.service_discovery.register_service() + async def _deinit_app(app: FrontikApplication) -> None: deinit_futures: list[Optional[Union[Future, Coroutine]]] = [] deinit_futures.extend([integration.deinitialize_app(app) for integration in app.available_integrations]) - app.upstream_manager.deregister_service_and_close() + app.service_discovery.deregister_service_and_close() try: await asyncio.gather(*[future for future in deinit_futures if future]) @@ -163,6 +175,10 @@ async def _deinit_app(app: FrontikApplication) -> None: except Exception as e: log.exception('failed to deinit, deinit returned: %s', e) + await asyncio.sleep(options.stop_timeout) + if app.http_client_factory is not None: + await app.http_client_factory.http_client.client_session.close() + def anyio_noop(*_args, **_kwargs): raise RuntimeError(f'trying to use non async {_args[0]}') diff --git a/frontik/service_discovery.py b/frontik/service_discovery.py index 2117dba52..479db8feb 100644 --- a/frontik/service_discovery.py +++ b/frontik/service_discovery.py @@ -1,5 +1,7 @@ +import abc import itertools import logging +import pickle import socket from random import shuffle from threading import Lock @@ -68,36 +70,53 @@ def _get_hostname_or_raise(node_name: str) -> str: return node_name -class UpstreamManager: - def __init__( - self, - upstreams: dict[str, Upstream], - statsd_client: Union[StatsDClient, StatsDClientStub], - upstreams_lock: Optional[Lock], - send_to_all_workers: Optional[Callable], - with_consul: bool, - app_name: str, - ) -> None: - self.with_consul: bool = with_consul +class ServiceDiscovery(abc.ABC): + @abc.abstractmethod + def get_upstreams_unsafe(self) -> dict[str, Upstream]: + pass + + @abc.abstractmethod + def register_service(self) -> None: + pass + + @abc.abstractmethod + def deregister_service_and_close(self) -> None: + pass + + @abc.abstractmethod + def get_upstreams_with_lock(self) -> tuple[dict[str, Upstream], Optional[Lock]]: + pass + + @abc.abstractmethod + def set_update_shared_data_hook(self, update_shared_data_hook: Callable) -> None: + pass + + @abc.abstractmethod + def update_upstreams(self, upstreams: list[Upstream]) -> None: + pass + + @abc.abstractmethod + def send_updates(self, upstream: Optional[Upstream] = None) -> None: + pass + + +class MasterServiceDiscovery(ServiceDiscovery): + def __init__(self, statsd_client: Union[StatsDClient, StatsDClientStub], app_name: str) -> None: self._upstreams_config: dict[str, dict] = {} self._upstreams_servers: dict[str, list[Server]] = {} - self._upstreams = upstreams - self._upstreams_lock = upstreams_lock or Lock() # should be used when access self._upstreams - self._send_to_all_workers = send_to_all_workers - - if not self.with_consul: - log.info('Consul disabled, skipping') - return + self._upstreams: dict[str, Upstream] = {} + self._upstreams_lock = Lock() + self._send_to_all_workers: Optional[Callable] = None self.consul = SyncConsulClient( host=options.consul_host, port=options.consul_port, client_event_callback=ConsulMetricsTracker(statsd_client), ) - self._service_name = app_name + self.service_name = app_name self.hostname = _get_hostname_or_raise(options.node_name) - self.service_id = _make_service_id(options, service_name=self._service_name, hostname=self.hostname) + self.service_id = _make_service_id(options, service_name=self.service_name, hostname=self.hostname) self.address = _get_service_address(options) self.http_check = _create_http_check(options, self.address) self.consul_weight_watch_seconds = f'{options.consul_weight_watch_seconds}s' @@ -114,9 +133,9 @@ def __init__( cache_initial_warmup_timeout=self.consul_cache_initial_warmup_timeout_sec, consistency_mode=self.consul_weight_consistency_mode, recurse=False, - caller=self._service_name, + caller=self.service_name, ) - self.kvCache.add_listener(self._update_register, False) + self.kvCache.add_listener(self._update_register) upstream_cache = KVCache( self.consul.kv, @@ -127,9 +146,9 @@ def __init__( cache_initial_warmup_timeout=self.consul_cache_initial_warmup_timeout_sec, consistency_mode=self.consul_weight_consistency_mode, recurse=True, - caller=self._service_name, + caller=self.service_name, ) - upstream_cache.add_listener(self._update_upstreams_config, True) + upstream_cache.add_listener(self._update_upstreams_config, trigger_current=True) upstream_cache.start() allow_cross_dc = http_options.http_client_allow_cross_datacenter_requests @@ -142,22 +161,28 @@ def __init__( watch_seconds=self.consul_weight_watch_seconds, backoff_delay_seconds=self.consul_cache_backoff_delay_seconds, dc=dc, - caller=self._service_name, + caller=self.service_name, ) - health_cache.add_listener(self._update_upstreams_service, True) + health_cache.add_listener(self._update_upstreams_service, trigger_current=True) health_cache.start() if options.fail_start_on_empty_upstream: - self._check_empty_upstreams_on_startup() + self.__check_empty_upstreams_on_startup() + + def set_update_shared_data_hook(self, update_shared_data_hook: Callable) -> None: + self._send_to_all_workers = update_shared_data_hook + + def get_upstreams_with_lock(self) -> tuple[dict[str, Upstream], Lock]: + return self._upstreams, self._upstreams_lock + + def get_upstreams_unsafe(self) -> dict[str, Upstream]: + return self._upstreams def _update_register(self, key, new_value): weight = _get_weight_or_default(new_value) self._sync_register(weight) def register_service(self) -> None: - if not self.with_consul: - return - weight = _get_weight_or_default(self.kvCache.get_value()) self._sync_register(weight) self.kvCache.start() @@ -170,27 +195,21 @@ def _sync_register(self, weight: int) -> None: 'check': self.http_check, 'tags': options.consul_tags, 'weights': Weight.weights(weight, 0), - 'caller': self._service_name, + 'caller': self.service_name, } - if self.consul.agent.service.register(self._service_name, **register_params): + if self.consul.agent.service.register(self.service_name, **register_params): log.info('Successfully registered service %s', register_params) else: raise Exception(f'Failed to register {register_params}') def deregister_service_and_close(self) -> None: - if not self.with_consul: - return - self.kvCache.stop() - if self.consul.agent.service.deregister(self.service_id, self._service_name): + if self.consul.agent.service.deregister(self.service_id, self.service_name): log.info('Successfully deregistered service %s', self.service_id) else: log.info('Failed to deregister service %s normally', self.service_id) - def get_upstreams(self) -> dict[str, Upstream]: - return self._upstreams - - def _check_empty_upstreams_on_startup(self) -> None: + def __check_empty_upstreams_on_startup(self) -> None: empty_upstreams = [k for k, v in self._upstreams.items() if not v.servers] if empty_upstreams: msg = f'failed startup application, because for next upstreams got empty servers: {empty_upstreams}' @@ -235,7 +254,7 @@ def send_updates(self, upstream: Optional[Upstream] = None) -> None: return with self._upstreams_lock: upstreams = list(self._upstreams.values()) if upstream is None else [upstream] - self._send_to_all_workers(upstreams) + self._send_to_all_workers(pickle.dumps(upstreams)) def _create_upstream(self, key: str) -> Upstream: servers = self._combine_servers(key) @@ -250,22 +269,48 @@ def _combine_servers(self, key: str) -> list[Server]: servers_from_all_dc += servers return servers_from_all_dc + def update_upstreams(self, upstreams: list[Upstream]) -> None: + raise RuntimeError('master should not serve upstream updates') + + +class WorkerServiceDiscovery(ServiceDiscovery): + def __init__(self, upstreams: dict[str, Upstream]) -> None: + self.upstreams = upstreams + def update_upstreams(self, upstreams: list[Upstream]) -> None: for upstream in upstreams: - self._update_upstream(upstream) + self.__update_upstream(upstream) - def _update_upstream(self, upstream: Upstream) -> None: - current_upstream = self._upstreams.get(upstream.name) + def __update_upstream(self, upstream: Upstream) -> None: + current_upstream = self.upstreams.get(upstream.name) if current_upstream is None: shuffle(upstream.servers) - self._upstreams[upstream.name] = upstream + self.upstreams[upstream.name] = upstream log.debug('add %s upstream: %s', upstream.name, str(upstream)) return current_upstream.update(upstream) log.debug('update %s upstream: %s', upstream.name, str(upstream)) + def get_upstreams_unsafe(self) -> dict[str, Upstream]: + return self.upstreams + + def register_service(self) -> None: + pass + + def deregister_service_and_close(self) -> None: + pass + + def get_upstreams_with_lock(self) -> tuple[dict[str, Upstream], Optional[Lock]]: + return {}, None + + def set_update_shared_data_hook(self, update_shared_data_hook: Callable) -> None: + raise RuntimeError('worker should not use update hook') + + def send_updates(self, upstream: Optional[Upstream] = None) -> None: + raise RuntimeError('worker should not use send updates') + class ConsulMetricsTracker(ClientEventCallback): def __init__(self, statsd_client: Union[StatsDClient, StatsDClientStub]) -> None: diff --git a/frontik/testing.py b/frontik/testing.py index 828737813..67ad42474 100644 --- a/frontik/testing.py +++ b/frontik/testing.py @@ -62,7 +62,7 @@ async def _finish_server_setup( ) -> None: self.app = frontik_app self.port = options.port - self.http_client: AIOHttpClientWrapper = frontik_app.http_client + self.http_client: AIOHttpClientWrapper = frontik_app.http_client_factory.http_client self.use_tornado_mocks = with_tornado_mocks if with_tornado_mocks: patch_http_client(self.http_client, fail_on_unknown=False) @@ -263,9 +263,9 @@ def fetch_json(self, path: str, query: Optional[dict] = None, **kwargs: Any) -> """Fetch the request and parse JSON tree from response body.""" return json.loads(self.fetch(path, query, **kwargs).raw_body) - def patch_app_http_client(self, app: FrontikApplication) -> None: + def patch_app_http_client(self, _app: FrontikApplication) -> None: """Patches application HTTPClient to enable requests stubbing.""" - patch_http_client(app.http_client) + patch_http_client(self.http_client) def set_stub( self, @@ -280,7 +280,7 @@ def set_stub( **kwargs: Any, ) -> None: set_stub( - self._app.http_client, + self.http_client, url, request_method, response_function, @@ -293,8 +293,8 @@ def set_stub( ) def tearDown(self) -> None: - if self._app.http_client is not None: - self.io_loop.run_sync(self._app.http_client.client_session.close) + if self.http_client is not None: + self.io_loop.run_sync(self.http_client.client_session.close) # type: ignore if self.forced_client is not None: self.io_loop.run_sync(self.forced_client.client_session.close) super().tearDown() diff --git a/tests/projects/balancer_app/pages/__init__.py b/tests/projects/balancer_app/pages/__init__.py index 22796e835..354d69dd9 100644 --- a/tests/projects/balancer_app/pages/__init__.py +++ b/tests/projects/balancer_app/pages/__init__.py @@ -7,19 +7,19 @@ def check_all_servers_occupied(handler: PageHandler, name: str) -> None: - servers = handler.application.upstream_manager.get_upstreams().get(name, noop_upstream).servers + servers = handler.application.service_discovery.get_upstreams_unsafe().get(name, noop_upstream).servers if any(server.current_requests == 0 for server in servers): raise HTTPError(500, 'some servers are ignored') def check_all_requests_done(handler: PageHandler, name: str) -> None: - servers = handler.application.upstream_manager.get_upstreams().get(name, noop_upstream).servers + servers = handler.application.service_discovery.get_upstreams_unsafe().get(name, noop_upstream).servers if any(server.current_requests != 0 for server in servers): raise HTTPError(500, 'some servers have unfinished requests') def check_all_servers_were_occupied(handler: PageHandler, name: str) -> None: - servers = handler.application.upstream_manager.get_upstreams().get(name, noop_upstream).servers + servers = handler.application.service_discovery.get_upstreams_unsafe().get(name, noop_upstream).servers if any(server.current_requests != 0 for server in servers): raise HTTPError(500, 'some servers are ignored') if any(server.stat_requests == 0 for server in servers): diff --git a/tests/projects/balancer_app/pages/different_datacenter.py b/tests/projects/balancer_app/pages/different_datacenter.py index f3014b435..afab06a59 100644 --- a/tests/projects/balancer_app/pages/different_datacenter.py +++ b/tests/projects/balancer_app/pages/different_datacenter.py @@ -16,7 +16,7 @@ async def get_page(handler=get_current_handler()): normal_server.datacenter = 'dc2' upstream = Upstream('different_datacenter', {}, [free_server, normal_server]) - handler.application.upstream_manager.get_upstreams()['different_datacenter'] = upstream + handler.application.service_discovery.get_upstreams_unsafe()['different_datacenter'] = upstream result = await handler.post_url('different_datacenter', handler.path) for server in upstream.servers: diff --git a/tests/projects/balancer_app/pages/no_available_backend.py b/tests/projects/balancer_app/pages/no_available_backend.py index 3aa05c80a..37b0bd02b 100644 --- a/tests/projects/balancer_app/pages/no_available_backend.py +++ b/tests/projects/balancer_app/pages/no_available_backend.py @@ -9,7 +9,7 @@ @plain_router.get('/no_available_backend', cls=PageHandler) async def get_page(handler=get_current_handler()): - upstreams = handler.application.upstream_manager.get_upstreams() + upstreams = handler.application.service_discovery.get_upstreams_unsafe() upstreams['no_available_backend'] = Upstream('no_available_backend', {}, []) request = handler.post_url('no_available_backend', handler.path) diff --git a/tests/projects/balancer_app/pages/no_retry_error.py b/tests/projects/balancer_app/pages/no_retry_error.py index e55492a0b..036fb1432 100644 --- a/tests/projects/balancer_app/pages/no_retry_error.py +++ b/tests/projects/balancer_app/pages/no_retry_error.py @@ -9,7 +9,7 @@ @plain_router.get('/no_retry_error', cls=PageHandler) async def get_page(handler=get_current_handler()): - upstreams = handler.application.upstream_manager.get_upstreams() + upstreams = handler.application.service_discovery.get_upstreams_unsafe() upstreams['no_retry_error'] = Upstream('no_retry_error', {}, [get_server(handler, 'broken')]) result = await handler.post_url('no_retry_error', handler.path) diff --git a/tests/projects/balancer_app/pages/no_retry_timeout.py b/tests/projects/balancer_app/pages/no_retry_timeout.py index 56f7eaa43..b59ebacea 100644 --- a/tests/projects/balancer_app/pages/no_retry_timeout.py +++ b/tests/projects/balancer_app/pages/no_retry_timeout.py @@ -11,7 +11,7 @@ @plain_router.get('/no_retry_timeout', cls=PageHandler) async def get_page(handler=get_current_handler()): - upstreams = handler.application.upstream_manager.get_upstreams() + upstreams = handler.application.service_discovery.get_upstreams_unsafe() upstreams['no_retry_timeout'] = Upstream('no_retry_timeout', {}, [get_server(handler, 'broken')]) result = await handler.post_url('no_retry_timeout', handler.path, request_timeout=0.2) diff --git a/tests/projects/balancer_app/pages/profile_with_retry.py b/tests/projects/balancer_app/pages/profile_with_retry.py index 54c806e2b..a9bbc5c65 100644 --- a/tests/projects/balancer_app/pages/profile_with_retry.py +++ b/tests/projects/balancer_app/pages/profile_with_retry.py @@ -15,7 +15,7 @@ async def get_page(handler=get_current_handler()): 'profile_without_retry': UpstreamConfig(max_tries=1), 'profile_with_retry': UpstreamConfig(max_tries=2), } - upstreams = handler.application.upstream_manager.get_upstreams() + upstreams = handler.application.service_discovery.get_upstreams_unsafe() upstreams['profile_with_retry'] = Upstream( 'profile_with_retry', upstream_config, diff --git a/tests/projects/balancer_app/pages/profile_without_retry.py b/tests/projects/balancer_app/pages/profile_without_retry.py index 294db6711..abb975e34 100644 --- a/tests/projects/balancer_app/pages/profile_without_retry.py +++ b/tests/projects/balancer_app/pages/profile_without_retry.py @@ -14,7 +14,7 @@ async def get_page(handler=get_current_handler()): 'profile_without_retry': UpstreamConfig(max_tries=1), 'profile_with_retry': UpstreamConfig(max_tries=2), } - upstreams = handler.application.upstream_manager.get_upstreams() + upstreams = handler.application.service_discovery.get_upstreams_unsafe() upstreams['profile_without_retry'] = Upstream( 'profile_without_retry', upstream_config, diff --git a/tests/projects/balancer_app/pages/requests_count.py b/tests/projects/balancer_app/pages/requests_count.py index da40461c2..cd0c66255 100644 --- a/tests/projects/balancer_app/pages/requests_count.py +++ b/tests/projects/balancer_app/pages/requests_count.py @@ -11,7 +11,7 @@ @plain_router.get('/requests_count', cls=PageHandler) async def get_page(handler=get_current_handler()): - upstreams = handler.application.upstream_manager.get_upstreams() + upstreams = handler.application.service_discovery.get_upstreams_unsafe() upstreams['requests_count_async'] = Upstream('requests_count_async', {}, [get_server(handler, 'normal')]) handler.text = '' @@ -34,6 +34,6 @@ async def get_page(handler=get_current_handler()): @plain_router.post('/requests_count', cls=PageHandler) async def post_page(handler=get_current_handler()): handler.set_header('Content-Type', media_types.TEXT_PLAIN) - upstreams = handler.application.upstream_manager.get_upstreams() + upstreams = handler.application.service_discovery.get_upstreams_unsafe() servers = upstreams['requests_count_async'].servers handler.text = str(servers[0].stat_requests) diff --git a/tests/projects/balancer_app/pages/retry_connect.py b/tests/projects/balancer_app/pages/retry_connect.py index 5190d88ba..98e6d205c 100644 --- a/tests/projects/balancer_app/pages/retry_connect.py +++ b/tests/projects/balancer_app/pages/retry_connect.py @@ -11,7 +11,7 @@ @plain_router.get('/retry_connect', cls=PageHandler) async def get_page(handler: PageHandler = get_current_handler()) -> None: - upstreams = handler.application.upstream_manager.get_upstreams() + upstreams = handler.application.service_discovery.get_upstreams_unsafe() upstreams['retry_connect'] = Upstream( 'retry_connect', {}, diff --git a/tests/projects/balancer_app/pages/retry_connect_timeout.py b/tests/projects/balancer_app/pages/retry_connect_timeout.py index d80bdd8a9..9f696bc4c 100644 --- a/tests/projects/balancer_app/pages/retry_connect_timeout.py +++ b/tests/projects/balancer_app/pages/retry_connect_timeout.py @@ -11,7 +11,7 @@ @plain_router.get('/retry_connect_timeout', cls=PageHandler) async def get_page(handler: PageHandler = get_current_handler()) -> None: - upstreams = handler.application.upstream_manager.get_upstreams() + upstreams = handler.application.service_discovery.get_upstreams_unsafe() upstreams['retry_connect_timeout'] = Upstream('retry_connect_timeout', {}, [get_server(handler, 'normal')]) handler.text = '' diff --git a/tests/projects/balancer_app/pages/retry_count_limit.py b/tests/projects/balancer_app/pages/retry_count_limit.py index fabc3b52f..f5357c20e 100644 --- a/tests/projects/balancer_app/pages/retry_count_limit.py +++ b/tests/projects/balancer_app/pages/retry_count_limit.py @@ -19,7 +19,7 @@ async def get_page(handler=get_current_handler()): ], ) - upstreams = handler.application.upstream_manager.get_upstreams() + upstreams = handler.application.service_discovery.get_upstreams_unsafe() upstreams['retry_count_limit'] = upstream handler.text = '' diff --git a/tests/projects/balancer_app/pages/retry_error.py b/tests/projects/balancer_app/pages/retry_error.py index 061bde4a6..69c53174e 100644 --- a/tests/projects/balancer_app/pages/retry_error.py +++ b/tests/projects/balancer_app/pages/retry_error.py @@ -11,7 +11,7 @@ @plain_router.get('/retry_error', cls=PageHandler) async def get_page(handler=get_current_handler()): - upstreams = handler.application.upstream_manager.get_upstreams() + upstreams = handler.application.service_discovery.get_upstreams_unsafe() upstreams['retry_error'] = Upstream( 'retry_error', {}, [get_server(handler, 'broken'), get_server(handler, 'normal')] ) diff --git a/tests/projects/balancer_app/pages/retry_non_idempotent_503.py b/tests/projects/balancer_app/pages/retry_non_idempotent_503.py index 98b1866d9..59d9a048b 100644 --- a/tests/projects/balancer_app/pages/retry_non_idempotent_503.py +++ b/tests/projects/balancer_app/pages/retry_non_idempotent_503.py @@ -12,7 +12,7 @@ @plain_router.get('/retry_non_idempotent_503', cls=PageHandler) async def get_page(handler=get_current_handler()): upstream_config = {Upstream.DEFAULT_PROFILE: UpstreamConfig(retry_policy={503: {'idempotent': 'true'}})} - upstreams = handler.application.upstream_manager.get_upstreams() + upstreams = handler.application.service_discovery.get_upstreams_unsafe() upstreams['retry_non_idempotent_503'] = Upstream( 'retry_non_idempotent_503', upstream_config, diff --git a/tests/projects/balancer_app/pages/retry_on_timeout.py b/tests/projects/balancer_app/pages/retry_on_timeout.py index 6615944c1..fa43f0672 100644 --- a/tests/projects/balancer_app/pages/retry_on_timeout.py +++ b/tests/projects/balancer_app/pages/retry_on_timeout.py @@ -10,7 +10,7 @@ @plain_router.get('/retry_on_timeout', cls=PageHandler) async def get_page(handler=get_current_handler()): - upstreams = handler.application.upstream_manager.get_upstreams() + upstreams = handler.application.service_discovery.get_upstreams_unsafe() upstreams['retry_on_timeout'] = Upstream( 'retry_on_timeout', {}, diff --git a/tests/projects/balancer_app/pages/slow_start.py b/tests/projects/balancer_app/pages/slow_start.py index aced5a117..ecf4909d0 100644 --- a/tests/projects/balancer_app/pages/slow_start.py +++ b/tests/projects/balancer_app/pages/slow_start.py @@ -20,7 +20,7 @@ async def get_page(handler=get_current_handler()): server_slow_start = Server('127.0.0.1:12345', 'dest_host', weight=5, dc='Test') upstream_config = {Upstream.DEFAULT_PROFILE: UpstreamConfig(slow_start_interval=0.1)} - upstreams = handler.application.upstream_manager.get_upstreams() + upstreams = handler.application.service_discovery.get_upstreams_unsafe() upstreams['slow_start'] = Upstream('slow_start', upstream_config, [server]) handler.text = '' @@ -46,5 +46,5 @@ async def get_page(handler=get_current_handler()): @plain_router.post('/slow_start', cls=PageHandler) async def post_page(handler=get_current_handler()): handler.set_header('Content-Type', media_types.TEXT_PLAIN) - upstreams = handler.application.upstream_manager.get_upstreams() + upstreams = handler.application.service_discovery.get_upstreams_unsafe() handler.text = str(upstreams['slow_start'].servers[0].stat_requests) diff --git a/tests/projects/balancer_app/pages/speculative_no_retry.py b/tests/projects/balancer_app/pages/speculative_no_retry.py index a443fb4a3..abf179932 100644 --- a/tests/projects/balancer_app/pages/speculative_no_retry.py +++ b/tests/projects/balancer_app/pages/speculative_no_retry.py @@ -8,7 +8,7 @@ @plain_router.get('/speculative_no_retry', cls=PageHandler) async def get_page(handler=get_current_handler()): - upstreams = handler.application.upstream_manager.get_upstreams() + upstreams = handler.application.service_discovery.get_upstreams_unsafe() upstreams['speculative_no_retry'] = Upstream( 'speculative_no_retry', {}, diff --git a/tests/projects/balancer_app/pages/speculative_retry.py b/tests/projects/balancer_app/pages/speculative_retry.py index 9c69a978b..98dc4c5cf 100644 --- a/tests/projects/balancer_app/pages/speculative_retry.py +++ b/tests/projects/balancer_app/pages/speculative_retry.py @@ -9,7 +9,7 @@ @plain_router.get('/speculative_retry', cls=PageHandler) async def get_page(handler=get_current_handler()): - upstreams = handler.application.upstream_manager.get_upstreams() + upstreams = handler.application.service_discovery.get_upstreams_unsafe() upstreams['speculative_retry'] = Upstream( 'speculative_retry', {}, diff --git a/tests/projects/test_app/__init__.py b/tests/projects/test_app/__init__.py index 13cf0c87e..f7b3c6810 100644 --- a/tests/projects/test_app/__init__.py +++ b/tests/projects/test_app/__init__.py @@ -13,7 +13,6 @@ def __init__(self): async def init(self): await super().init() - self.http_client_factory.request_engine_builder.kafka_producer = TestKafkaProducer() def application_config(self): diff --git a/tests/test_process_fork.py b/tests/test_process_fork.py index fd1876608..58f7702c4 100644 --- a/tests/test_process_fork.py +++ b/tests/test_process_fork.py @@ -1,8 +1,10 @@ import asyncio import contextlib +import pickle import time from ctypes import c_bool, c_int from multiprocessing import Lock, Queue, Value +from typing import Callable, Optional from http_client import options as http_client_options from http_client.balancing import Server, Upstream, UpstreamConfig @@ -10,7 +12,7 @@ import frontik.process from frontik.options import options from frontik.process import WorkerState, fork_workers -from frontik.service_discovery import UpstreamManager +from frontik.service_discovery import WorkerServiceDiscovery async def worker_teardown(worker_exit_event): @@ -40,6 +42,10 @@ def prepare_upstreams(): } +def get_upstream_bytes(service_discovery): + return pickle.dumps(list(service_discovery.get_upstreams_unsafe().values())) + + def noop(*_args, **__kwargs): pass @@ -61,39 +67,32 @@ def test_pipe_buffer_overflow(self): upstreams = prepare_upstreams() num_workers = 1 worker_state = WorkerState(Value(c_int, num_workers), Value(c_bool, 0), Lock()) - control_master_state = { - 'shared_data': None, - 'worker_state': worker_state, - 'upstream_cache': None, - 'master_func_done': Queue(1), - } - control_worker_state = {'shared_data': Queue(1), 'enable_listener': Queue(1)} - worker_exit_event = Queue(1) - - def master_function(shared_data, upstreams_lock, send_to_all_workers): - shared_data.update(upstreams) - upstream_cache = UpstreamManager(shared_data, None, upstreams_lock, send_to_all_workers, False, 'test') - control_master_state['shared_data'] = shared_data - control_master_state['upstream_cache'] = upstream_cache - control_master_state['master_func_done'].put(True) + service_discovery = WorkerServiceDiscovery({}) + send_updates_hook: Optional[Callable] = None + + worker_queues = {'shared_data': Queue(1), 'enable_listener': Queue(1), 'worker_exit_event': Queue(1)} + + def master_after_fork_action(hook): + nonlocal send_updates_hook + send_updates_hook = hook def worker_function(): - control_master_state['worker_state'].init_workers_count_down.value -= 1 + worker_state.init_workers_count_down.value -= 1 loop = asyncio.get_event_loop() - loop.create_task(worker_teardown(worker_exit_event)) + loop.create_task(worker_teardown(worker_queues['worker_exit_event'])) loop.run_forever() - def worker_listener_handler(shared_data): - if control_worker_state['shared_data'].full(): - control_worker_state['shared_data'].get() - control_worker_state['shared_data'].put(shared_data) + def worker_listener_handler(shared_upstreams): + if worker_queues['shared_data'].full(): + worker_queues['shared_data'].get() + worker_queues['shared_data'].put(shared_upstreams) # Case 1: no worker reads, make write overflow async def delayed_listener(*_args, **__kwargs): while True: await asyncio.sleep(0) with contextlib.suppress(Exception): - control_worker_state['enable_listener'].get_nowait() + worker_queues['enable_listener'].get_nowait() break await self._orig_listener(*_args, **__kwargs) @@ -102,38 +101,38 @@ async def delayed_listener(*_args, **__kwargs): fork_workers( worker_state=worker_state, num_workers=num_workers, - worker_function=worker_function, - master_function=master_function, + master_before_fork_action=lambda: ({}, Lock()), + master_after_fork_action=master_after_fork_action, master_before_shutdown_action=lambda: None, + worker_function=worker_function, worker_listener_handler=worker_listener_handler, ) - if not worker_state.is_master: + if not worker_state.is_master: # when worker stopes it should exit return - control_master_state['master_func_done'].get(timeout=1) + assert send_updates_hook is not None + service_discovery.get_upstreams_unsafe().update(upstreams) for _i in range(500): - control_master_state['upstream_cache'].send_updates() - - resend_dict = control_master_state['worker_state'].resend_dict + send_updates_hook(get_upstream_bytes(service_discovery)) - assert bool(resend_dict), 'resend dict should not be empty' + assert bool(worker_state.resend_dict), 'resend dict should not be empty' # Case 2: wake up worker listener, check shared data is correct - control_worker_state['enable_listener'].put(True) + worker_queues['enable_listener'].put(True) time.sleep(1) - worker_shared_data = control_worker_state['shared_data'].get(timeout=2) + worker_shared_data = worker_queues['shared_data'].get(timeout=2) assert len(worker_shared_data) == 1, 'upstreams size on master and worker should be the same' # Case 3: add new upstream, check worker get it - control_master_state['shared_data']['upstream2'] = Upstream( + service_discovery.get_upstreams_unsafe()['upstream2'] = Upstream( 'upstream2', {}, [Server('12.2.3.5', 'dest_host'), Server('12.22.3.5', 'dest_host')], ) - control_master_state['upstream_cache'].send_updates() - control_master_state['upstream_cache'].send_updates() + send_updates_hook(get_upstream_bytes(service_discovery)) + send_updates_hook(get_upstream_bytes(service_discovery)) time.sleep(1) - worker_shared_data = control_worker_state['shared_data'].get(timeout=2) + worker_shared_data = worker_queues['shared_data'].get(timeout=2) assert len(worker_shared_data) == 2, 'upstreams size on master and worker should be the same' - worker_exit_event.put(True) + worker_queues['worker_exit_event'].put(True) diff --git a/tests/test_telemetry.py b/tests/test_telemetry.py index b4111ef0a..3ee5aec35 100644 --- a/tests/test_telemetry.py +++ b/tests/test_telemetry.py @@ -80,10 +80,10 @@ async def get_page_b(handler=get_current_handler()): handler.json.put({}) -def make_otel_provider() -> TracerProvider: +def make_otel_provider(service_name: str) -> TracerProvider: resource = Resource( attributes={ - ResourceAttributes.SERVICE_NAME: options.service_name, # type: ignore + ResourceAttributes.SERVICE_NAME: service_name, ResourceAttributes.SERVICE_VERSION: '1.2.3', ResourceAttributes.HOST_NAME: options.node_name, ResourceAttributes.CLOUD_REGION: 'test', @@ -124,13 +124,14 @@ def frontik_app(self) -> FrontikApplication: options.opentelemetry_enabled = True options.opentelemetry_sampler_ratio = 1 + app = FrontikApplication() + test_exporter = TestExporter() - provider = make_otel_provider() + provider = make_otel_provider(app.app_name) batch_span_processor = BatchSpanProcessor(test_exporter) provider.add_span_processor(batch_span_processor) trace.set_tracer_provider(provider) - app = FrontikApplication() BATCH_SPAN_PROCESSOR.append(batch_span_processor) return app diff --git a/tests/test_upstream_caches.py b/tests/test_upstream_caches.py index 75254e393..c3958bfa9 100644 --- a/tests/test_upstream_caches.py +++ b/tests/test_upstream_caches.py @@ -1,8 +1,20 @@ +from threading import Lock +from typing import Callable, Optional + from http_client import options as http_client_options -from frontik.integrations.statsd import StatsDClientStub from frontik.options import options -from frontik.service_discovery import UpstreamManager +from frontik.service_discovery import MasterServiceDiscovery + + +class StubServiceDiscovery(MasterServiceDiscovery): + def __init__(self) -> None: + self._upstreams_config: dict = {} + self._upstreams_servers: dict = {} + + self._upstreams = {} + self._upstreams_lock = Lock() + self._send_to_all_workers: Optional[Callable] = None class TestUpstreamCaches: @@ -43,12 +55,12 @@ def test_update_upstreams_servers_different_dc(self) -> None: }, ] - upstream_cache = UpstreamManager({}, StatsDClientStub(), None, None, False, 'test') - upstream_cache._update_upstreams_service('app', value_one_dc) - upstream_cache._update_upstreams_service('app', value_another_dc) + service_discovery = StubServiceDiscovery() + service_discovery._update_upstreams_service('app', value_one_dc) + service_discovery._update_upstreams_service('app', value_another_dc) - assert len(upstream_cache._upstreams_servers) == 2 - assert len(upstream_cache._upstreams['app'].servers) == 2 + assert len(service_discovery._upstreams_servers) == 2 + assert len(service_discovery.get_upstreams_unsafe()['app'].servers) == 2 def test_update_upstreams_servers_same_dc(self) -> None: options.upstreams = ['app'] @@ -66,12 +78,12 @@ def test_update_upstreams_servers_same_dc(self) -> None: }, ] - upstream_cache = UpstreamManager({}, StatsDClientStub(), None, None, False, 'test') - upstream_cache._update_upstreams_service('app', value_one_dc) - upstream_cache._update_upstreams_service('app', value_one_dc) + service_discovery = StubServiceDiscovery() + service_discovery._update_upstreams_service('app', value_one_dc) + service_discovery._update_upstreams_service('app', value_one_dc) - assert len(upstream_cache._upstreams_servers) == 1 - assert len(upstream_cache._upstreams['app'].servers) == 1 + assert len(service_discovery._upstreams_servers) == 1 + assert len(service_discovery.get_upstreams_unsafe()['app'].servers) == 1 def test_multiple_update_upstreams_servers_different_dc(self) -> None: options.upstreams = ['app'] @@ -102,14 +114,14 @@ def test_multiple_update_upstreams_servers_different_dc(self) -> None: }, ] - upstream_cache = UpstreamManager({}, StatsDClientStub(), None, None, False, 'test') - upstream_cache._update_upstreams_service('app', value_one_dc) - upstream_cache._update_upstreams_service('app', value_another_dc) - upstream_cache._update_upstreams_service('app', value_another_dc) - upstream_cache._update_upstreams_service('app', value_one_dc) + service_discovery = StubServiceDiscovery() + service_discovery._update_upstreams_service('app', value_one_dc) + service_discovery._update_upstreams_service('app', value_another_dc) + service_discovery._update_upstreams_service('app', value_another_dc) + service_discovery._update_upstreams_service('app', value_one_dc) - assert len(upstream_cache._upstreams_servers) == 2 - assert len(upstream_cache._upstreams['app'].servers) == 2 + assert len(service_discovery._upstreams_servers) == 2 + assert len(service_discovery.get_upstreams_unsafe()['app'].servers) == 2 def test_remove_upstreams_servers_different_dc(self) -> None: options.upstreams = ['app'] @@ -163,17 +175,20 @@ def test_remove_upstreams_servers_different_dc(self) -> None: }, ] - upstream_cache = UpstreamManager({}, StatsDClientStub(), None, None, False, 'test') - upstream_cache._update_upstreams_service('app', value_test_dc) - upstream_cache._update_upstreams_service('app', value_another_dc) + service_discovery = StubServiceDiscovery() + service_discovery._update_upstreams_service('app', value_test_dc) + service_discovery._update_upstreams_service('app', value_another_dc) - assert len(upstream_cache._upstreams_servers['app-test']) == 1 - assert len(upstream_cache._upstreams_servers['app-another']) == 2 - assert len(upstream_cache._upstreams['app'].servers) == 3 + assert len(service_discovery._upstreams_servers['app-test']) == 1 + assert len(service_discovery._upstreams_servers['app-another']) == 2 + assert len(service_discovery.get_upstreams_unsafe()['app'].servers) == 3 - upstream_cache._update_upstreams_service('app', value_another_remove_service_dc) + service_discovery._update_upstreams_service('app', value_another_remove_service_dc) - assert len(upstream_cache._upstreams_servers['app-another']) == 1 - assert upstream_cache._upstreams_servers['app-another'][0].address == '2.2.2.2:9999' - assert len(upstream_cache._upstreams['app'].servers) == 3 - assert len([server for server in upstream_cache._upstreams['app'].servers if server is not None]) == 2 + assert len(service_discovery._upstreams_servers['app-another']) == 1 + assert service_discovery._upstreams_servers['app-another'][0].address == '2.2.2.2:9999' + assert len(service_discovery.get_upstreams_unsafe()['app'].servers) == 3 + assert ( + len([server for server in service_discovery.get_upstreams_unsafe()['app'].servers if server is not None]) + == 2 + )