Skip to content

Commit

Permalink
Make tracer.start_as_current_span() decorator work with async funct…
Browse files Browse the repository at this point in the history
…ions (#3633)

* test: add a minimal test that reproduce the bug

Signed-off-by: QuentinN42 <[email protected]>

* fix: replaced contextmanager by my own agnosticcontextmanager

* docs: changelog

Signed-off-by: QuentinN42 <[email protected]>

* docs: typo

Signed-off-by: QuentinN42 <[email protected]>

* fix: linting

Signed-off-by: QuentinN42 <[email protected]>

* feat: reimplement the contextlib._GeneratorContextManager inside the trace api

Signed-off-by: QuentinN42 <[email protected]>

* refactor: merge tests

Signed-off-by: QuentinN42 <[email protected]>

* fix: make agnosticcontextmanager as protected

Signed-off-by: QuentinN42 <[email protected]>

* fix: changelog

Signed-off-by: QuentinN42 <[email protected]>

* feat: start typing

Signed-off-by: QuentinN42 <[email protected]>

* fix: revert to acdfa03 implementation and fix mypy

Signed-off-by: QuentinN42 <[email protected]>

* fix: use super call on the synchronous branch

Co-authored-by: Aaron Abbott <[email protected]>

* fix: typing compliant with 3.8

Signed-off-by: QuentinN42 <[email protected]>

* fix: black

Signed-off-by: QuentinN42 <[email protected]>

* feat: added tests and fixed lint in python 3.10

Signed-off-by: QuentinN42 <[email protected]>

* feat: use_span use the _agnosticcontextmanager

Signed-off-by: QuentinN42 <[email protected]>

* docs: explain why we have an overriden class

Signed-off-by: QuentinN42 <[email protected]>

* fix: use typing.Generic for pre 3.9 compatibility

Signed-off-by: QuentinN42 <[email protected]>

* fix: typo

Signed-off-by: QuentinN42 <[email protected]>

* fix: mypy api/src

Signed-off-by: QuentinN42 <[email protected]>

* fix: ignore reference to privat attributes

Signed-off-by: QuentinN42 <[email protected]>

* fix: mv cm inside test

Signed-off-by: QuentinN42 <[email protected]>

* fix: define __call__ as Coroutine and not awaitable

Signed-off-by: QuentinN42 <[email protected]>

* fix: mypy green

Signed-off-by: QuentinN42 <[email protected]>

* fix: py38 tests ok

Signed-off-by: QuentinN42 <[email protected]>

* fix: reimplementing __enter__ to avoid the type error.

Signed-off-by: QuentinN42 <[email protected]>

* fix: lint

Signed-off-by: QuentinN42 <[email protected]>

* fix: mypy

Signed-off-by: QuentinN42 <[email protected]>

* test: rm test_wraps_contextlib

Signed-off-by: QuentinN42 <[email protected]>

* docs: document why we are overriding the contextlib._GeneratorContextManager class

Signed-off-by: QuentinN42 <[email protected]>

* docs: mv feat to unreleased section

Signed-off-by: QuentinN42 <[email protected]>

* test: rename lst with a more readable name

Signed-off-by: QuentinN42 <[email protected]>

* Remove unused type ignore comment

* Fix missing symbol

The missing symbol error was caused by a rebase on main and subsequent
force push by me, sorry.

---------

Signed-off-by: QuentinN42 <[email protected]>
Co-authored-by: Aaron Abbott <[email protected]>
Co-authored-by: Diego Hurtado <[email protected]>
  • Loading branch information
3 people authored Mar 20, 2024
1 parent d6321d6 commit 5a6da15
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ notes=FIXME,
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager
contextmanager-decorators=contextlib.contextmanager, _agnosticcontextmanager

# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

- Make `tracer.start_as_current_span()` decorator work with async functions
([#3633](https://github.com/open-telemetry/opentelemetry-python/pull/3633))
- Fix python 3.12 deprecation warning
([#3751](https://github.com/open-telemetry/opentelemetry-python/pull/3751))
- bump mypy to 0.982
Expand Down
10 changes: 5 additions & 5 deletions opentelemetry-api/src/opentelemetry/trace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@
import os
import typing
from abc import ABC, abstractmethod
from contextlib import contextmanager
from enum import Enum
from logging import getLogger
from typing import Iterator, Optional, Sequence, cast
Expand Down Expand Up @@ -109,6 +108,7 @@
)
from opentelemetry.trace.status import Status, StatusCode
from opentelemetry.util import types
from opentelemetry.util._decorator import _agnosticcontextmanager
from opentelemetry.util._once import Once
from opentelemetry.util._providers import _load_provider

Expand Down Expand Up @@ -324,7 +324,7 @@ def start_span(
The newly-created span.
"""

@contextmanager
@_agnosticcontextmanager
@abstractmethod
def start_as_current_span(
self,
Expand Down Expand Up @@ -431,7 +431,7 @@ def _tracer(self) -> Tracer:
def start_span(self, *args, **kwargs) -> Span: # type: ignore
return self._tracer.start_span(*args, **kwargs) # type: ignore

@contextmanager # type: ignore
@_agnosticcontextmanager # type: ignore
def start_as_current_span(self, *args, **kwargs) -> Iterator[Span]:
with self._tracer.start_as_current_span(*args, **kwargs) as span: # type: ignore
yield span
Expand All @@ -457,7 +457,7 @@ def start_span(
# pylint: disable=unused-argument,no-self-use
return INVALID_SPAN

@contextmanager
@_agnosticcontextmanager
def start_as_current_span(
self,
name: str,
Expand Down Expand Up @@ -543,7 +543,7 @@ def get_tracer_provider() -> TracerProvider:
return cast("TracerProvider", _TRACER_PROVIDER)


@contextmanager
@_agnosticcontextmanager
def use_span(
span: Span,
end_on_exit: bool = False,
Expand Down
81 changes: 81 additions & 0 deletions opentelemetry-api/src/opentelemetry/util/_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import contextlib
import functools
import typing
from typing import Callable, Generic, Iterator, TypeVar

V = TypeVar("V")
R = TypeVar("R") # Return type
Pargs = TypeVar("Pargs") # Generic type for arguments
Pkwargs = TypeVar("Pkwargs") # Generic type for arguments

if hasattr(typing, "ParamSpec"):
# only available in python 3.10+
# https://peps.python.org/pep-0612/
P = typing.ParamSpec("P") # Generic type for all arguments


class _AgnosticContextManager(
contextlib._GeneratorContextManager, Generic[R] # type: ignore # FIXME use contextlib._GeneratorContextManager[R] when we drop the python 3.8 support
): # pylint: disable=protected-access
"""Context manager that can decorate both async and sync functions.
This is an overridden version of the contextlib._GeneratorContextManager
class that will decorate async functions with an async context manager
to end the span AFTER the entire async function coroutine finishes.
Else it will report near zero spans durations for async functions.
We are overriding the contextlib._GeneratorContextManager class as
reimplementing it is a lot of code to maintain and this class (even if it's
marked as protected) doesn't seems like to be evolving a lot.
For more information, see:
https://github.com/open-telemetry/opentelemetry-python/pull/3633
"""

def __enter__(self) -> R:
"""Reimplementing __enter__ to avoid the type error.
The original __enter__ method returns Any type, but we want to return R.
"""
del self.args, self.kwds, self.func # type: ignore
try:
return next(self.gen) # type: ignore
except StopIteration:
raise RuntimeError("generator didn't yield") from None

def __call__(self, func: V) -> V:
if asyncio.iscoroutinefunction(func):

@functools.wraps(func) # type: ignore
async def async_wrapper(*args: Pargs, **kwargs: Pkwargs) -> R:
with self._recreate_cm(): # type: ignore
return await func(*args, **kwargs) # type: ignore

return async_wrapper # type: ignore
return super().__call__(func) # type: ignore


def _agnosticcontextmanager(
func: "Callable[P, Iterator[R]]",
) -> "Callable[P, _AgnosticContextManager[R]]":
@functools.wraps(func)
def helper(*args: Pargs, **kwargs: Pkwargs) -> _AgnosticContextManager[R]:
return _AgnosticContextManager(func, args, kwargs)

return helper
4 changes: 2 additions & 2 deletions opentelemetry-api/tests/trace/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# pylint: disable=W0212,W0222,W0221
import typing
import unittest
from contextlib import contextmanager

from opentelemetry import trace
from opentelemetry.test.globals_test import TraceGlobalsTest
Expand All @@ -24,6 +23,7 @@
NonRecordingSpan,
Span,
)
from opentelemetry.util._decorator import _agnosticcontextmanager


class TestProvider(trace.NoOpTracerProvider):
Expand All @@ -40,7 +40,7 @@ class TestTracer(trace.NoOpTracer):
def start_span(self, *args, **kwargs):
return TestSpan(INVALID_SPAN_CONTEXT)

@contextmanager
@_agnosticcontextmanager # pylint: disable=protected-access
def start_as_current_span(self, *args, **kwargs): # type: ignore
with trace.use_span(self.start_span(*args, **kwargs)) as span: # type: ignore
yield span
Expand Down
37 changes: 25 additions & 12 deletions opentelemetry-api/tests/trace/test_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
# limitations under the License.


from contextlib import contextmanager
import asyncio
from unittest import TestCase
from unittest.mock import Mock

from opentelemetry.trace import (
INVALID_SPAN,
NoOpTracer,
Span,
Tracer,
_agnosticcontextmanager,
get_current_span,
)

Expand All @@ -39,29 +39,42 @@ def test_start_as_current_span_context_manager(self):
self.assertIsInstance(span, Span)

def test_start_as_current_span_decorator(self):

mock_call = Mock()
# using a list to track the mock call order
calls = []

class MockTracer(Tracer):
def start_span(self, *args, **kwargs):
return INVALID_SPAN

@contextmanager
@_agnosticcontextmanager # pylint: disable=protected-access
def start_as_current_span(self, *args, **kwargs): # type: ignore
mock_call()
calls.append(1)
yield INVALID_SPAN
calls.append(9)

mock_tracer = MockTracer()

# test 1 : sync function
@mock_tracer.start_as_current_span("name")
def function(): # type: ignore
pass
def function_sync(data: str) -> int:
calls.append(5)
return len(data)

function() # type: ignore
function() # type: ignore
function() # type: ignore
calls = []
res = function_sync("123")
self.assertEqual(res, 3)
self.assertEqual(calls, [1, 5, 9])

self.assertEqual(mock_call.call_count, 3)
# test 2 : async function
@mock_tracer.start_as_current_span("name")
async def function_async(data: str) -> int:
calls.append(5)
return len(data)

calls = []
res = asyncio.run(function_async("123"))
self.assertEqual(res, 3)
self.assertEqual(calls, [1, 5, 9])

def test_get_current_span(self):
with self.tracer.start_as_current_span("test") as span:
Expand Down
68 changes: 68 additions & 0 deletions opentelemetry-api/tests/util/test_contextmanager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import unittest
from typing import Callable, Iterator

from opentelemetry.util._decorator import _agnosticcontextmanager


@_agnosticcontextmanager
def cm() -> Iterator[int]:
yield 3


@_agnosticcontextmanager
def cm_call_when_done(f: Callable[[], None]) -> Iterator[int]:
yield 3
f()


class TestContextManager(unittest.TestCase):
def test_sync_with(self):
with cm() as val:
self.assertEqual(val, 3)

def test_decorate_sync_func(self):
@cm()
def sync_func(a: str) -> str:
return a + a

res = sync_func("a")
self.assertEqual(res, "aa")

def test_decorate_async_func(self):
# Test that a universal context manager decorating an async function runs it's cleanup
# code after the entire async function coroutine finishes. This silently fails when
# using the normal @contextmanager decorator, which runs it's __exit__() after the
# un-started coroutine is returned.
#
# To see this behavior, change cm_call_when_done() to
# be decorated with @contextmanager.

events = []

@cm_call_when_done(lambda: events.append("cm_done"))
async def async_func(a: str) -> str:
events.append("start_async_func")
await asyncio.sleep(0)
events.append("finish_sleep")
return a + a

res = asyncio.run(async_func("a"))
self.assertEqual(res, "aa")
self.assertEqual(
events, ["start_async_func", "finish_sleep", "cm_done"]
)
4 changes: 2 additions & 2 deletions opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import threading
import traceback
import typing
from contextlib import contextmanager
from os import environ
from time import time_ns
from types import MappingProxyType, TracebackType
Expand Down Expand Up @@ -66,6 +65,7 @@
from opentelemetry.trace import SpanContext
from opentelemetry.trace.status import Status, StatusCode
from opentelemetry.util import types
from opentelemetry.util._decorator import _agnosticcontextmanager

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1038,7 +1038,7 @@ def __init__(
self._span_limits = span_limits
self._instrumentation_scope = instrumentation_scope

@contextmanager
@_agnosticcontextmanager # pylint: disable=protected-access
def start_as_current_span(
self,
name: str,
Expand Down

0 comments on commit 5a6da15

Please sign in to comment.