Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Finish converting the media repo code to async / await. #7947

Merged
merged 7 commits into from
Jul 27, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7947.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert more media code to async/await.
clokep marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 5 additions & 3 deletions synapse/rest/media/v1/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import logging
import os
import urllib
from typing import Awaitable

from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender

from synapse.api.errors import Codes, SynapseError, cs_error
Expand Down Expand Up @@ -240,14 +242,14 @@ class Responder(object):
held can be cleaned up.
"""

def write_to_consumer(self, consumer):
def write_to_consumer(self, consumer: IConsumer) -> Awaitable:
"""Stream response into consumer

Args:
consumer (IConsumer)
consumer: The consumer to stream into.

Returns:
Deferred: Resolves once the response has finished being written
Resolves once the response has finished being written
"""
pass

Expand Down
105 changes: 61 additions & 44 deletions synapse/rest/media/v1/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
import logging
import os
import shutil
from typing import Dict, Tuple
from typing import IO, Dict, Optional, Tuple

import twisted.internet.error
import twisted.web.http
from twisted.web.http import Request
from twisted.web.resource import Resource

from synapse.api.errors import (
Expand All @@ -40,6 +41,7 @@

from ._base import (
FileInfo,
Responder,
get_filename_from_headers,
respond_404,
respond_with_responder,
Expand Down Expand Up @@ -135,19 +137,24 @@ def mark_recently_accessed(self, server_name, media_id):
self.recently_accessed_locals.add(media_id)

async def create_content(
self, media_type, upload_name, content, content_length, auth_user
):
self,
media_type: str,
upload_name: str,
content: IO,
content_length: int,
auth_user: str,
) -> str:
"""Store uploaded content for a local user and return the mxc URL

Args:
media_type(str): The content type of the file
upload_name(str): The name of the file
media_type: The content type of the file
upload_name: The name of the file
content: A file like object that is the content to store
content_length(int): The length of the content
auth_user(str): The user_id of the uploader
content_length: The length of the content
auth_user: The user_id of the uploader

Returns:
Deferred[str]: The mxc url of the stored content
The mxc url of the stored content
"""
media_id = random_string(24)

Expand All @@ -170,19 +177,20 @@ async def create_content(

return "mxc://%s/%s" % (self.server_name, media_id)

async def get_local_media(self, request, media_id, name):
async def get_local_media(
self, request: Request, media_id: str, name: Optional[str]
) -> None:
"""Responds to reqests for local media, if exists, or returns 404.

Args:
request(twisted.web.http.Request)
media_id (str): The media ID of the content. (This is the same as
request: The incoming request.
media_id: The media ID of the content. (This is the same as
the file_id for local content.)
name (str|None): Optional name that, if specified, will be used as
name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.

Returns:
Deferred: Resolves once a response has successfully been written
to request
Resolves once a response has successfully been written to request
"""
media_info = await self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
Expand All @@ -203,20 +211,20 @@ async def get_local_media(self, request, media_id, name):
request, responder, media_type, media_length, upload_name
)

async def get_remote_media(self, request, server_name, media_id, name):
async def get_remote_media(
self, request: Request, server_name: str, media_id: str, name: Optional[str]
) -> None:
"""Respond to requests for remote media.

Args:
request(twisted.web.http.Request)
server_name (str): Remote server_name where the media originated.
media_id (str): The media ID of the content (as defined by the
remote server).
name (str|None): Optional name that, if specified, will be used as
request: The incoming request.
server_name: Remote server_name where the media originated.
media_id: The media ID of the content (as defined by the remote server).
name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.

Returns:
Deferred: Resolves once a response has successfully been written
to request
Resolves once a response has successfully been written to request
"""
if (
self.federation_domain_whitelist is not None
Expand Down Expand Up @@ -245,17 +253,16 @@ async def get_remote_media(self, request, server_name, media_id, name):
else:
respond_404(request)

async def get_remote_media_info(self, server_name, media_id):
async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
"""Gets the media info associated with the remote file, downloading
if necessary.

Args:
server_name (str): Remote server_name where the media originated.
media_id (str): The media ID of the content (as defined by the
remote server).
server_name: Remote server_name where the media originated.
media_id: The media ID of the content (as defined by the remote server).

Returns:
Deferred[dict]: The media_info of the file
The media info of the file
"""
if (
self.federation_domain_whitelist is not None
Expand All @@ -278,7 +285,9 @@ async def get_remote_media_info(self, server_name, media_id):

return media_info

async def _get_remote_media_impl(self, server_name, media_id):
async def _get_remote_media_impl(
self, server_name: str, media_id: str
) -> Tuple[Optional[Responder], dict]:
"""Looks for media in local cache, if not there then attempt to
download from remote server.

Expand All @@ -288,7 +297,7 @@ async def _get_remote_media_impl(self, server_name, media_id):
remote server).

Returns:
Deferred[(Responder, media_info)]
A tuple of responder and the media info of the file.
"""
media_info = await self.store.get_cached_remote_media(server_name, media_id)

Expand Down Expand Up @@ -319,19 +328,21 @@ async def _get_remote_media_impl(self, server_name, media_id):
responder = await self.media_storage.fetch_media(file_info)
return responder, media_info

async def _download_remote_file(self, server_name, media_id, file_id):
async def _download_remote_file(
self, server_name: str, media_id: str, file_id: str
) -> dict:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.

Args:
server_name (str): Originating server
media_id (str): The media ID of the content (as defined by the
server_name: Originating server
media_id: The media ID of the content (as defined by the
remote server). This is different than the file_id, which is
locally generated.
file_id (str): Local file ID
file_id: Local file ID

Returns:
Deferred[MediaInfo]
The media info of the file.
"""

file_info = FileInfo(server_name=server_name, file_id=file_id)
Expand Down Expand Up @@ -549,25 +560,31 @@ async def generate_remote_exact_thumbnail(
return output_path

async def _generate_thumbnails(
self, server_name, media_id, file_id, media_type, url_cache=False
):
self,
server_name: Optional[str],
media_id: str,
file_id: str,
media_type: str,
url_cache: bool = False,
) -> Optional[dict]:
"""Generate and store thumbnails for an image.

Args:
server_name (str|None): The server name if remote media, else None if local
media_id (str): The media ID of the content. (This is the same as
server_name: The server name if remote media, else None if local
media_id: The media ID of the content. (This is the same as
the file_id for local content)
file_id (str): Local file ID
media_type (str): The content type of the file
url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
file_id: Local file ID
media_type: The content type of the file
url_cache: If we are thumbnailing images downloaded for the URL cache,
used exclusively by the url previewer

Returns:
Deferred[dict]: Dict with "width" and "height" keys of original image
Dict with "width" and "height" keys of original image or None if the
media cannot be thumbnailed.
"""
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
return
return None

input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache)
Expand All @@ -584,7 +601,7 @@ async def _generate_thumbnails(
m_height,
self.max_image_pixels,
)
return
return None

if thumbnailer.transpose_method is not None:
m_width, m_height = await defer_to_thread(
Expand Down
52 changes: 28 additions & 24 deletions synapse/rest/media/v1/media_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import inspect
import logging
import os
import shutil
from typing import Optional
from typing import IO, TYPE_CHECKING, Optional, Sequence

from twisted.protocols.basic import FileSender

from synapse.logging.context import defer_to_thread, make_deferred_yieldable
from synapse.util.file_consumer import BackgroundFileConsumer

from ._base import FileInfo, Responder
from .filepath import MediaFilePaths

if TYPE_CHECKING:
from synapse.server import HomeServer

from .storage_provider import StorageProvider

logger = logging.getLogger(__name__)

Expand All @@ -34,20 +39,25 @@ class MediaStorage(object):
"""Responsible for storing/fetching files from local sources.

Args:
hs (synapse.server.Homeserver)
local_media_directory (str): Base path where we store media on disk
filepaths (MediaFilePaths)
storage_providers ([StorageProvider]): List of StorageProvider that are
used to fetch and store files.
hs
local_media_directory: Base path where we store media on disk
filepaths
storage_providers: List of StorageProvider that are used to fetch and store files.
"""

def __init__(self, hs, local_media_directory, filepaths, storage_providers):
def __init__(
self,
hs: "HomeServer",
local_media_directory: str,
filepaths: MediaFilePaths,
storage_providers: Sequence["StorageProvider"],
):
self.hs = hs
self.local_media_directory = local_media_directory
self.filepaths = filepaths
self.storage_providers = storage_providers

async def store_file(self, source, file_info: FileInfo) -> str:
async def store_file(self, source: IO, file_info: FileInfo) -> str:
"""Write `source` to the on disk media store, and also any other
configured storage providers

Expand All @@ -69,7 +79,7 @@ async def store_file(self, source, file_info: FileInfo) -> str:
return fname

@contextlib.contextmanager
def store_into_file(self, file_info):
def store_into_file(self, file_info: FileInfo):
"""Context manager used to get a file like object to write into, as
described by file_info.

Expand All @@ -85,7 +95,7 @@ def store_into_file(self, file_info):
error.

Args:
file_info (FileInfo): Info about the file to store
file_info: Info about the file to store

Example:

Expand Down Expand Up @@ -147,10 +157,10 @@ async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
# Fetch is supposed to return an Awaitable, but guard against
# improper implementations.
if inspect.isawaitable(res):
res = await res
res = await res # type: ignore
if res:
logger.debug("Streaming %s from %s", path, provider)
return res
return res # type: ignore

return None

Expand Down Expand Up @@ -178,29 +188,23 @@ async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str:
# Fetch is supposed to return an Awaitable, but guard against
# improper implementations.
if inspect.isawaitable(res):
res = await res
if res:
with res:
res = await res # type: ignore
if res: # type: ignore
with res: # type: ignore
clokep marked this conversation as resolved.
Show resolved Hide resolved
consumer = BackgroundFileConsumer(
open(local_path, "wb"), self.hs.get_reactor()
)
await res.write_to_consumer(consumer)
await res.write_to_consumer(consumer) # type: ignore
await consumer.wait()
return local_path

raise Exception("file could not be found")

def _file_info_to_path(self, file_info):
def _file_info_to_path(self, file_info: FileInfo) -> str:
"""Converts file_info into a relative path.

The path is suitable for storing files under a directory, e.g. used to
store files on local FS under the base media repository directory.

Args:
file_info (FileInfo)

Returns:
str
"""
if file_info.url_cache:
if file_info.thumbnail:
Expand Down
Loading