Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Generator Wrappers #890

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 47 additions & 39 deletions newrelic/common/async_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
is_generator_function,
is_async_generator_function,
)
from newrelic.packages import six


def evaluate_wrapper(wrapper_string, wrapped, trace):
Expand Down Expand Up @@ -61,56 +62,63 @@ def wrapper(*args, **kwargs):
return wrapped


def generator_wrapper(wrapped, trace):
@functools.wraps(wrapped)
def wrapper(*args, **kwargs):
g = wrapped(*args, **kwargs)
value = None
with trace:
while True:
if six.PY3:
def generator_wrapper(wrapped, trace):
WRAPPER = textwrap.dedent("""
@functools.wraps(wrapped)
def wrapper(*args, **kwargs):
with trace:
result = yield from wrapped(*args, **kwargs)
return result
""")

try:
return evaluate_wrapper(WRAPPER, wrapped, trace)
except:
return wrapped
else:
def generator_wrapper(wrapped, trace):
@functools.wraps(wrapped)
def wrapper(*args, **kwargs):
g = wrapped(*args, **kwargs)
with trace:
try:
yielded = g.send(value)
yielded = g.send(None)
while True:
try:
sent = yield yielded
except GeneratorExit as e:
g.close()
raise
except BaseException as e:
yielded = g.throw(e)
else:
yielded = g.send(sent)
except StopIteration:
break

try:
value = yield yielded
except BaseException as e:
value = yield g.throw(type(e), e)

return wrapper
return
return wrapper


def async_generator_wrapper(wrapped, trace):
WRAPPER = textwrap.dedent("""
@functools.wraps(wrapped)
async def wrapper(*args, **kwargs):
g = wrapped(*args, **kwargs)
value = None
with trace:
while True:
try:
yielded = await g.asend(value)
except StopAsyncIteration as e:
# The underlying async generator has finished, return propagates a new StopAsyncIteration
return
except StopIteration as e:
# The call to async_generator_asend.send() should raise a StopIteration containing the yielded value
yielded = e.value

try:
value = yield yielded
except BaseException as e:
# An exception was thrown with .athrow(), propagate to the original async generator.
# Return value logic must be identical to .asend()
try:
yielded = await g.asend(None)
while True:
try:
value = yield await g.athrow(type(e), e)
except StopAsyncIteration as e:
# The underlying async generator has finished, return propagates a new StopAsyncIteration
return
except StopIteration as e:
# The call to async_generator_athrow.send() should raise a StopIteration containing a yielded value
value = yield e.value
sent = yield yielded
except GeneratorExit as e:
await g.aclose()
raise
except BaseException as e:
yielded = await g.athrow(e)
else:
yielded = await g.asend(sent)
except StopAsyncIteration:
return
""")

try:
Expand Down
35 changes: 35 additions & 0 deletions tests/agent_features/_test_async_generator_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,41 @@ async def _test():
event_loop.run_until_complete(_test())


@validate_transaction_metrics(
"test_multiple_throws_yield_a_value",
background_task=True,
scoped_metrics=[("Function/agen", 1)],
rollup_metrics=[("Function/agen", 1)],
)
def test_multiple_throws_yield_a_value(event_loop):
@function_trace(name="agen")
async def agen():
value = None
for _ in range(4):
try:
yield value
value = "bar"
except MyException:
value = "foo"


@background_task(name="test_multiple_throws_yield_a_value")
async def _test():
gen = agen()

# kickstart the coroutine
assert await gen.asend(None) is None
assert await gen.athrow(MyException) == "foo"
assert await gen.athrow(MyException) == "foo"
assert await gen.asend(None) == "bar"

# finish consumption of the coroutine if necessary
async for _ in gen:
pass

event_loop.run_until_complete(_test())


@validate_transaction_metrics(
"test_athrow_does_not_yield_a_value",
background_task=True,
Expand Down
31 changes: 31 additions & 0 deletions tests/agent_features/test_coroutine_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,37 @@ def coro():
pass


@validate_transaction_metrics(
"test_multiple_throws_yield_a_value",
background_task=True,
scoped_metrics=[("Function/coro", 1)],
rollup_metrics=[("Function/coro", 1)],
)
@background_task(name="test_multiple_throws_yield_a_value")
def test_multiple_throws_yield_a_value():
@function_trace(name="coro")
def coro():
value = None
for _ in range(4):
try:
yield value
value = "bar"
except MyException:
value = "foo"

c = coro()

# kickstart the coroutine
assert next(c) is None
assert c.throw(MyException) == "foo"
assert c.throw(MyException) == "foo"
assert next(c) == "bar"

# finish consumption of the coroutine if necessary
for _ in c:
pass


@pytest.mark.parametrize(
"trace",
[
Expand Down