Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Implement account status endpoints (MSC3720) #12001

Merged
merged 17 commits into from
Feb 22, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12001.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement experimental support for [MSC3720](https://github.com/matrix-org/matrix-doc/pull/3720) (account status endpoints).
3 changes: 3 additions & 0 deletions synapse/config/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,6 @@ def read_config(self, config: JsonDict, **kwargs):

# MSC3706 (server-side support for partial state in /send_join responses)
self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False)

# MSC3720 (Account status endpoint)
self.msc3720_enabled: bool = experimental.get("msc3720_enabled", False)
49 changes: 49 additions & 0 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
RoomVersion,
RoomVersions,
)
from synapse.api.urls import FEDERATION_UNSTABLE_PREFIX
from synapse.events import EventBase, builder
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.transport.client import SendJoinResponse
Expand Down Expand Up @@ -1526,6 +1527,54 @@ async def timestamp_to_event(
except ValueError as e:
raise InvalidResponseError(str(e))

async def get_account_status(
self, destination: str, user_ids: List[str]
) -> Tuple[JsonDict, List[str]]:
"""Retrieves account statuses for a given list of users on a given remote
homeserver.

If the request fails for any reason, all user IDs for this destination are marked
as failed.

Args:
destination: the destination to contact
user_ids: the user ID(s) for which to request account status(es)

Returns:
The account statuses, as well as the list of user IDs for which it was not
possible to retrieve a status.
"""
try:
res = await self.transport_layer.make_query(
destination=destination,
query_type="account_status",
args={"user_id": user_ids},
retry_on_dns_fail=True,
prefix=FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3720",
)
except Exception:
# If the query failed for any reason, mark all the users as failed.
return {}, user_ids

statuses = res.get("account_statuses", {})
failures = res.get("failures", [])

if not isinstance(statuses, dict) or not isinstance(failures, list):
# Make sure we're not feeding back malformed data back to the caller.
logger.warning(
"Destination %s responded with malformed data to account_status query",
destination,
)
return {}, user_ids
squahtx marked this conversation as resolved.
Show resolved Hide resolved

for user_id in user_ids:
# Any account whose status is missing is a user we failed to receive the
# status of.
if user_id not in statuses:
failures.append(user_id)
babolivier marked this conversation as resolved.
Show resolved Hide resolved

return statuses, failures


@attr.s(frozen=True, slots=True, auto_attribs=True)
class TimestampToEventResponse:
Expand Down
3 changes: 2 additions & 1 deletion synapse/federation/transport/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,9 @@ async def make_query(
args: dict,
retry_on_dns_fail: bool,
ignore_backoff: bool = False,
prefix: str = FEDERATION_V1_PREFIX,
) -> JsonDict:
path = _create_v1_path("/query/%s", query_type)
path = _create_path(prefix, "/query/%s", query_type)

return await self.client.get_json(
destination=destination,
Expand Down
38 changes: 38 additions & 0 deletions synapse/federation/transport/server/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,43 @@ async def on_GET(
return 200, complexity


class AccountStatusServlet(BaseFederationServerServlet):
PATH = "/org.matrix.msc3720/query/account_status"
PREFIX = FEDERATION_UNSTABLE_PREFIX
babolivier marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
hs: "HomeServer",
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
):
super().__init__(hs, authenticator, ratelimiter, server_name)
self._account_handler = hs.get_account_handler()

async def on_GET(
self,
origin: str,
content: Literal[None],
query: Dict[bytes, List[bytes]],
) -> Tuple[int, JsonDict]:
# Handle MSC3720 account statuses requests.
# TODO: when the MSC has released into the spec, this handler should be moved
# to a query handler
if b"user_id" not in query:
raise SynapseError(
400, "Required parameter 'user_id' is missing", Codes.MISSING_PARAM
)

user_ids: List[bytes] = query[b"user_id"]
statuses, failures = await self._account_handler.get_account_statuses(
user_ids,
allow_remote=False,
)

return 200, {"account_statuses": statuses, "failures": failures}


FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationSendServlet,
FederationEventServlet,
Expand Down Expand Up @@ -797,4 +834,5 @@ async def on_GET(
FederationRoomHierarchyUnstableServlet,
FederationV1SendKnockServlet,
FederationMakeKnockServlet,
AccountStatusServlet,
babolivier marked this conversation as resolved.
Show resolved Hide resolved
)
145 changes: 145 additions & 0 deletions synapse/handlers/account.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Dict, List, Tuple

from synapse.api.errors import Codes, SynapseError
from synapse.types import JsonDict, UserID

if TYPE_CHECKING:
from synapse.server import HomeServer


class AccountHandler:
def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastore()
self._is_mine = hs.is_mine
self._federation_client = hs.get_federation_client()

async def get_account_statuses(
self,
user_ids: List[bytes],
allow_remote: bool,
) -> Tuple[JsonDict, List[str]]:
"""Get account statuses for a list of user IDs.

If one or more account(s) belong to remote homeservers, retrieve their status(es)
over federation if allowed.

Args:
user_ids: The list of accounts to retrieve the status of.
allow_remote: Whether to try to retrieve the status of remote accounts, if
any.

Returns:
The account statuses as well as the list of users whose statuses could not be
retrieved.
babolivier marked this conversation as resolved.
Show resolved Hide resolved

Raises:
SynapseError if a required parameter is missing or malformed, or if one of
the accounts isn't local to this homeserver and allow_remote is False.
"""
statuses = {}
failures = []
remote_users: List[UserID] = []

for user_id_bytes in user_ids:
try:
raw_user_id = user_id_bytes.decode("ascii")
user_id = UserID.from_string(raw_user_id)
except (AttributeError, SynapseError):
babolivier marked this conversation as resolved.
Show resolved Hide resolved
raise SynapseError(
400,
f"Not a valid Matrix user ID: {user_id_bytes.decode('utf8')}",
babolivier marked this conversation as resolved.
Show resolved Hide resolved
Codes.INVALID_PARAM,
)

if self._is_mine(user_id):
status = await self._get_local_account_status(user_id)
statuses[user_id.to_string()] = status
squahtx marked this conversation as resolved.
Show resolved Hide resolved
else:
if not allow_remote:
raise SynapseError(
400,
f"Not a local user: {raw_user_id}",
Codes.INVALID_PARAM,
)

remote_users.append(user_id)

if allow_remote and len(remote_users) > 0:
remote_statuses, remote_failures = await self._get_remote_account_statuses(
remote_users,
)

statuses.update(remote_statuses)
failures += remote_failures

return statuses, failures

async def _get_local_account_status(self, user_id: UserID) -> JsonDict:
"""Retrieve the status of a local account.

Args:
user_id: The account to retrieve the status of.

Returns:
The account's status.
"""
status = {"exists": False}

userinfo = await self._store.get_userinfo_by_id(user_id.to_string())

if userinfo is not None:
status = {
"exists": True,
"deactivated": userinfo.is_deactivated,
}

return status

async def _get_remote_account_statuses(
self, remote_users: List[UserID]
) -> Tuple[JsonDict, List[str]]:
"""Send out federation requests to retrieve the statuses of remote accounts.

Args:
remote_users: The accounts to retrieve the statuses of.

Returns:
The statuses of the accounts, and a list of accounts for which no status
could be retrieved.
"""
# Group remote users by destination, so we only send one request per remote
# homeserver.
by_destination: Dict[str, List[str]] = {}
for user in remote_users:
if user.domain not in by_destination:
by_destination[user.domain] = []

by_destination[user.domain].append(user.to_string())

# Retrieve the statuses and failures for remote accounts.
final_statuses: JsonDict = {}
final_failures: List[str] = []
for destination, users in by_destination.items():
statuses, failures = await self._federation_client.get_account_status(
destination,
users,
)

final_statuses.update(statuses)
final_failures += failures
babolivier marked this conversation as resolved.
Show resolved Hide resolved

return final_statuses, final_failures
36 changes: 35 additions & 1 deletion synapse/rest/client/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
import random
from http import HTTPStatus
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from urllib.parse import urlparse

from twisted.web.server import Request
Expand Down Expand Up @@ -894,6 +894,37 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
return 200, response


class AccountStatusRestServlet(RestServlet):
PATTERNS = client_patterns(
"/org.matrix.msc3720/account_status$", unstable=True, releases=()
)

def __init__(self, hs: "HomeServer"):
super().__init__()
self._auth = hs.get_auth()
self._store = hs.get_datastore()
self._is_mine = hs.is_mine
self._federation_client = hs.get_federation_client()
Comment on lines +907 to +909
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these things all seem to be unused. Please don't import things from HomeServer where they are not required - it increases the coupling of the code.

Fixed in #12067.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gah, sorry about that. It came from a previous version of that implem that I refactored before opening the PR, but it looks like I forgot to remove some bits. Thanks for taking care of it!

self._account_handler = hs.get_account_handler()

async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self._auth.get_user_by_req(request)

args: Dict[bytes, List[bytes]] = request.args # type: ignore[assignment]
if b"user_id" not in args:
raise SynapseError(
400, "Required parameter 'user_id' is missing", Codes.MISSING_PARAM
)

user_ids: List[bytes] = args[b"user_id"]
babolivier marked this conversation as resolved.
Show resolved Hide resolved
statuses, failures = await self._account_handler.get_account_statuses(
user_ids,
allow_remote=True,
)

return 200, {"account_statuses": statuses, "failures": failures}


def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server)
Expand All @@ -908,3 +939,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ThreepidUnbindRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)
WhoamiRestServlet(hs).register(http_server)

if hs.config.experimental.msc3720_enabled:
AccountStatusRestServlet(hs).register(http_server)
5 changes: 5 additions & 0 deletions synapse/rest/client/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if self.config.experimental.msc3440_enabled:
response["capabilities"]["io.element.thread"] = {"enabled": True}

if self.config.experimental.msc3720_enabled:
response["capabilities"]["org.matrix.msc3720.account_status"] = {
"enabled": True,
}

return HTTPStatus.OK, response


Expand Down
5 changes: 5 additions & 0 deletions synapse/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from synapse.federation.transport.client import TransportLayerClient
from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler
from synapse.handlers.account import AccountHandler
from synapse.handlers.account_data import AccountDataHandler
from synapse.handlers.account_validity import AccountValidityHandler
from synapse.handlers.admin import AdminHandler
Expand Down Expand Up @@ -807,6 +808,10 @@ def get_event_auth_handler(self) -> EventAuthHandler:
def get_external_cache(self) -> ExternalCache:
return ExternalCache(self)

@cache_in_self
def get_account_handler(self) -> AccountHandler:
return AccountHandler(self)

@cache_in_self
def get_outbound_redis_connection(self) -> "RedisProtocol":
"""
Expand Down
Loading