Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix cache call signature to accept on_invalidate. #8684

Merged
merged 6 commits into from
Oct 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8684.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix typing info on cache call signature to accept `on_invalidate`.
38 changes: 27 additions & 11 deletions scripts-dev/mypy_synapse_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@

from typing import Callable, Optional

from mypy.nodes import ARG_NAMED_OPT
from mypy.plugin import MethodSigContext, Plugin
from mypy.typeops import bind_self
from mypy.types import CallableType
from mypy.types import CallableType, NoneType


class SynapsePlugin(Plugin):
Expand All @@ -40,8 +41,9 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:

It already has *almost* the correct signature, except:

1. the `self` argument needs to be marked as "bound"; and
2. any `cache_context` argument should be removed.
1. the `self` argument needs to be marked as "bound";
2. any `cache_context` argument should be removed;
3. an optional keyword argument `on_invalidated` should be added.
"""

# First we mark this as a bound function signature.
Expand All @@ -58,19 +60,33 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
context_arg_index = idx
break

arg_types = list(signature.arg_types)
arg_names = list(signature.arg_names)
arg_kinds = list(signature.arg_kinds)

if context_arg_index:
arg_types = list(signature.arg_types)
arg_types.pop(context_arg_index)

arg_names = list(signature.arg_names)
arg_names.pop(context_arg_index)

arg_kinds = list(signature.arg_kinds)
arg_kinds.pop(context_arg_index)

signature = signature.copy_modified(
arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds,
)
# Third, we add an optional "on_invalidate" argument.
#
# This is a callable which accepts no input and returns nothing.
calltyp = CallableType(
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
arg_types=[],
arg_kinds=[],
arg_names=[],
ret_type=NoneType(),
fallback=ctx.api.named_generic_type("builtins.function", []),
)

arg_types.append(calltyp)
arg_names.append("on_invalidate")
arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg.

signature = signature.copy_modified(
arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds,
)

return signature

Expand Down
12 changes: 7 additions & 5 deletions synapse/handlers/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@

MYPY = False
if MYPY:
import synapse.server
from synapse.server import HomeServer

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -101,7 +101,7 @@
class BasePresenceHandler(abc.ABC):
"""Parts of the PresenceHandler that are shared between workers and master"""

def __init__(self, hs: "synapse.server.HomeServer"):
def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.store = hs.get_datastore()

Expand Down Expand Up @@ -199,7 +199,7 @@ async def bump_presence_active_time(self, user: UserID):


class PresenceHandler(BasePresenceHandler):
def __init__(self, hs: "synapse.server.HomeServer"):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self.is_mine_id = hs.is_mine_id
Expand Down Expand Up @@ -1011,7 +1011,7 @@ def format_user_presence_state(state, now, include_user_id=True):


class PresenceEventSource:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
# We can't call get_presence_handler here because there's a cycle:
#
# Presence -> Notifier -> PresenceEventSource -> Presence
Expand Down Expand Up @@ -1071,12 +1071,14 @@ async def get_new_events(

users_interested_in = await self._get_interested_in(user, explicit_room_id)

user_ids_changed = set()
user_ids_changed = set() # type: Collection[str]
changed = None
if from_key:
changed = stream_change_cache.get_all_entities_changed(from_key)

if changed is not None and len(changed) < 500:
assert isinstance(user_ids_changed, set)

# For small deltas, its quicker to get all changes and then
# work out if we share a room or they're in our presence list
get_updates_counter.labels("stream").inc()
Expand Down