From af9bc17be18dcd5bb1440fca4855fbfcde1199c0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 17 Jul 2020 13:48:08 -0400 Subject: [PATCH 1/7] Convert the StorageProvider to async/await. --- synapse/rest/media/v1/storage_provider.py | 56 ++++++++++------------- 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 858680be266c..56e7f36a8720 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -16,62 +16,56 @@ import logging import os import shutil - -from twisted.internet import defer +from typing import Optional from synapse.config._base import Config from synapse.logging.context import defer_to_thread, run_in_background +from ._base import FileInfo, Responder from .media_storage import FileResponder logger = logging.getLogger(__name__) -class StorageProvider(object): +class StorageProvider: """A storage provider is a service that can store uploaded media and retrieve them. """ - def store_file(self, path, file_info): + async def store_file(self, path: str, file_info: FileInfo): """Store the file described by file_info. The actual contents can be retrieved by reading the file in file_info.upload_path. Args: - path (str): Relative path of file in local cache - file_info (FileInfo) - - Returns: - Deferred + path: Relative path of file in local cache + file_info: The metadata of the file. """ - pass - def fetch(self, path, file_info): + async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: """Attempt to fetch the file described by file_info and stream it into writer. Args: - path (str): Relative path of file in local cache - file_info (FileInfo) + path: Relative path of file in local cache + file_info: The metadata of the file. Returns: - Deferred(Responder): Returns a Responder if the provider has the file, - otherwise returns None. + Returns a Responder if the provider has the file, otherwise returns None. """ - pass class StorageProviderWrapper(StorageProvider): """Wraps a storage provider and provides various config options Args: - backend (StorageProvider) - store_local (bool): Whether to store new local files or not. - store_synchronous (bool): Whether to wait for file to be successfully + backend: The storage provider to wrap. + store_local: Whether to store new local files or not. + store_synchronous: Whether to wait for file to be successfully uploaded, or todo the upload in the background. - store_remote (bool): Whether remote media should be uploaded + store_remote: Whether remote media should be uploaded """ - def __init__(self, backend, store_local, store_synchronous, store_remote): + def __init__(self, backend: StorageProvider, store_local: bool, store_synchronous: bool, store_remote: bool): self.backend = backend self.store_local = store_local self.store_synchronous = store_synchronous @@ -80,15 +74,15 @@ def __init__(self, backend, store_local, store_synchronous, store_remote): def __str__(self): return "StorageProviderWrapper[%s]" % (self.backend,) - def store_file(self, path, file_info): + async def store_file(self, path, file_info): if not file_info.server_name and not self.store_local: - return defer.succeed(None) + return None if file_info.server_name and not self.store_remote: - return defer.succeed(None) + return None if self.store_synchronous: - return self.backend.store_file(path, file_info) + return await self.backend.store_file(path, file_info) else: # TODO: Handle errors. def store(): @@ -98,10 +92,10 @@ def store(): logger.exception("Error storing file") run_in_background(store) - return defer.succeed(None) + return None - def fetch(self, path, file_info): - return self.backend.fetch(path, file_info) + async def fetch(self, path, file_info): + return await self.backend.fetch(path, file_info) class FileStorageProviderBackend(StorageProvider): @@ -120,7 +114,7 @@ def __init__(self, hs, config): def __str__(self): return "FileStorageProviderBackend[%s]" % (self.base_directory,) - def store_file(self, path, file_info): + async def store_file(self, path, file_info): """See StorageProvider.store_file""" primary_fname = os.path.join(self.cache_directory, path) @@ -130,11 +124,11 @@ def store_file(self, path, file_info): if not os.path.exists(dirname): os.makedirs(dirname) - return defer_to_thread( + return await defer_to_thread( self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname ) - def fetch(self, path, file_info): + async def fetch(self, path, file_info): """See StorageProvider.fetch""" backup_fname = os.path.join(self.base_directory, path) From a57f88af3aad73000bb499ef2331941ab1fbbbe7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 24 Jul 2020 11:14:47 -0400 Subject: [PATCH 2/7] Remove other references to Deferreds in media code. --- synapse/rest/media/v1/_base.py | 8 +- synapse/rest/media/v1/media_repository.py | 105 ++++++++++-------- synapse/rest/media/v1/media_storage.py | 50 +++++---- synapse/rest/media/v1/preview_url_resource.py | 10 +- synapse/rest/media/v1/storage_provider.py | 8 +- 5 files changed, 105 insertions(+), 76 deletions(-) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 9a847130c0c9..20ddb9550b29 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -17,7 +17,9 @@ import logging import os import urllib +from typing import Awaitable +from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender from synapse.api.errors import Codes, SynapseError, cs_error @@ -240,14 +242,14 @@ class Responder(object): held can be cleaned up. """ - def write_to_consumer(self, consumer): + def write_to_consumer(self, consumer: IConsumer) -> Awaitable: """Stream response into consumer Args: - consumer (IConsumer) + consumer: The consumer to stream into. Returns: - Deferred: Resolves once the response has finished being written + Resolves once the response has finished being written """ pass diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 45628c07b401..6fb4039e9877 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -18,10 +18,11 @@ import logging import os import shutil -from typing import Dict, Tuple +from typing import IO, Dict, Optional, Tuple import twisted.internet.error import twisted.web.http +from twisted.web.http import Request from twisted.web.resource import Resource from synapse.api.errors import ( @@ -40,6 +41,7 @@ from ._base import ( FileInfo, + Responder, get_filename_from_headers, respond_404, respond_with_responder, @@ -135,19 +137,24 @@ def mark_recently_accessed(self, server_name, media_id): self.recently_accessed_locals.add(media_id) async def create_content( - self, media_type, upload_name, content, content_length, auth_user - ): + self, + media_type: str, + upload_name: str, + content: IO, + content_length: int, + auth_user: str, + ) -> str: """Store uploaded content for a local user and return the mxc URL Args: - media_type(str): The content type of the file - upload_name(str): The name of the file + media_type: The content type of the file + upload_name: The name of the file content: A file like object that is the content to store - content_length(int): The length of the content - auth_user(str): The user_id of the uploader + content_length: The length of the content + auth_user: The user_id of the uploader Returns: - Deferred[str]: The mxc url of the stored content + The mxc url of the stored content """ media_id = random_string(24) @@ -170,19 +177,20 @@ async def create_content( return "mxc://%s/%s" % (self.server_name, media_id) - async def get_local_media(self, request, media_id, name): + async def get_local_media( + self, request: Request, media_id: str, name: Optional[str] + ) -> None: """Responds to reqests for local media, if exists, or returns 404. Args: - request(twisted.web.http.Request) - media_id (str): The media ID of the content. (This is the same as + request: The incoming request. + media_id: The media ID of the content. (This is the same as the file_id for local content.) - name (str|None): Optional name that, if specified, will be used as + name: Optional name that, if specified, will be used as the filename in the Content-Disposition header of the response. Returns: - Deferred: Resolves once a response has successfully been written - to request + Resolves once a response has successfully been written to request """ media_info = await self.store.get_local_media(media_id) if not media_info or media_info["quarantined_by"]: @@ -203,20 +211,20 @@ async def get_local_media(self, request, media_id, name): request, responder, media_type, media_length, upload_name ) - async def get_remote_media(self, request, server_name, media_id, name): + async def get_remote_media( + self, request: Request, server_name: str, media_id: str, name: Optional[str] + ) -> None: """Respond to requests for remote media. Args: - request(twisted.web.http.Request) - server_name (str): Remote server_name where the media originated. - media_id (str): The media ID of the content (as defined by the - remote server). - name (str|None): Optional name that, if specified, will be used as + request: The incoming request. + server_name: Remote server_name where the media originated. + media_id: The media ID of the content (as defined by the remote server). + name: Optional name that, if specified, will be used as the filename in the Content-Disposition header of the response. Returns: - Deferred: Resolves once a response has successfully been written - to request + Resolves once a response has successfully been written to request """ if ( self.federation_domain_whitelist is not None @@ -245,17 +253,16 @@ async def get_remote_media(self, request, server_name, media_id, name): else: respond_404(request) - async def get_remote_media_info(self, server_name, media_id): + async def get_remote_media_info(self, server_name: str, media_id: str) -> dict: """Gets the media info associated with the remote file, downloading if necessary. Args: - server_name (str): Remote server_name where the media originated. - media_id (str): The media ID of the content (as defined by the - remote server). + server_name: Remote server_name where the media originated. + media_id: The media ID of the content (as defined by the remote server). Returns: - Deferred[dict]: The media_info of the file + The media info of the file """ if ( self.federation_domain_whitelist is not None @@ -278,7 +285,9 @@ async def get_remote_media_info(self, server_name, media_id): return media_info - async def _get_remote_media_impl(self, server_name, media_id): + async def _get_remote_media_impl( + self, server_name: str, media_id: str + ) -> Tuple[Optional[Responder], dict]: """Looks for media in local cache, if not there then attempt to download from remote server. @@ -288,7 +297,7 @@ async def _get_remote_media_impl(self, server_name, media_id): remote server). Returns: - Deferred[(Responder, media_info)] + A tuple of responder and the media info of the file. """ media_info = await self.store.get_cached_remote_media(server_name, media_id) @@ -319,19 +328,21 @@ async def _get_remote_media_impl(self, server_name, media_id): responder = await self.media_storage.fetch_media(file_info) return responder, media_info - async def _download_remote_file(self, server_name, media_id, file_id): + async def _download_remote_file( + self, server_name: str, media_id: str, file_id: str + ) -> dict: """Attempt to download the remote file from the given server name, using the given file_id as the local id. Args: - server_name (str): Originating server - media_id (str): The media ID of the content (as defined by the + server_name: Originating server + media_id: The media ID of the content (as defined by the remote server). This is different than the file_id, which is locally generated. - file_id (str): Local file ID + file_id: Local file ID Returns: - Deferred[MediaInfo] + The media info of the file. """ file_info = FileInfo(server_name=server_name, file_id=file_id) @@ -549,25 +560,31 @@ async def generate_remote_exact_thumbnail( return output_path async def _generate_thumbnails( - self, server_name, media_id, file_id, media_type, url_cache=False - ): + self, + server_name: Optional[str], + media_id: str, + file_id: str, + media_type: str, + url_cache: bool = False, + ) -> Optional[dict]: """Generate and store thumbnails for an image. Args: - server_name (str|None): The server name if remote media, else None if local - media_id (str): The media ID of the content. (This is the same as + server_name: The server name if remote media, else None if local + media_id: The media ID of the content. (This is the same as the file_id for local content) - file_id (str): Local file ID - media_type (str): The content type of the file - url_cache (bool): If we are thumbnailing images downloaded for the URL cache, + file_id: Local file ID + media_type: The content type of the file + url_cache: If we are thumbnailing images downloaded for the URL cache, used exclusively by the url previewer Returns: - Deferred[dict]: Dict with "width" and "height" keys of original image + Dict with "width" and "height" keys of original image or None if the + media cannot be thumbnailed. """ requirements = self._get_thumbnail_requirements(media_type) if not requirements: - return + return None input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(server_name, file_id, url_cache=url_cache) @@ -584,7 +601,7 @@ async def _generate_thumbnails( m_height, self.max_image_pixels, ) - return + return None if thumbnailer.transpose_method is not None: m_width, m_height = await defer_to_thread( diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 66bc1c336088..bf606d741d2c 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -18,7 +18,7 @@ import logging import os import shutil -from typing import Optional +from typing import TYPE_CHECKING, IO, Sequence, Optional from twisted.protocols.basic import FileSender @@ -26,6 +26,11 @@ from synapse.util.file_consumer import BackgroundFileConsumer from ._base import FileInfo, Responder +from .filepath import MediaFilePaths + +if TYPE_CHECKING: + from synapse.server import HomeServer + from .storage_provider import StorageProvider logger = logging.getLogger(__name__) @@ -34,20 +39,25 @@ class MediaStorage(object): """Responsible for storing/fetching files from local sources. Args: - hs (synapse.server.Homeserver) - local_media_directory (str): Base path where we store media on disk - filepaths (MediaFilePaths) - storage_providers ([StorageProvider]): List of StorageProvider that are - used to fetch and store files. + hs + local_media_directory: Base path where we store media on disk + filepaths + storage_providers: List of StorageProvider that are used to fetch and store files. """ - def __init__(self, hs, local_media_directory, filepaths, storage_providers): + def __init__( + self, + hs: "HomeServer", + local_media_directory: str, + filepaths: MediaFilePaths, + storage_providers: Sequence["StorageProvider"], + ): self.hs = hs self.local_media_directory = local_media_directory self.filepaths = filepaths self.storage_providers = storage_providers - async def store_file(self, source, file_info: FileInfo) -> str: + async def store_file(self, source: IO, file_info: FileInfo) -> str: """Write `source` to the on disk media store, and also any other configured storage providers @@ -69,7 +79,7 @@ async def store_file(self, source, file_info: FileInfo) -> str: return fname @contextlib.contextmanager - def store_into_file(self, file_info): + def store_into_file(self, file_info: FileInfo): """Context manager used to get a file like object to write into, as described by file_info. @@ -85,7 +95,7 @@ def store_into_file(self, file_info): error. Args: - file_info (FileInfo): Info about the file to store + file_info: Info about the file to store Example: @@ -147,10 +157,10 @@ async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: # Fetch is supposed to return an Awaitable, but guard against # improper implementations. if inspect.isawaitable(res): - res = await res + res = await res # type: ignore if res: logger.debug("Streaming %s from %s", path, provider) - return res + return res # type: ignore return None @@ -178,29 +188,23 @@ async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str: # Fetch is supposed to return an Awaitable, but guard against # improper implementations. if inspect.isawaitable(res): - res = await res - if res: - with res: + res = await res # type: ignore + if res: # type: ignore + with res: # type: ignore consumer = BackgroundFileConsumer( open(local_path, "wb"), self.hs.get_reactor() ) - await res.write_to_consumer(consumer) + await res.write_to_consumer(consumer) # type: ignore await consumer.wait() return local_path raise Exception("file could not be found") - def _file_info_to_path(self, file_info): + def _file_info_to_path(self, file_info: FileInfo) -> str: """Converts file_info into a relative path. The path is suitable for storing files under a directory, e.g. used to store files on local FS under the base media repository directory. - - Args: - file_info (FileInfo) - - Returns: - str """ if file_info.url_cache: if file_info.thumbnail: diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index e52c86c798f1..a2534ac8f217 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -171,16 +171,16 @@ async def _async_render_GET(self, request): og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe)) respond_with_json_bytes(request, 200, og, send_cors=True) - async def _do_preview(self, url, user, ts): + async def _do_preview(self, url: str, user: str, ts: int) -> bytes: """Check the db, and download the URL and build a preview Args: - url (str): - user (str): - ts (int): + url: The URL to preview. + user: The user requesting the preview. + ts: The timestamp requested for the preview. Returns: - Deferred[bytes]: json-encoded og data + json-encoded og data """ # check the URL cache in the DB (which will also provide us with # historical previews, if we have any) diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 56e7f36a8720..a33f56e8068d 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -65,7 +65,13 @@ class StorageProviderWrapper(StorageProvider): store_remote: Whether remote media should be uploaded """ - def __init__(self, backend: StorageProvider, store_local: bool, store_synchronous: bool, store_remote: bool): + def __init__( + self, + backend: StorageProvider, + store_local: bool, + store_synchronous: bool, + store_remote: bool, + ): self.backend = backend self.store_local = store_local self.store_synchronous = store_synchronous From 5a5f5cb7bcc95af1e31c8a2f9fc5c571fa0f1960 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 24 Jul 2020 11:46:48 -0400 Subject: [PATCH 3/7] Add changelog. --- changelog.d/7947.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/7947.misc diff --git a/changelog.d/7947.misc b/changelog.d/7947.misc new file mode 100644 index 000000000000..f72e24853684 --- /dev/null +++ b/changelog.d/7947.misc @@ -0,0 +1 @@ + Convert more media code to async/await. From 5f3402e4483b018191be4b53b4e29ad2a4cd7c9c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 24 Jul 2020 12:22:46 -0400 Subject: [PATCH 4/7] Fix isort --- synapse/rest/media/v1/media_storage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index bf606d741d2c..41c297f96fd0 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -12,13 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import contextlib import inspect import logging import os import shutil -from typing import TYPE_CHECKING, IO, Sequence, Optional +from typing import IO, TYPE_CHECKING, Optional, Sequence from twisted.protocols.basic import FileSender @@ -30,6 +29,7 @@ if TYPE_CHECKING: from synapse.server import HomeServer + from .storage_provider import StorageProvider logger = logging.getLogger(__name__) From 84979a96ce19bed44feaee4974c4d0804c7f8b23 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 24 Jul 2020 15:30:57 -0400 Subject: [PATCH 5/7] Fix typo in changelog. Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- changelog.d/7947.misc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/changelog.d/7947.misc b/changelog.d/7947.misc index f72e24853684..58260764e7fa 100644 --- a/changelog.d/7947.misc +++ b/changelog.d/7947.misc @@ -1 +1 @@ - Convert more media code to async/await. +Convert more media code to async/await. From 6642e243989ceca33a1e3fe4b5c435830f7f3836 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 27 Jul 2020 08:10:49 -0400 Subject: [PATCH 6/7] Mark return values as Any and clean-up type hints a bit. --- synapse/rest/media/v1/media_storage.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 41c297f96fd0..858b6d300595 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -17,7 +17,7 @@ import logging import os import shutil -from typing import IO, TYPE_CHECKING, Optional, Sequence +from typing import IO, TYPE_CHECKING, Any, Optional, Sequence from twisted.protocols.basic import FileSender @@ -153,14 +153,14 @@ async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: return FileResponder(open(local_path, "rb")) for provider in self.storage_providers: - res = provider.fetch(path, file_info) - # Fetch is supposed to return an Awaitable, but guard against - # improper implementations. + res = provider.fetch(path, file_info) # type: Any + # Fetch is supposed to return an Awaitable[Responder], but guard + # against improper implementations. if inspect.isawaitable(res): - res = await res # type: ignore + res = await res if res: logger.debug("Streaming %s from %s", path, provider) - return res # type: ignore + return res return None @@ -184,17 +184,17 @@ async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str: os.makedirs(dirname) for provider in self.storage_providers: - res = provider.fetch(path, file_info) - # Fetch is supposed to return an Awaitable, but guard against - # improper implementations. + res = provider.fetch(path, file_info) # type: Any + # Fetch is supposed to return an Awaitable[Responder], but guard + # against improper implementations. if inspect.isawaitable(res): - res = await res # type: ignore - if res: # type: ignore - with res: # type: ignore + res = await res + if res: + with res: consumer = BackgroundFileConsumer( open(local_path, "wb"), self.hs.get_reactor() ) - await res.write_to_consumer(consumer) # type: ignore + await res.write_to_consumer(consumer) await consumer.wait() return local_path From 99aa65074d2e45ae8f5fe9f95e4d07f76777333f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 27 Jul 2020 12:42:40 -0400 Subject: [PATCH 7/7] Consolidate changelog entries. --- changelog.d/7947.misc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/changelog.d/7947.misc b/changelog.d/7947.misc index 58260764e7fa..dfe4c03171d6 100644 --- a/changelog.d/7947.misc +++ b/changelog.d/7947.misc @@ -1 +1 @@ -Convert more media code to async/await. +Convert various parts of the codebase to async/await.