diff --git a/changelog.d/9311.feature b/changelog.d/9311.feature new file mode 100644 index 000000000000..293f2118e5f8 --- /dev/null +++ b/changelog.d/9311.feature @@ -0,0 +1 @@ +Add hook to spam checker modules that allow checking file uploads and remote downloads. diff --git a/docs/spam_checker.md b/docs/spam_checker.md index 5b4f6428e658..47a27bf85c58 100644 --- a/docs/spam_checker.md +++ b/docs/spam_checker.md @@ -61,6 +61,9 @@ class ExampleSpamChecker: async def check_registration_for_spam(self, email_threepid, username, request_info): return RegistrationBehaviour.ALLOW # allow all registrations + + async def check_media_file_for_spam(self, file_wrapper, file_info): + return False # allow all media ``` ## Configuration diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index e7e3a7b9a45a..8cfc0bb3cbb1 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -17,6 +17,8 @@ import inspect from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from synapse.rest.media.v1._base import FileInfo +from synapse.rest.media.v1.media_storage import ReadableFileWrapper from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import Collection from synapse.util.async_helpers import maybe_awaitable @@ -214,3 +216,48 @@ async def check_registration_for_spam( return behaviour return RegistrationBehaviour.ALLOW + + async def check_media_file_for_spam( + self, file_wrapper: ReadableFileWrapper, file_info: FileInfo + ) -> bool: + """Checks if a piece of newly uploaded media should be blocked. + + This will be called for local uploads, downloads of remote media, each + thumbnail generated for those, and web pages/images used for URL + previews. + + Note that care should be taken to not do blocking IO operations in the + main thread. For example, to get the contents of a file a module + should do:: + + async def check_media_file_for_spam( + self, file: ReadableFileWrapper, file_info: FileInfo + ) -> bool: + buffer = BytesIO() + await file.write_chunks_to(buffer.write) + + if buffer.getvalue() == b"Hello World": + return True + + return False + + + Args: + file: An object that allows reading the contents of the media. + file_info: Metadata about the file. + + Returns: + True if the media should be blocked or False if it should be + allowed. + """ + + for spam_checker in self.spam_checkers: + # For backwards compatibility, only run if the method exists on the + # spam checker + checker = getattr(spam_checker, "check_media_file_for_spam", None) + if checker: + spam = await maybe_awaitable(checker(file_wrapper, file_info)) + if spam: + return True + + return False diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 89cdd605aa91..aba6d689a800 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -16,13 +16,17 @@ import logging import os import shutil -from typing import IO, TYPE_CHECKING, Any, Optional, Sequence +from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Sequence + +import attr from twisted.internet.defer import Deferred from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender +from synapse.api.errors import NotFoundError from synapse.logging.context import defer_to_thread, make_deferred_yieldable +from synapse.util import Clock from synapse.util.file_consumer import BackgroundFileConsumer from ._base import FileInfo, Responder @@ -58,6 +62,8 @@ def __init__( self.local_media_directory = local_media_directory self.filepaths = filepaths self.storage_providers = storage_providers + self.spam_checker = hs.get_spam_checker() + self.clock = hs.get_clock() async def store_file(self, source: IO, file_info: FileInfo) -> str: """Write `source` to the on disk media store, and also any other @@ -127,18 +133,29 @@ async def finish(): f.flush() f.close() + spam = await self.spam_checker.check_media_file_for_spam( + ReadableFileWrapper(self.clock, fname), file_info + ) + if spam: + logger.info("Blocking media due to spam checker") + # Note that we'll delete the stored media, due to the + # try/except below. The media also won't be stored in + # the DB. + raise SpamMediaException() + for provider in self.storage_providers: await provider.store_file(path, file_info) finished_called[0] = True yield f, fname, finish - except Exception: + except Exception as e: try: os.remove(fname) except Exception: pass - raise + + raise e from None if not finished_called: raise Exception("Finished callback not called") @@ -302,3 +319,39 @@ def write_to_consumer(self, consumer: IConsumer) -> Deferred: def __exit__(self, exc_type, exc_val, exc_tb): self.open_file.close() + + +class SpamMediaException(NotFoundError): + """The media was blocked by a spam checker, so we simply 404 the request (in + the same way as if it was quarantined). + """ + + +@attr.s(slots=True) +class ReadableFileWrapper: + """Wrapper that allows reading a file in chunks, yielding to the reactor, + and writing to a callback. + + This is simplified `FileSender` that takes an IO object rather than an + `IConsumer`. + """ + + CHUNK_SIZE = 2 ** 14 + + clock = attr.ib(type=Clock) + path = attr.ib(type=str) + + async def write_chunks_to(self, callback: Callable[[bytes], None]): + """Reads the file in chunks and calls the callback with each chunk. + """ + + with open(self.path, "rb") as file: + while True: + chunk = file.read(self.CHUNK_SIZE) + if not chunk: + break + + callback(chunk) + + # We yield to the reactor by sleeping for 0 seconds. + await self.clock.sleep(0) diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 6da76ae994af..1136277794a1 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -22,6 +22,7 @@ from synapse.api.errors import Codes, SynapseError from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_string +from synapse.rest.media.v1.media_storage import SpamMediaException if TYPE_CHECKING: from synapse.app.homeserver import HomeServer @@ -86,9 +87,14 @@ async def _async_render_POST(self, request: Request) -> None: # disposition = headers.getRawHeaders(b"Content-Disposition")[0] # TODO(markjh): parse content-dispostion - content_uri = await self.media_repo.create_content( - media_type, upload_name, request.content, content_length, requester.user - ) + try: + content_uri = await self.media_repo.create_content( + media_type, upload_name, request.content, content_length, requester.user + ) + except SpamMediaException: + # For uploading of media we want to respond with a 400, instead of + # the default 404, as that would just be confusing. + raise SynapseError(400, "Bad content") logger.info("Uploaded content with URI %r", content_uri) diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index a6c6985173a6..c279eb49e366 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -30,6 +30,8 @@ from twisted.internet.defer import Deferred from synapse.logging.context import make_deferred_yieldable +from synapse.rest import admin +from synapse.rest.client.v1 import login from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.rest.media.v1.media_storage import MediaStorage @@ -37,6 +39,7 @@ from tests import unittest from tests.server import FakeSite, make_request +from tests.utils import default_config class MediaStorageTests(unittest.HomeserverTestCase): @@ -398,3 +401,94 @@ def test_x_robots_tag_header(self): headers.getRawHeaders(b"X-Robots-Tag"), [b"noindex, nofollow, noarchive, noimageindex"], ) + + +class TestSpamChecker: + """A spam checker module that rejects all media that includes the bytes + `evil`. + """ + + def __init__(self, config, api): + self.config = config + self.api = api + + def parse_config(config): + return config + + async def check_event_for_spam(self, foo): + return False # allow all events + + async def user_may_invite(self, inviter_userid, invitee_userid, room_id): + return True # allow all invites + + async def user_may_create_room(self, userid): + return True # allow all room creations + + async def user_may_create_room_alias(self, userid, room_alias): + return True # allow all room aliases + + async def user_may_publish_room(self, userid, room_id): + return True # allow publishing of all rooms + + async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool: + buf = BytesIO() + await file_wrapper.write_chunks_to(buf.write) + + return b"evil" in buf.getvalue() + + +class SpamCheckerTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + admin.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.user = self.register_user("user", "pass") + self.tok = self.login("user", "pass") + + # Allow for uploading and downloading to/from the media repo + self.media_repo = hs.get_media_repository_resource() + self.download_resource = self.media_repo.children[b"download"] + self.upload_resource = self.media_repo.children[b"upload"] + + def default_config(self): + config = default_config("test") + + config.update( + { + "spam_checker": [ + { + "module": TestSpamChecker.__module__ + ".TestSpamChecker", + "config": {}, + } + ] + } + ) + + return config + + def test_upload_innocent(self): + """Attempt to upload some innocent data that should be allowed. + """ + + image_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + self.helper.upload_media( + self.upload_resource, image_data, tok=self.tok, expect_code=200 + ) + + def test_upload_ban(self): + """Attempt to upload some data that includes bytes "evil", which should + get rejected by the spam checker. + """ + + data = b"Some evil data" + + self.helper.upload_media( + self.upload_resource, data, tok=self.tok, expect_code=400 + )