diff --git a/changelog.d/9958.feature b/changelog.d/9958.feature new file mode 100644 index 000000000000..d86ba36519f4 --- /dev/null +++ b/changelog.d/9958.feature @@ -0,0 +1 @@ +Reduce memory usage when joining very large rooms over federation. diff --git a/mypy.ini b/mypy.ini index ea655a0d4d92..1d1d1ea0f25f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -174,3 +174,6 @@ ignore_missing_imports = True [mypy-pympler.*] ignore_missing_imports = True + +[mypy-ijson.*] +ignore_missing_imports = True diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index a5b6a611952b..e0e9f5d0beeb 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -55,6 +55,7 @@ ) from synapse.events import EventBase, builder from synapse.federation.federation_base import FederationBase, event_from_pdu_json +from synapse.federation.transport.client import SendJoinResponse from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.logging.utils import log_function from synapse.types import JsonDict, get_domain_from_id @@ -665,19 +666,10 @@ async def send_join( """ async def send_request(destination) -> Dict[str, Any]: - content = await self._do_send_join(destination, pdu) + response = await self._do_send_join(room_version, destination, pdu) - logger.debug("Got content: %s", content) - - state = [ - event_from_pdu_json(p, room_version, outlier=True) - for p in content.get("state", []) - ] - - auth_chain = [ - event_from_pdu_json(p, room_version, outlier=True) - for p in content.get("auth_chain", []) - ] + state = response.state + auth_chain = response.auth_events pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)} @@ -752,11 +744,14 @@ async def send_request(destination) -> Dict[str, Any]: return await self._try_destination_list("send_join", destinations, send_request) - async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict: + async def _do_send_join( + self, room_version: RoomVersion, destination: str, pdu: EventBase + ) -> SendJoinResponse: time_now = self._clock.time_msec() try: return await self.transport_layer.send_join_v2( + room_version=room_version, destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, @@ -771,17 +766,14 @@ async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict: logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API") - resp = await self.transport_layer.send_join_v1( + return await self.transport_layer.send_join_v1( + room_version=room_version, destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, content=pdu.get_pdu_json(time_now), ) - # We expect the v1 API to respond with [200, content], so we only return the - # content. - return resp[1] - async def send_invite( self, destination: str, diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index ada322a81e14..2352d6d0f4f9 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -17,13 +17,19 @@ import urllib from typing import Any, Dict, List, Optional +import attr +import ijson + from synapse.api.constants import Membership from synapse.api.errors import Codes, HttpResponseException, SynapseError +from synapse.api.room_versions import RoomVersion from synapse.api.urls import ( FEDERATION_UNSTABLE_PREFIX, FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX, ) +from synapse.events import EventBase, make_event_from_dict +from synapse.http.matrixfederationclient import ByteParser from synapse.logging.utils import log_function from synapse.types import JsonDict @@ -240,21 +246,36 @@ async def make_membership_event( return content @log_function - async def send_join_v1(self, destination, room_id, event_id, content): + async def send_join_v1( + self, + room_version, + destination, + room_id, + event_id, + content, + ) -> "SendJoinResponse": path = _create_v1_path("/send_join/%s/%s", room_id, event_id) response = await self.client.put_json( - destination=destination, path=path, data=content + destination=destination, + path=path, + data=content, + parser=SendJoinParser(room_version, v1_api=True), ) return response @log_function - async def send_join_v2(self, destination, room_id, event_id, content): + async def send_join_v2( + self, room_version, destination, room_id, event_id, content + ) -> "SendJoinResponse": path = _create_v2_path("/send_join/%s/%s", room_id, event_id) response = await self.client.put_json( - destination=destination, path=path, data=content + destination=destination, + path=path, + data=content, + parser=SendJoinParser(room_version, v1_api=False), ) return response @@ -1052,3 +1073,59 @@ def _create_v2_path(path, *args): str """ return _create_path(FEDERATION_V2_PREFIX, path, *args) + + +@attr.s(slots=True, auto_attribs=True) +class SendJoinResponse: + """The parsed response of a `/send_join` request.""" + + auth_events: List[EventBase] + state: List[EventBase] + + +@ijson.coroutine +def _event_list_parser(room_version: RoomVersion, events: List[EventBase]): + """Helper function for use with `ijson.items_coro` to parse an array of + events and add them to the given list. + """ + + while True: + obj = yield + event = make_event_from_dict(obj, room_version) + events.append(event) + + +class SendJoinParser(ByteParser[SendJoinResponse]): + """A parser for the response to `/send_join` requests. + + Args: + room_version: The version of the room. + v1_api: Whether the response is in the v1 format. + """ + + CONTENT_TYPE = "application/json" + + def __init__(self, room_version: RoomVersion, v1_api: bool): + self._response = SendJoinResponse([], []) + + # The V1 API has the shape of `[200, {...}]`, which we handle by + # prefixing with `item.*`. + prefix = "item." if v1_api else "" + + self._coro_state = ijson.items_coro( + _event_list_parser(room_version, self._response.state), + prefix + "state.item", + ) + self._coro_auth = ijson.items_coro( + _event_list_parser(room_version, self._response.auth_events), + prefix + "auth_chain.item", + ) + + def write(self, data: bytes) -> int: + self._coro_state.send(data) + self._coro_auth.send(data) + + return len(data) + + def finish(self) -> SendJoinResponse: + return self._response diff --git a/synapse/http/client.py b/synapse/http/client.py index 5f40f16e24d6..1ca6624fd5da 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -813,7 +813,12 @@ def dataReceived(self, data: bytes) -> None: if self.deferred.called: return - self.stream.write(data) + try: + self.stream.write(data) + except Exception: + self.deferred.errback() + return + self.length += len(data) # The first time the maximum size is exceeded, error and cancel the # connection. dataReceived might be called again if data was received diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index bb837b7b1979..f5503b394b37 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import cgi import codecs import logging @@ -19,13 +20,24 @@ import typing import urllib.parse from io import BytesIO, StringIO -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import ( + Callable, + Dict, + Generic, + List, + Optional, + Tuple, + TypeVar, + Union, + overload, +) import attr import treq from canonicaljson import encode_canonical_json from prometheus_client import Counter from signedjson.sign import sign_json +from typing_extensions import Literal from twisted.internet import defer from twisted.internet.error import DNSLookupError @@ -48,6 +60,7 @@ BlacklistingAgentWrapper, BlacklistingReactorWrapper, BodyExceededMaxSize, + ByteWriteable, encode_query_args, read_body_with_max_size, ) @@ -88,6 +101,27 @@ QueryArgs = Dict[str, Union[str, List[str]]] +T = TypeVar("T") + + +class ByteParser(ByteWriteable, Generic[T], abc.ABC): + """A `ByteWriteable` that has an additional `finish` function that returns + the parsed data. + """ + + CONTENT_TYPE = abc.abstractproperty() # type: str # type: ignore + """The expected content type of the response, e.g. `application/json`. If + the content type doesn't match we fail the request. + """ + + @abc.abstractmethod + def finish(self) -> T: + """Called when response has finished streaming and the parser should + return the final result (or error). + """ + pass + + @attr.s(slots=True, frozen=True) class MatrixFederationRequest: method = attr.ib(type=str) @@ -148,15 +182,32 @@ def get_json(self) -> Optional[JsonDict]: return self.json -async def _handle_json_response( +class JsonParser(ByteParser[Union[JsonDict, list]]): + """A parser that buffers the response and tries to parse it as JSON.""" + + CONTENT_TYPE = "application/json" + + def __init__(self): + self._buffer = StringIO() + self._binary_wrapper = BinaryIOWrapper(self._buffer) + + def write(self, data: bytes) -> int: + return self._binary_wrapper.write(data) + + def finish(self) -> Union[JsonDict, list]: + return json_decoder.decode(self._buffer.getvalue()) + + +async def _handle_response( reactor: IReactorTime, timeout_sec: float, request: MatrixFederationRequest, response: IResponse, start_ms: int, -) -> JsonDict: + parser: ByteParser[T], +) -> T: """ - Reads the JSON body of a response, with a timeout + Reads the body of a response with a timeout and sends it to a parser Args: reactor: twisted reactor, for the timeout @@ -164,23 +215,21 @@ async def _handle_json_response( request: the request that triggered the response response: response to the request start_ms: Timestamp when request was made + parser: The parser for the response Returns: - The parsed JSON response + The parsed response """ + try: - check_content_type_is_json(response.headers) + check_content_type_is(response.headers, parser.CONTENT_TYPE) - buf = StringIO() - d = read_body_with_max_size(response, BinaryIOWrapper(buf), MAX_RESPONSE_SIZE) + d = read_body_with_max_size(response, parser, MAX_RESPONSE_SIZE) d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor) - def parse(_len: int): - return json_decoder.decode(buf.getvalue()) - - d.addCallback(parse) + length = await make_deferred_yieldable(d) - body = await make_deferred_yieldable(d) + value = parser.finish() except BodyExceededMaxSize as e: # The response was too big. logger.warning( @@ -193,9 +242,9 @@ def parse(_len: int): ) raise RequestSendFailed(e, can_retry=False) from e except ValueError as e: - # The JSON content was invalid. + # The content was invalid. logger.warning( - "{%s} [%s] Failed to parse JSON response - %s %s", + "{%s} [%s] Failed to parse response - %s %s", request.txn_id, request.destination, request.method, @@ -225,16 +274,17 @@ def parse(_len: int): time_taken_secs = reactor.seconds() - start_ms / 1000 logger.info( - "{%s} [%s] Completed request: %d %s in %.2f secs - %s %s", + "{%s} [%s] Completed request: %d %s in %.2f secs, got %d bytes - %s %s", request.txn_id, request.destination, response.code, response.phrase.decode("ascii", errors="replace"), time_taken_secs, + length, request.method, request.uri.decode("ascii"), ) - return body + return value class BinaryIOWrapper: @@ -671,6 +721,7 @@ def build_auth_headers( ) return auth_headers + @overload async def put_json( self, destination: str, @@ -683,7 +734,41 @@ async def put_json( ignore_backoff: bool = False, backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, + parser: Literal[None] = None, ) -> Union[JsonDict, list]: + ... + + @overload + async def put_json( + self, + destination: str, + path: str, + args: Optional[QueryArgs] = None, + data: Optional[JsonDict] = None, + json_data_callback: Optional[Callable[[], JsonDict]] = None, + long_retries: bool = False, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + backoff_on_404: bool = False, + try_trailing_slash_on_400: bool = False, + parser: Optional[ByteParser[T]] = None, + ) -> T: + ... + + async def put_json( + self, + destination: str, + path: str, + args: Optional[QueryArgs] = None, + data: Optional[JsonDict] = None, + json_data_callback: Optional[Callable[[], JsonDict]] = None, + long_retries: bool = False, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + backoff_on_404: bool = False, + try_trailing_slash_on_400: bool = False, + parser: Optional[ByteParser] = None, + ): """Sends the specified json data using PUT Args: @@ -716,6 +801,8 @@ async def put_json( of the request. Workaround for #3622 in Synapse <= v0.99.3. This will be attempted before backing off if backing off has been enabled. + parser: The parser to use to decode the response. Defaults to + parsing as JSON. Returns: Succeeds when we get a 2xx HTTP response. The @@ -756,8 +843,16 @@ async def put_json( else: _sec_timeout = self.default_timeout - body = await _handle_json_response( - self.reactor, _sec_timeout, request, response, start_ms + if parser is None: + parser = JsonParser() + + body = await _handle_response( + self.reactor, + _sec_timeout, + request, + response, + start_ms, + parser=parser, ) return body @@ -830,12 +925,8 @@ async def post_json( else: _sec_timeout = self.default_timeout - body = await _handle_json_response( - self.reactor, - _sec_timeout, - request, - response, - start_ms, + body = await _handle_response( + self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser() ) return body @@ -907,8 +998,8 @@ async def get_json( else: _sec_timeout = self.default_timeout - body = await _handle_json_response( - self.reactor, _sec_timeout, request, response, start_ms + body = await _handle_response( + self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser() ) return body @@ -975,8 +1066,8 @@ async def delete_json( else: _sec_timeout = self.default_timeout - body = await _handle_json_response( - self.reactor, _sec_timeout, request, response, start_ms + body = await _handle_response( + self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser() ) return body @@ -1068,16 +1159,16 @@ def _flatten_response_never_received(e): return repr(e) -def check_content_type_is_json(headers: Headers) -> None: +def check_content_type_is(headers: Headers, expected_content_type: str) -> None: """ Check that a set of HTTP headers have a Content-Type header, and that it - is application/json. + is the expected value.. Args: headers: headers to check Raises: - RequestSendFailed: if the Content-Type header is missing or isn't JSON + RequestSendFailed: if the Content-Type header is missing or doesn't match """ content_type_headers = headers.getRawHeaders(b"Content-Type") @@ -1089,11 +1180,10 @@ def check_content_type_is_json(headers: Headers) -> None: c_type = content_type_headers[0].decode("ascii") # only the first header val, options = cgi.parse_header(c_type) - if val != "application/json": + if val != expected_content_type: raise RequestSendFailed( RuntimeError( - "Remote server sent Content-Type header of '%s', not 'application/json'" - % c_type, + f"Remote server sent Content-Type header of '{c_type}', not '{expected_content_type}'", ), can_retry=False, ) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 989523c82374..546231bec0c9 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -87,6 +87,7 @@ # We enforce that we have a `cryptography` version that bundles an `openssl` # with the latest security patches. "cryptography>=3.4.7", + "ijson>=3.0", ] CONDITIONAL_REQUIREMENTS = {