Skip to content

Commit

Permalink
unify ServerHttpProtocol and RequestHandler; replace FileSender with …
Browse files Browse the repository at this point in the history
…FileResponse
  • Loading branch information
fafhrd91 committed Feb 27, 2017
1 parent 0248cd9 commit 9cfa4fe
Show file tree
Hide file tree
Showing 20 changed files with 517 additions and 754 deletions.
4 changes: 1 addition & 3 deletions aiohttp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from .http_websocket import WSMsgType, WSCloseCode, WSMessage, WebSocketError # noqa
from .streams import * # noqa
from .multipart import * # noqa
from .file_sender import FileSender # noqa
from .cookiejar import CookieJar # noqa
from .payload import * # noqa
from .payload_streamer import * # noqa
Expand All @@ -32,8 +31,7 @@
payload.__all__ + # noqa
payload_streamer.__all__ + # noqa
streams.__all__ + # noqa
('hdrs', 'FileSender',
'HttpVersion', 'HttpVersion10', 'HttpVersion11',
('hdrs', 'HttpVersion', 'HttpVersion10', 'HttpVersion11',
'WSMsgType', 'MsgType', 'WSCloseCode',
'WebSocketError', 'WSMessage', 'CookieJar',

Expand Down
4 changes: 3 additions & 1 deletion aiohttp/http.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from yarl import URL # noqa

from .http_exceptions import HttpProcessingError
from .http_message import (RESPONSES, SERVER_SOFTWARE, HttpMessage,
HttpVersion, HttpVersion10, HttpVersion11,
Expand All @@ -13,7 +15,7 @@

# .http_message
'RESPONSES', 'SERVER_SOFTWARE',
'HttpMessage', 'Request', 'Response', 'PayloadWriter',
'HttpMessage', 'Request', 'PayloadWriter',
'HttpVersion', 'HttpVersion10', 'HttpVersion11',

# .http_parser
Expand Down
26 changes: 21 additions & 5 deletions aiohttp/http_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import sys
import zlib
from urllib.parse import SplitResult
from wsgiref.handlers import format_date_time

import yarl
from multidict import CIMultiDict, istr
Expand Down Expand Up @@ -36,7 +35,7 @@

class PayloadWriter(AbstractPayloadWriter):

def __init__(self, stream, loop):
def __init__(self, stream, loop, acquire=True):
if loop is None:
loop = asyncio.get_event_loop()

Expand All @@ -53,13 +52,29 @@ def __init__(self, stream, loop):
self._compress = None
self._drain_waiter = None

self._replacement = None

if self._stream.available:
self._transport = self._stream.transport
self._stream.available = False
else:
elif acquire:
self._stream.acquire(self.set_transport)

def replace(self, factory):
"""Hack: for internal use only """
if self._transport is not None:
self._transport = None
self._stream.available = True
return factory(self._stream, self.loop)
else:
self._replacement = factory(self._stream, self.loop, False)
return self._replacement

def set_transport(self, transport):
if self._replacement is not None:
self._replacement.set_transport(transport)
return

self._transport = transport

chunk = b''.join(self._buffer)
Expand Down Expand Up @@ -196,7 +211,7 @@ def drain(self, last=False):
class HttpMessage(PayloadWriter):
"""HttpMessage allows to write headers and payload to a stream."""

HOP_HEADERS = None # Must be set by subclass.
HOP_HEADERS = () # Must be set by subclass.

SERVER_SOFTWARE = 'Python/{0[0]}.{0[1]} aiohttp/{1}'.format(
sys.version_info, aiohttp.__version__)
Expand All @@ -205,7 +220,8 @@ class HttpMessage(PayloadWriter):
websocket = False # Upgrade: WEBSOCKET
has_chunked_hdr = False # Transfer-encoding: chunked

def __init__(self, transport, version, close, loop=None):
def __init__(self, transport,
version=HttpVersion11, close=False, loop=None):
super().__init__(transport, loop)

self.version = version
Expand Down
21 changes: 12 additions & 9 deletions aiohttp/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from aiohttp.client import _RequestContextManager

from . import ClientSession, hdrs
from .helpers import PY_35, sentinel
from .http import HttpVersion, PayloadWriter, RawRequestMessage
from .helpers import PY_35, noop, sentinel
from .http import HttpVersion, RawRequestMessage
from .signals import Signal
from .web import Application, Request, Server, UrlMappingMatchInfo

Expand Down Expand Up @@ -484,6 +484,7 @@ def make_mocked_request(method, path, headers=None, *,
version=HttpVersion(1, 1), closing=False,
app=None,
writer=sentinel,
payload_writer=sentinel,
protocol=sentinel,
transport=sentinel,
payload=sentinel,
Expand All @@ -497,6 +498,10 @@ def make_mocked_request(method, path, headers=None, *,
"""

task = mock.Mock()
loop = mock.Mock()
loop.create_future.return_value = ()

if version < HttpVersion(1, 1):
closing = True

Expand Down Expand Up @@ -526,6 +531,10 @@ def make_mocked_request(method, path, headers=None, *,
writer = mock.Mock()
writer.transport = transport

if payload_writer is sentinel:
payload_writer = mock.Mock()
payload_writer.write_eof.side_effect = noop

protocol.transport = transport
protocol.writer = writer

Expand All @@ -543,14 +552,8 @@ def timeout(*args, **kw):
time_service.timeout = mock.Mock()
time_service.timeout.side_effect = timeout

task = mock.Mock()
loop = mock.Mock()
loop.create_future.return_value = ()

w = PayloadWriter(writer, loop=loop)

req = Request(message, payload,
protocol, w, time_service, task,
protocol, payload_writer, time_service, task,
secure_proxy_ssl_header=secure_proxy_ssl_header,
client_max_size=client_max_size)

Expand Down
16 changes: 11 additions & 5 deletions aiohttp/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,28 @@

from yarl import URL

from . import (hdrs, web_exceptions, web_middlewares, web_request,
web_response, web_server, web_urldispatcher, web_ws)
from . import (hdrs, web_exceptions, web_fileresponse, web_middlewares,
web_protocol, web_request, web_response, web_server,
web_urldispatcher, web_ws)
from .abc import AbstractMatchInfo, AbstractRouter
from .helpers import FrozenList
from .http import HttpVersion # noqa
from .log import access_logger, web_logger
from .signals import PostSignal, PreSignal, Signal
from .web_exceptions import * # noqa
from .web_fileresponse import * # noqa
from .web_middlewares import * # noqa
from .web_protocol import * # noqa
from .web_request import * # noqa
from .web_response import * # noqa
from .web_server import Server
from .web_urldispatcher import * # noqa
from .web_urldispatcher import PrefixedSubAppResource
from .web_ws import * # noqa

__all__ = (web_request.__all__ +
__all__ = (web_protocol.__all__ +
web_fileresponse.__all__ +
web_request.__all__ +
web_response.__all__ +
web_exceptions.__all__ +
web_urldispatcher.__all__ +
Expand Down Expand Up @@ -222,10 +227,10 @@ def cleanup(self):
"""
yield from self.on_cleanup.send(self)

def _make_request(self, message, payload, protocol, writer,
def _make_request(self, message, payload, protocol, writer, task,
_cls=web_request.Request):
return _cls(
message, payload, protocol, writer, protocol._time_service, None,
message, payload, protocol, writer, protocol._time_service, task,
secure_proxy_ssl_header=self._secure_proxy_ssl_header,
client_max_size=self._client_max_size)

Expand All @@ -250,6 +255,7 @@ def _handle(self, request):
for app in match_info.apps:
for factory in app._middlewares:
handler = yield from factory(app, handler)

resp = yield from handler(request)

assert isinstance(resp, web_response.StreamResponse), \
Expand Down
69 changes: 42 additions & 27 deletions aiohttp/file_sender.py → aiohttp/web_fileresponse.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import mimetypes
import os
import pathlib

from . import hdrs
from .helpers import create_future
Expand All @@ -9,6 +10,8 @@
HTTPRequestRangeNotSatisfiable)
from .web_response import StreamResponse

__all__ = ('FileResponse',)


NOSENDFILE = bool(os.environ.get("AIOHTTP_NOSENDFILE"))

Expand Down Expand Up @@ -81,15 +84,20 @@ def write_eof(self, chunk=b''):
pass


class FileSender:
"""A helper that can be used to send files."""
class FileResponse(StreamResponse):
"""A response object can be used to send files."""

def __init__(self, path, chunk_size=256*1024, *args, **kwargs):
super().__init__(*args, **kwargs)

if isinstance(path, str):
path = pathlib.Path(path)

def __init__(self, *, resp_factory=StreamResponse, chunk_size=256*1024):
self._response_factory = resp_factory
self._path = path
self._chunk_size = chunk_size

@asyncio.coroutine
def _sendfile_system(self, request, resp, fobj, count):
def _sendfile_system(self, request, fobj, count):
# Write count bytes of fobj to resp using
# the os.sendfile system call.
#
Expand All @@ -103,14 +111,17 @@ def _sendfile_system(self, request, resp, fobj, count):

transport = request.transport
if transport.get_extra_info("sslcontext"):
yield from self._sendfile_fallback(request, resp, fobj, count)
writer = yield from self._sendfile_fallback(request, fobj, count)
else:
writer = yield from resp.prepare(
request, PayloadWriterFactory=SendfilePayloadWriter)
writer = request._writer.replace(SendfilePayloadWriter)
request._writer = writer
yield from super().prepare(request)
yield from writer.sendfile(fobj, count)

return writer

@asyncio.coroutine
def _sendfile_fallback(self, request, resp, fobj, count):
def _sendfile_fallback(self, request, fobj, count):
# Mimic the _sendfile_system() method, but without using the
# os.sendfile() system call. This should be used on systems
# that don't support the os.sendfile().
Expand All @@ -119,30 +130,33 @@ def _sendfile_fallback(self, request, resp, fobj, count):
# fobj is transferred in chunks controlled by the
# constructor's chunk_size argument.

yield from resp.prepare(request)
writer = (yield from super().prepare(request))

resp.set_tcp_cork(True)
self.set_tcp_cork(True)
try:
chunk_size = self._chunk_size

chunk = fobj.read(chunk_size)
while True:
yield from resp.write(chunk)
yield from writer.write(chunk)
count = count - chunk_size
if count <= 0:
break
chunk = fobj.read(min(chunk_size, count))
finally:
resp.set_tcp_nodelay(True)
self.set_tcp_nodelay(True)

yield from writer.drain()

if hasattr(os, "sendfile") and not NOSENDFILE: # pragma: no cover
_sendfile = _sendfile_system
else: # pragma: no cover
_sendfile = _sendfile_fallback

@asyncio.coroutine
def send(self, request, filepath):
"""Send filepath to client using request."""
def prepare(self, request):
filepath = self._path

gzip = False
if 'gzip' in request.headers.get(hdrs.ACCEPT_ENCODING, ''):
gzip_path = filepath.with_name(filepath.name + '.gz')
Expand All @@ -155,7 +169,8 @@ def send(self, request, filepath):

modsince = request.if_modified_since
if modsince is not None and st.st_mtime <= modsince.timestamp():
raise HTTPNotModified()
self.set_status(HTTPNotModified.status_code)
return (yield from super().prepare(request))

ct, encoding = mimetypes.guess_type(str(filepath))
if not ct:
Expand All @@ -170,7 +185,8 @@ def send(self, request, filepath):
start = rng.start
end = rng.stop
except ValueError:
raise HTTPRequestRangeNotSatisfiable
self.set_status(HTTPRequestRangeNotSatisfiable.status_code)
return (yield from super().prepare(request))

# If a range request has been made, convert start, end slice notation
# into file pointer offset and count
Expand All @@ -192,18 +208,17 @@ def send(self, request, filepath):
# the current length of the selected representation).
count = file_size - start

resp = self._response_factory(status=status)
resp.content_type = ct
self.set_status(status)
self.content_type = ct
if encoding:
resp.headers[hdrs.CONTENT_ENCODING] = encoding
self.headers[hdrs.CONTENT_ENCODING] = encoding
if gzip:
resp.headers[hdrs.VARY] = hdrs.ACCEPT_ENCODING
resp.last_modified = st.st_mtime
self.headers[hdrs.VARY] = hdrs.ACCEPT_ENCODING
self.last_modified = st.st_mtime
self.content_length = count

resp.content_length = count
with filepath.open('rb') as f:
with filepath.open('rb') as fobj:
if start:
f.seek(start)
yield from self._sendfile(request, resp, f, count)
fobj.seek(start)

return resp
return (yield from self._sendfile(request, fobj, count))
Loading

0 comments on commit 9cfa4fe

Please sign in to comment.