Skip to content

Commit

Permalink
Use custom classes to pass client signals parameters #2686
Browse files Browse the repository at this point in the history
This will allow aiohttp development add new parameters in the future without
break the signals signature.
  • Loading branch information
pfreixes committed Jan 30, 2018
1 parent 5397413 commit 071cf45
Show file tree
Hide file tree
Showing 7 changed files with 318 additions and 103 deletions.
1 change: 1 addition & 0 deletions CHANGES/2686.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use custom classes to pass client signals parameters
14 changes: 7 additions & 7 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,14 +660,14 @@ async def _resolve_host(self, host, port, traces=None):

if traces:
for trace in traces:
await trace.send_dns_resolvehost_start()
await trace.send_dns_resolvehost_start(host)

res = (await self._resolver.resolve(
host, port, family=self._family))

if traces:
for trace in traces:
await trace.send_dns_resolvehost_end()
await trace.send_dns_resolvehost_end(host)

return res

Expand All @@ -678,26 +678,26 @@ async def _resolve_host(self, host, port, traces=None):

if traces:
for trace in traces:
await trace.send_dns_cache_hit()
await trace.send_dns_cache_hit(host)

return self._cached_hosts.next_addrs(key)

if key in self._throttle_dns_events:
if traces:
for trace in traces:
await trace.send_dns_cache_hit()
await trace.send_dns_cache_hit(host)
await self._throttle_dns_events[key].wait()
else:
if traces:
for trace in traces:
await trace.send_dns_cache_miss()
await trace.send_dns_cache_miss(host)
self._throttle_dns_events[key] = \
EventResultOrError(self._loop)
try:

if traces:
for trace in traces:
await trace.send_dns_resolvehost_start()
await trace.send_dns_resolvehost_start(host)

addrs = await \
asyncio.shield(self._resolver.resolve(host,
Expand All @@ -706,7 +706,7 @@ async def _resolve_host(self, host, port, traces=None):
loop=self._loop)
if traces:
for trace in traces:
await trace.send_dns_resolvehost_end()
await trace.send_dns_resolvehost_end(host)

self._cached_hosts.add(key, addrs)
self._throttle_dns_events[key].set()
Expand Down
163 changes: 123 additions & 40 deletions aiohttp/tracing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
from types import SimpleNamespace

import attr
from multidict import CIMultiDict
from yarl import URL

from .client_reqrep import ClientResponse
from .signals import Signal


__all__ = ('TraceConfig',)
__all__ = (
'TraceConfig', 'TraceRequestStartParams', 'TraceRequestEndParams',
'TraceRequestExceptionParams', 'TraceConnectionQueuedStartParams',
'TraceConnectionQueuedEndParams', 'TraceConnectionCreateStartParams',
'TraceConnectionCreateEndParams', 'TraceConnectionReuseconnParams',
'TraceDnsResolveHostStartParams', 'TraceDnsResolveHostEndParams',
'TraceDnsCacheHitParams', 'TraceDnsCacheMissParams'
)


class TraceConfig:
Expand Down Expand Up @@ -100,6 +112,90 @@ def on_dns_cache_miss(self):
return self._on_dns_cache_miss


@attr.s(frozen=True, slots=True)
class TraceRequestStartParams:
""" Parameters sent by the `on_request_start` signal"""
method = attr.ib(type=str)
url = attr.ib(type=URL)
headers = attr.ib(type=CIMultiDict)


@attr.s(frozen=True, slots=True)
class TraceRequestEndParams:
""" Parameters sent by the `on_request_end` signal"""
method = attr.ib(type=str)
url = attr.ib(type=URL)
headers = attr.ib(type=CIMultiDict)
resp = attr.ib(type=ClientResponse)


@attr.s(frozen=True, slots=True)
class TraceRequestExceptionParams:
""" Parameters sent by the `on_request_exception` signal"""
method = attr.ib(type=str)
url = attr.ib(type=URL)
headers = attr.ib(type=CIMultiDict)
exception = attr.ib(type=Exception)


@attr.s(frozen=True, slots=True)
class TraceRequestRedirectParams:
""" Parameters sent by the `on_request_redirect` signal"""
method = attr.ib(type=str)
url = attr.ib(type=URL)
headers = attr.ib(type=CIMultiDict)
resp = attr.ib(type=ClientResponse)


@attr.s(frozen=True, slots=True)
class TraceConnectionQueuedStartParams:
""" Parameters sent by the `on_connection_queued_start` signal"""


@attr.s(frozen=True, slots=True)
class TraceConnectionQueuedEndParams:
""" Parameters sent by the `on_connection_queued_end` signal"""


@attr.s(frozen=True, slots=True)
class TraceConnectionCreateStartParams:
""" Parameters sent by the `on_connection_create_start` signal"""


@attr.s(frozen=True, slots=True)
class TraceConnectionCreateEndParams:
""" Parameters sent by the `on_connection_create_end` signal"""


@attr.s(frozen=True, slots=True)
class TraceConnectionReuseconnParams:
""" Parameters sent by the `on_connection_reuseconn` signal"""


@attr.s(frozen=True, slots=True)
class TraceDnsResolveHostStartParams:
""" Parameters sent by the `on_dns_resolvehost_start` signal"""
host = attr.ib(type=str)


@attr.s(frozen=True, slots=True)
class TraceDnsResolveHostEndParams:
""" Parameters sent by the `on_dns_resolvehost_end` signal"""
host = attr.ib(type=str)


@attr.s(frozen=True, slots=True)
class TraceDnsCacheHitParams:
""" Parameters sent by the `on_dns_cache_hit` signal"""
host = attr.ib(type=str)


@attr.s(frozen=True, slots=True)
class TraceDnsCacheMissParams:
""" Parameters sent by the `on_dns_cache_miss` signal"""
host = attr.ib(type=str)


class Trace:
""" Internal class used to keep together the main dependencies used
at the moment of send a signal."""
Expand All @@ -109,106 +205,93 @@ def __init__(self, session, trace_config, trace_config_ctx):
self._trace_config_ctx = trace_config_ctx
self._session = session

async def send_request_start(self, *args, **kwargs):
async def send_request_start(self, method, url, headers):
return await self._trace_config.on_request_start.send(
self._session,
self._trace_config_ctx,
*args,
**kwargs
TraceRequestStartParams(method, url, headers)
)

async def send_request_end(self, *args, **kwargs):
async def send_request_end(self, method, url, headers, response):
return await self._trace_config.on_request_end.send(
self._session,
self._trace_config_ctx,
*args,
**kwargs
TraceRequestEndParams(method, url, headers, response)
)

async def send_request_exception(self, *args, **kwargs):
async def send_request_exception(self, method, url, headers, exception):
return await self._trace_config.on_request_exception.send(
self._session,
self._trace_config_ctx,
*args,
**kwargs
TraceRequestExceptionParams(method, url, headers, exception)
)

async def send_request_redirect(self, *args, **kwargs):
async def send_request_redirect(self, method, url, headers, response):
return await self._trace_config._on_request_redirect.send(
self._session,
self._trace_config_ctx,
*args,
**kwargs
TraceRequestRedirectParams(method, url, headers, response)
)

async def send_connection_queued_start(self, *args, **kwargs):
async def send_connection_queued_start(self):
return await self._trace_config.on_connection_queued_start.send(
self._session,
self._trace_config_ctx,
*args,
**kwargs
TraceConnectionQueuedStartParams()
)

async def send_connection_queued_end(self, *args, **kwargs):
async def send_connection_queued_end(self):
return await self._trace_config.on_connection_queued_end.send(
self._session,
self._trace_config_ctx,
*args,
**kwargs
TraceConnectionQueuedEndParams()
)

async def send_connection_create_start(self, *args, **kwargs):
async def send_connection_create_start(self):
return await self._trace_config.on_connection_create_start.send(
self._session,
self._trace_config_ctx,
*args,
**kwargs
TraceConnectionCreateStartParams()
)

async def send_connection_create_end(self, *args, **kwargs):
async def send_connection_create_end(self):
return await self._trace_config.on_connection_create_end.send(
self._session,
self._trace_config_ctx,
*args,
**kwargs
TraceConnectionCreateEndParams()
)

async def send_connection_reuseconn(self, *args, **kwargs):
async def send_connection_reuseconn(self):
return await self._trace_config.on_connection_reuseconn.send(
self._session,
self._trace_config_ctx,
*args,
**kwargs
TraceConnectionReuseconnParams()
)

async def send_dns_resolvehost_start(self, *args, **kwargs):
async def send_dns_resolvehost_start(self, host):
return await self._trace_config.on_dns_resolvehost_start.send(
self._session,
self._trace_config_ctx,
*args,
**kwargs
TraceDnsResolveHostStartParams(host)
)

async def send_dns_resolvehost_end(self, *args, **kwargs):
async def send_dns_resolvehost_end(self, host):
return await self._trace_config.on_dns_resolvehost_end.send(
self._session,
self._trace_config_ctx,
*args,
**kwargs
TraceDnsResolveHostEndParams(host)
)

async def send_dns_cache_hit(self, *args, **kwargs):
async def send_dns_cache_hit(self, host):
return await self._trace_config.on_dns_cache_hit.send(
self._session,
self._trace_config_ctx,
*args,
**kwargs
TraceDnsCacheHitParams(host)
)

async def send_dns_cache_miss(self, *args, **kwargs):
async def send_dns_cache_miss(self, host):
return await self._trace_config.on_dns_cache_miss.send(
self._session,
self._trace_config_ctx,
*args,
**kwargs
TraceDnsCacheMissParams(host)
)
10 changes: 5 additions & 5 deletions docs/client_advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,10 @@ disabled. The following snippet shows how the start and the end
signals of a request flow can be followed::

async def on_request_start(
session, trace_config_ctx, method, host, port, headers):
session, trace_config_ctx, params):
print("Starting request")

async def on_request_end(session, trace_config_ctx, resp):
async def on_request_end(session, trace_config_ctx, params):
print("Ending request")

trace_config = aiohttp.TraceConfig()
Expand Down Expand Up @@ -259,10 +259,10 @@ share the state through to the different signals that belong to the
same request and to the same :class:`TraceConfig` class, perhaps::

async def on_request_start(
session, trace_config_ctx, method, host, port, headers):
session, trace_config_ctx, params):
trace_config_ctx.start = session.loop.time()

async def on_request_end(session, trace_config_ctx, resp):
async def on_request_end(session, trace_config_ctx, params):
elapsed = session.loop.time() - trace_config_ctx.start
print("Request took {}".format(elapsed))

Expand All @@ -280,7 +280,7 @@ factory. This param is useful to pass data that is only available at
request time, perhaps::

async def on_request_start(
session, trace_config_ctx, method, host, port, headers):
session, trace_config_ctx, params):
print(trace_config_ctx.trace_request_ctx)


Expand Down
Loading

0 comments on commit 071cf45

Please sign in to comment.