Skip to content

Commit

Permalink
fix for later versions
Browse files Browse the repository at this point in the history
  • Loading branch information
thehesiod committed Jan 7, 2019
1 parent f055c55 commit a1fdf15
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 10 deletions.
4 changes: 2 additions & 2 deletions ddtrace/contrib/asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ def _trace_method(self, method, query, rowcount_method, extra_tags, *args,
return result

@asyncio.coroutine
def prepare(self, stmt_name, query, timeout):
def prepare(self, stmt_name, query, timeout, *args, **kwargs):
result = yield from self._trace_method(
self.__wrapped__.prepare, query, None, {},
stmt_name, query, timeout) # noqa: E999
stmt_name, query, timeout, *args, **kwargs) # noqa: E999
return result

@asyncio.coroutine
Expand Down
32 changes: 25 additions & 7 deletions ddtrace/contrib/asyncpg/patch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import inspect
import warnings

# 3p
from asyncpg.protocol import Protocol as orig_Protocol
Expand Down Expand Up @@ -81,14 +83,28 @@ def _patched_connect(connect_func, _, args, kwargs):
return conn


def _get_parsed_tags(*,
dsn=None, host=None, port=None, user=None, password=None,
database=None, ssl=None, connect_timeout=60, server_settings=None):
_connect_args = inspect.signature(asyncpg.connection.connect).parameters
_parse_connect_dsn_and_args_params = \
inspect.signature(
asyncpg.connect_utils._parse_connect_dsn_and_args).parameters

_connect_parse_arg_mapping = {
'connect_timeout': 'timeout'
}


def _get_parsed_tags(**connect_kwargs):
parse_args = dict()

for param_name, param in _parse_connect_dsn_and_args_params.items():
# Grab param from connect_kwargs
connect_param_name = _connect_parse_arg_mapping.get(param_name, param_name)
default = _connect_args[connect_param_name].default
parse_args[param_name] = connect_kwargs.get(param_name, default)

try:
addrs, params = asyncpg.connect_utils._parse_connect_dsn_and_args(
dsn=dsn, host=host, port=port, user=user, password=password,
database=database, ssl=ssl, connect_timeout=connect_timeout,
server_settings=server_settings)
addrs, params, *_ = asyncpg.connect_utils._parse_connect_dsn_and_args(
**parse_args)

tags = {
net.TARGET_HOST: addrs[0][0],
Expand Down Expand Up @@ -126,6 +142,8 @@ def _patched_acquire(acquire_func, instance, args, kwargs):
if len(connect_args) == 1:
kwargs_copy['dsn'] = connect_args[0]
tags = _get_parsed_tags(**kwargs_copy)
else:
warnings.warn("Unrecognized parameters to asyncpg connect")

pin = _create_pin(tags)

Expand Down
11 changes: 10 additions & 1 deletion tests/contrib/asyncpg/test_asyncpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def assert_conn_is_traced(self, tracer, db, service):
# Ensure we can run a query and it's correctly traced
q = 'select \'foobarblah\''
start = time.time()
rows = await db.fetch(q)
rows = await db.fetch(q, timeout=5)
end = time.time()
eq_(rows, [('foobarblah',)])
assert rows
Expand Down Expand Up @@ -110,6 +110,15 @@ async def assert_conn_is_traced(self, tracer, db, service):
eq_(span.meta['out.port'], TEST_PORT)
eq_(span.span_type, 'sql')

@mark_sync
async def test_pool_dsn(self):
Pin(None, tracer=self.tracer).onto(asyncpg)
dsn = 'postgresql://%(user)s:%(password)s@%(host)s:%(port)s/%(database)s' % POSTGRES_CONFIG
async with asyncpg.create_pool(dsn,
min_size=1, max_size=1) as pool:
async with pool.acquire() as conn:
await conn.execute('select 1;')

@mark_sync
async def test_copy_from(self):
# This test is here to ensure we don't break the params
Expand Down

0 comments on commit a1fdf15

Please sign in to comment.