diff --git a/changelog.d/11556.misc b/changelog.d/11556.misc new file mode 100644 index 000000000000..53b26aa676b4 --- /dev/null +++ b/changelog.d/11556.misc @@ -0,0 +1 @@ +Add missing type hints to `synapse.logging.context`. diff --git a/mypy.ini b/mypy.ini index 4551302c8292..186732204438 100644 --- a/mypy.ini +++ b/mypy.ini @@ -167,6 +167,9 @@ disallow_untyped_defs = True [mypy-synapse.http.server] disallow_untyped_defs = True +[mypy-synapse.logging.context] +disallow_untyped_defs = True + [mypy-synapse.metrics.*] disallow_untyped_defs = True diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index 4ff3c6de5feb..429234d7ae7f 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -17,11 +17,12 @@ from typing import Any, List, Optional, Type, Union from twisted.internet import protocol +from twisted.internet.defer import Deferred class RedisProtocol(protocol.Protocol): def publish(self, channel: str, message: bytes): ... - async def ping(self) -> None: ... - async def set( + def ping(self) -> "Deferred[None]": ... + def set( self, key: str, value: Any, @@ -29,8 +30,8 @@ class RedisProtocol(protocol.Protocol): pexpire: Optional[int] = None, only_if_not_exists: bool = False, only_if_exists: bool = False, - ) -> None: ... - async def get(self, key: str) -> Any: ... + ) -> "Deferred[None]": ... + def get(self, key: str) -> "Deferred[Any]": ... class SubscriberProtocol(RedisProtocol): def __init__(self, *args, **kwargs): ... diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 8e37e76206ac..cf067b56c6b4 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -30,7 +30,6 @@ from prometheus_client import Counter, Gauge, Histogram -from twisted.internet import defer from twisted.internet.abstract import isIPAddress from twisted.python import failure @@ -67,7 +66,7 @@ from synapse.storage.databases.main.lock import Lock from synapse.types import JsonDict, get_domain_from_id from synapse.util import glob_to_regex, json_decoder, unwrapFirstError -from synapse.util.async_helpers import Linearizer, concurrently_execute +from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import parse_server_name @@ -360,13 +359,13 @@ async def _handle_incoming_transaction( # want to block things like to device messages from reaching clients # behind the potentially expensive handling of PDUs. pdu_results, _ = await make_deferred_yieldable( - defer.gatherResults( - [ + gather_results( + ( run_in_background( self._handle_pdus_in_txn, origin, transaction, request_time ), run_in_background(self._handle_edus_in_txn, origin, transaction), - ], + ), consumeErrors=True, ).addErrback(unwrapFirstError) ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 1ea837d08211..26b8e3f43c40 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -360,31 +360,34 @@ async def try_backfill(domains: List[str]) -> bool: logger.debug("calling resolve_state_groups in _maybe_backfill") resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events) - states = await make_deferred_yieldable( + states_list = await make_deferred_yieldable( defer.gatherResults( [resolve(room_id, [e]) for e in event_ids], consumeErrors=True ) ) - # dict[str, dict[tuple, str]], a map from event_id to state map of - # event_ids. - states = dict(zip(event_ids, [s.state for s in states])) + # A map from event_id to state map of event_ids. + state_ids: Dict[str, StateMap[str]] = dict( + zip(event_ids, [s.state for s in states_list]) + ) state_map = await self.store.get_events( - [e_id for ids in states.values() for e_id in ids.values()], + [e_id for ids in state_ids.values() for e_id in ids.values()], get_prev_content=False, ) - states = { + + # A map from event_id to state map of events. + state_events: Dict[str, StateMap[EventBase]] = { key: { k: state_map[e_id] for k, e_id in state_dict.items() if e_id in state_map } - for key, state_dict in states.items() + for key, state_dict in state_ids.items() } for e_id in event_ids: - likely_extremeties_domains = get_domains_from_state(states[e_id]) + likely_extremeties_domains = get_domains_from_state(state_events[e_id]) success = await try_backfill( [ diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 9cd21e7f2b3c..9ab723ff975e 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -13,21 +13,27 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, List, Optional, Tuple - -from twisted.internet import defer +from typing import TYPE_CHECKING, List, Optional, Tuple, cast from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.errors import SynapseError +from synapse.events import EventBase from synapse.events.validator import EventValidator from synapse.handlers.presence import format_user_presence_state from synapse.handlers.receipts import ReceiptEventSource from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage.roommember import RoomsForUser from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, Requester, RoomStreamToken, StreamToken, UserID +from synapse.types import ( + JsonDict, + Requester, + RoomStreamToken, + StateMap, + StreamToken, + UserID, +) from synapse.util import unwrapFirstError -from synapse.util.async_helpers import concurrently_execute +from synapse.util.async_helpers import concurrently_execute, gather_results from synapse.util.caches.response_cache import ResponseCache from synapse.visibility import filter_events_for_client @@ -190,14 +196,13 @@ async def handle_room(event: RoomsForUser) -> None: ) deferred_room_state = run_in_background( self.state_store.get_state_for_events, [event.event_id] - ) - deferred_room_state.addCallback( - lambda states: states[event.event_id] + ).addCallback( + lambda states: cast(StateMap[EventBase], states[event.event_id]) ) (messages, token), current_state = await make_deferred_yieldable( - defer.gatherResults( - [ + gather_results( + ( run_in_background( self.store.get_recent_events_for_room, event.room_id, @@ -205,7 +210,7 @@ async def handle_room(event: RoomsForUser) -> None: end_token=room_end_token, ), deferred_room_state, - ] + ) ) ).addErrback(unwrapFirstError) @@ -454,8 +459,8 @@ async def get_receipts() -> List[JsonDict]: return receipts presence, receipts, (messages, token) = await make_deferred_yieldable( - defer.gatherResults( - [ + gather_results( + ( run_in_background(get_presence), run_in_background(get_receipts), run_in_background( @@ -464,7 +469,7 @@ async def get_receipts() -> List[JsonDict]: limit=limit, end_token=now_token.room_key, ), - ], + ), consumeErrors=True, ).addErrback(unwrapFirstError) ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 38409fef38d9..5e3d3886eb1d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -21,7 +21,6 @@ from canonicaljson import encode_canonical_json -from twisted.internet import defer from twisted.internet.interfaces import IDelayedCall from synapse import event_auth @@ -57,7 +56,7 @@ from synapse.storage.state import StateFilter from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester from synapse.util import json_decoder, json_encoder, log_failure -from synapse.util.async_helpers import Linearizer, unwrapFirstError +from synapse.util.async_helpers import Linearizer, gather_results, unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import measure_func from synapse.visibility import filter_events_for_client @@ -1168,9 +1167,9 @@ async def handle_new_client_event( # We now persist the event (and update the cache in parallel, since we # don't want to block on it). - result = await make_deferred_yieldable( - defer.gatherResults( - [ + result, _ = await make_deferred_yieldable( + gather_results( + ( run_in_background( self._persist_event, requester=requester, @@ -1182,12 +1181,12 @@ async def handle_new_client_event( run_in_background( self.cache_joined_hosts_for_event, event, context ).addErrback(log_failure, "cache_joined_hosts_for_event failed"), - ], + ), consumeErrors=True, ) ).addErrback(unwrapFirstError) - return result[0] + return result async def _persist_event( self, diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 1238bfd28726..a8a520f80944 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -25,6 +25,7 @@ from twisted.internet import defer from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.interfaces import ( + IProtocol, IProtocolFactory, IReactorCore, IStreamClientEndpoint, @@ -309,12 +310,14 @@ def __init__( self._srv_resolver = srv_resolver - def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred: + def connect( + self, protocol_factory: IProtocolFactory + ) -> "defer.Deferred[IProtocol]": """Implements IStreamClientEndpoint interface""" return run_in_background(self._do_connect, protocol_factory) - async def _do_connect(self, protocol_factory: IProtocolFactory) -> None: + async def _do_connect(self, protocol_factory: IProtocolFactory) -> IProtocol: first_exception = None server_list = await self._resolve_server() diff --git a/synapse/logging/context.py b/synapse/logging/context.py index d8ae3188b7da..25e78cc82fcd 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -22,20 +22,33 @@ See doc/log_contexts.rst for details on how this works. """ -import inspect import logging import threading import typing import warnings -from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Optional, + Tuple, + Type, + TypeVar, + Union, + overload, +) import attr from typing_extensions import Literal from twisted.internet import defer, threads +from twisted.python.threadpool import ThreadPool if TYPE_CHECKING: from synapse.logging.scopecontextmanager import _LogContextScope + from synapse.types import ISynapseReactor logger = logging.getLogger(__name__) @@ -66,7 +79,7 @@ def get_thread_resource_usage() -> "Optional[resource.struct_rusage]": # a hook which can be set during testing to assert that we aren't abusing logcontexts. -def logcontext_error(msg: str): +def logcontext_error(msg: str) -> None: logger.warning(msg) @@ -223,22 +236,19 @@ def __init__(self) -> None: def __str__(self) -> str: return "sentinel" - def copy_to(self, record): - pass - - def start(self, rusage: "Optional[resource.struct_rusage]"): + def start(self, rusage: "Optional[resource.struct_rusage]") -> None: pass - def stop(self, rusage: "Optional[resource.struct_rusage]"): + def stop(self, rusage: "Optional[resource.struct_rusage]") -> None: pass - def add_database_transaction(self, duration_sec): + def add_database_transaction(self, duration_sec: float) -> None: pass - def add_database_scheduled(self, sched_sec): + def add_database_scheduled(self, sched_sec: float) -> None: pass - def record_event_fetch(self, event_count): + def record_event_fetch(self, event_count: int) -> None: pass def __bool__(self) -> Literal[False]: @@ -379,7 +389,12 @@ def __enter__(self) -> "LoggingContext": ) return self - def __exit__(self, type, value, traceback) -> None: + def __exit__( + self, + type: Optional[Type[BaseException]], + value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: """Restore the logging context in thread local storage to the state it was before this context was entered. Returns: @@ -399,17 +414,6 @@ def __exit__(self, type, value, traceback) -> None: # recorded against the correct metrics. self.finished = True - def copy_to(self, record) -> None: - """Copy logging fields from this context to a log record or - another LoggingContext - """ - - # we track the current request - record.request = self.request - - # we also track the current scope: - record.scope = self.scope - def start(self, rusage: "Optional[resource.struct_rusage]") -> None: """ Record that this logcontext is currently running. @@ -626,7 +630,12 @@ def __init__( def __enter__(self) -> None: self._old_context = set_current_context(self._new_context) - def __exit__(self, type, value, traceback) -> None: + def __exit__( + self, + type: Optional[Type[BaseException]], + value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: context = set_current_context(self._old_context) if context != self._new_context: @@ -711,16 +720,61 @@ def nested_logging_context(suffix: str) -> LoggingContext: ) -def preserve_fn(f): +R = TypeVar("R") + + +@overload +def preserve_fn( # type: ignore[misc] + f: Callable[..., Awaitable[R]], +) -> Callable[..., "defer.Deferred[R]"]: + # The `type: ignore[misc]` above suppresses + # "Overloaded function signatures 1 and 2 overlap with incompatible return types" + ... + + +@overload +def preserve_fn(f: Callable[..., R]) -> Callable[..., "defer.Deferred[R]"]: + ... + + +def preserve_fn( + f: Union[ + Callable[..., R], + Callable[..., Awaitable[R]], + ] +) -> Callable[..., "defer.Deferred[R]"]: """Function decorator which wraps the function with run_in_background""" - def g(*args, **kwargs): + def g(*args: Any, **kwargs: Any) -> "defer.Deferred[R]": return run_in_background(f, *args, **kwargs) return g -def run_in_background(f, *args, **kwargs) -> defer.Deferred: +@overload +def run_in_background( # type: ignore[misc] + f: Callable[..., Awaitable[R]], *args: Any, **kwargs: Any +) -> "defer.Deferred[R]": + # The `type: ignore[misc]` above suppresses + # "Overloaded function signatures 1 and 2 overlap with incompatible return types" + ... + + +@overload +def run_in_background( + f: Callable[..., R], *args: Any, **kwargs: Any +) -> "defer.Deferred[R]": + ... + + +def run_in_background( + f: Union[ + Callable[..., R], + Callable[..., Awaitable[R]], + ], + *args: Any, + **kwargs: Any, +) -> "defer.Deferred[R]": """Calls a function, ensuring that the current context is restored after return from the function, and that the sentinel context is set once the deferred returned by the function completes. @@ -751,6 +805,10 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred: # At this point we should have a Deferred, if not then f was a synchronous # function, wrap it in a Deferred for consistency. if not isinstance(res, defer.Deferred): + # `res` is not a `Deferred` and not a `Coroutine`. + # There are no other types of `Awaitable`s we expect to encounter in Synapse. + assert not isinstance(res, Awaitable) + return defer.succeed(res) if res.called and not res.paused: @@ -778,13 +836,14 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred: return res -def make_deferred_yieldable(deferred): - """Given a deferred (or coroutine), make it follow the Synapse logcontext - rules: +T = TypeVar("T") + - If the deferred has completed (or is not actually a Deferred), essentially - does nothing (just returns another completed deferred with the - result/failure). +def make_deferred_yieldable(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": + """Given a deferred, make it follow the Synapse logcontext rules: + + If the deferred has completed, essentially does nothing (just returns another + completed deferred with the result/failure). If the deferred has not yet completed, resets the logcontext before returning a deferred. Then, when the deferred completes, restores the @@ -792,16 +851,6 @@ def make_deferred_yieldable(deferred): (This is more-or-less the opposite operation to run_in_background.) """ - if inspect.isawaitable(deferred): - # If we're given a coroutine we convert it to a deferred so that we - # run it and find out if it immediately finishes, it it does then we - # don't need to fiddle with log contexts at all and can return - # immediately. - deferred = defer.ensureDeferred(deferred) - - if not isinstance(deferred, defer.Deferred): - return deferred - if deferred.called and not deferred.paused: # it looks like this deferred is ready to run any callbacks we give it # immediately. We may as well optimise out the logcontext faffery. @@ -823,7 +872,9 @@ def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT: return result -def defer_to_thread(reactor, f, *args, **kwargs): +def defer_to_thread( + reactor: "ISynapseReactor", f: Callable[..., R], *args: Any, **kwargs: Any +) -> "defer.Deferred[R]": """ Calls the function `f` using a thread from the reactor's default threadpool and returns the result as a Deferred. @@ -855,7 +906,13 @@ def defer_to_thread(reactor, f, *args, **kwargs): return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs) -def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): +def defer_to_threadpool( + reactor: "ISynapseReactor", + threadpool: ThreadPool, + f: Callable[..., R], + *args: Any, + **kwargs: Any, +) -> "defer.Deferred[R]": """ A wrapper for twisted.internet.threads.deferToThreadpool, which handles logcontexts correctly. @@ -897,7 +954,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): assert isinstance(curr_context, LoggingContext) parent_context = curr_context - def g(): + def g() -> R: with LoggingContext(str(curr_context), parent_context=parent_context): return f(*args, **kwargs) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 20ce294209ad..bde99ea8787b 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -30,9 +30,11 @@ Iterator, Optional, Set, + Tuple, TypeVar, Union, cast, + overload, ) import attr @@ -234,6 +236,59 @@ def yieldable_gather_results( ).addErrback(unwrapFirstError) +T1 = TypeVar("T1") +T2 = TypeVar("T2") +T3 = TypeVar("T3") + + +@overload +def gather_results( + deferredList: Tuple[()], consumeErrors: bool = ... +) -> "defer.Deferred[Tuple[()]]": + ... + + +@overload +def gather_results( + deferredList: Tuple["defer.Deferred[T1]"], + consumeErrors: bool = ..., +) -> "defer.Deferred[Tuple[T1]]": + ... + + +@overload +def gather_results( + deferredList: Tuple["defer.Deferred[T1]", "defer.Deferred[T2]"], + consumeErrors: bool = ..., +) -> "defer.Deferred[Tuple[T1, T2]]": + ... + + +@overload +def gather_results( + deferredList: Tuple[ + "defer.Deferred[T1]", "defer.Deferred[T2]", "defer.Deferred[T3]" + ], + consumeErrors: bool = ..., +) -> "defer.Deferred[Tuple[T1, T2, T3]]": + ... + + +def gather_results( # type: ignore[misc] + deferredList: Tuple["defer.Deferred[T1]", ...], + consumeErrors: bool = False, +) -> "defer.Deferred[Tuple[T1, ...]]": + """Combines a tuple of `Deferred`s into a single `Deferred`. + + Wraps `defer.gatherResults` to provide type annotations that support heterogenous + lists of `Deferred`s. + """ + # The `type: ignore[misc]` above suppresses + # "Overloaded function implementation cannot produce return type of signature 1/2/3" + deferred = defer.gatherResults(deferredList, consumeErrors=consumeErrors) + return deferred.addCallback(tuple) + + @attr.s(slots=True) class _LinearizerEntry: # The number of things executing. @@ -352,7 +407,7 @@ def _await_lock(self, key: Hashable) -> defer.Deferred: logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key) - new_defer = make_deferred_yieldable(defer.Deferred()) + new_defer: "defer.Deferred[None]" = make_deferred_yieldable(defer.Deferred()) entry.deferreds[new_defer] = 1 def cb(_r: None) -> "defer.Deferred[None]": diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py index 470f4f91a59b..e325f44da328 100644 --- a/synapse/util/caches/cached_call.py +++ b/synapse/util/caches/cached_call.py @@ -76,6 +76,7 @@ async def get(self) -> TV: # Fire off the callable now if this is our first time if not self._deferred: + assert self._callable is not None self._deferred = run_in_background(self._callable) # we will never need the callable again, so make sure it can be GCed diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index de2adacd70dc..46771a401b50 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -142,6 +142,7 @@ def _writer(self) -> None: def wait(self) -> "Deferred[None]": """Returns a deferred that resolves when finished writing to file""" + assert self._finished_deferred is not None return make_deferred_yieldable(self._finished_deferred) def _resume_paused_producer(self) -> None: diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 5d9c4665aa58..621b0f9fcdf0 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -152,46 +152,11 @@ def test_make_deferred_yieldable_with_chained_deferreds(self): # now it should be restored self._check_test_key("one") - @defer.inlineCallbacks - def test_make_deferred_yieldable_on_non_deferred(self): - """Check that make_deferred_yieldable does the right thing when its - argument isn't actually a deferred""" - - with LoggingContext("one"): - d1 = make_deferred_yieldable("bum") - self._check_test_key("one") - - r = yield d1 - self.assertEqual(r, "bum") - self._check_test_key("one") - def test_nested_logging_context(self): with LoggingContext("foo"): nested_context = nested_logging_context(suffix="bar") self.assertEqual(nested_context.name, "foo-bar") - @defer.inlineCallbacks - def test_make_deferred_yieldable_with_await(self): - # an async function which returns an incomplete coroutine, but doesn't - # follow the synapse rules. - - async def blocking_function(): - d = defer.Deferred() - reactor.callLater(0, d.callback, None) - await d - - sentinel_context = current_context() - - with LoggingContext("one"): - d1 = make_deferred_yieldable(blocking_function()) - # make sure that the context was reset by make_deferred_yieldable - self.assertIs(current_context(), sentinel_context) - - yield d1 - - # now it should be restored - self._check_test_key("one") - # a function which returns a deferred which has been "called", but # which had a function which returned another incomplete deferred on