diff --git a/.pylintrc b/.pylintrc index ab11620d772..4f37a7e10f0 100644 --- a/.pylintrc +++ b/.pylintrc @@ -166,7 +166,7 @@ notes=FIXME, # List of decorators that produce context managers, such as # contextlib.contextmanager. Add to this list to register other decorators that # produce valid context managers. -contextmanager-decorators=contextlib.contextmanager +contextmanager-decorators=contextlib.contextmanager, _agnosticcontextmanager # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular diff --git a/CHANGELOG.md b/CHANGELOG.md index 50e6f2125d1..ca3d0687bf1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +- Make `tracer.start_as_current_span()` decorator work with async functions + ([#3633](https://github.com/open-telemetry/opentelemetry-python/pull/3633)) - Fix python 3.12 deprecation warning ([#3751](https://github.com/open-telemetry/opentelemetry-python/pull/3751)) - bump mypy to 0.982 diff --git a/opentelemetry-api/src/opentelemetry/trace/__init__.py b/opentelemetry-api/src/opentelemetry/trace/__init__.py index 8910fd27518..3b6295e259d 100644 --- a/opentelemetry-api/src/opentelemetry/trace/__init__.py +++ b/opentelemetry-api/src/opentelemetry/trace/__init__.py @@ -77,7 +77,6 @@ import os import typing from abc import ABC, abstractmethod -from contextlib import contextmanager from enum import Enum from logging import getLogger from typing import Iterator, Optional, Sequence, cast @@ -109,6 +108,7 @@ ) from opentelemetry.trace.status import Status, StatusCode from opentelemetry.util import types +from opentelemetry.util._decorator import _agnosticcontextmanager from opentelemetry.util._once import Once from opentelemetry.util._providers import _load_provider @@ -324,7 +324,7 @@ def start_span( The newly-created span. """ - @contextmanager + @_agnosticcontextmanager @abstractmethod def start_as_current_span( self, @@ -431,7 +431,7 @@ def _tracer(self) -> Tracer: def start_span(self, *args, **kwargs) -> Span: # type: ignore return self._tracer.start_span(*args, **kwargs) # type: ignore - @contextmanager # type: ignore + @_agnosticcontextmanager # type: ignore def start_as_current_span(self, *args, **kwargs) -> Iterator[Span]: with self._tracer.start_as_current_span(*args, **kwargs) as span: # type: ignore yield span @@ -457,7 +457,7 @@ def start_span( # pylint: disable=unused-argument,no-self-use return INVALID_SPAN - @contextmanager + @_agnosticcontextmanager def start_as_current_span( self, name: str, @@ -543,7 +543,7 @@ def get_tracer_provider() -> TracerProvider: return cast("TracerProvider", _TRACER_PROVIDER) -@contextmanager +@_agnosticcontextmanager def use_span( span: Span, end_on_exit: bool = False, diff --git a/opentelemetry-api/src/opentelemetry/util/_decorator.py b/opentelemetry-api/src/opentelemetry/util/_decorator.py new file mode 100644 index 00000000000..233f29ff79d --- /dev/null +++ b/opentelemetry-api/src/opentelemetry/util/_decorator.py @@ -0,0 +1,81 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import contextlib +import functools +import typing +from typing import Callable, Generic, Iterator, TypeVar + +V = TypeVar("V") +R = TypeVar("R") # Return type +Pargs = TypeVar("Pargs") # Generic type for arguments +Pkwargs = TypeVar("Pkwargs") # Generic type for arguments + +if hasattr(typing, "ParamSpec"): + # only available in python 3.10+ + # https://peps.python.org/pep-0612/ + P = typing.ParamSpec("P") # Generic type for all arguments + + +class _AgnosticContextManager( + contextlib._GeneratorContextManager, Generic[R] # type: ignore # FIXME use contextlib._GeneratorContextManager[R] when we drop the python 3.8 support +): # pylint: disable=protected-access + """Context manager that can decorate both async and sync functions. + + This is an overridden version of the contextlib._GeneratorContextManager + class that will decorate async functions with an async context manager + to end the span AFTER the entire async function coroutine finishes. + + Else it will report near zero spans durations for async functions. + + We are overriding the contextlib._GeneratorContextManager class as + reimplementing it is a lot of code to maintain and this class (even if it's + marked as protected) doesn't seems like to be evolving a lot. + + For more information, see: + https://github.com/open-telemetry/opentelemetry-python/pull/3633 + """ + + def __enter__(self) -> R: + """Reimplementing __enter__ to avoid the type error. + + The original __enter__ method returns Any type, but we want to return R. + """ + del self.args, self.kwds, self.func # type: ignore + try: + return next(self.gen) # type: ignore + except StopIteration: + raise RuntimeError("generator didn't yield") from None + + def __call__(self, func: V) -> V: + if asyncio.iscoroutinefunction(func): + + @functools.wraps(func) # type: ignore + async def async_wrapper(*args: Pargs, **kwargs: Pkwargs) -> R: + with self._recreate_cm(): # type: ignore + return await func(*args, **kwargs) # type: ignore + + return async_wrapper # type: ignore + return super().__call__(func) # type: ignore + + +def _agnosticcontextmanager( + func: "Callable[P, Iterator[R]]", +) -> "Callable[P, _AgnosticContextManager[R]]": + @functools.wraps(func) + def helper(*args: Pargs, **kwargs: Pkwargs) -> _AgnosticContextManager[R]: + return _AgnosticContextManager(func, args, kwargs) + + return helper diff --git a/opentelemetry-api/tests/trace/test_proxy.py b/opentelemetry-api/tests/trace/test_proxy.py index e48a2157aec..8c367afb6da 100644 --- a/opentelemetry-api/tests/trace/test_proxy.py +++ b/opentelemetry-api/tests/trace/test_proxy.py @@ -15,7 +15,6 @@ # pylint: disable=W0212,W0222,W0221 import typing import unittest -from contextlib import contextmanager from opentelemetry import trace from opentelemetry.test.globals_test import TraceGlobalsTest @@ -24,6 +23,7 @@ NonRecordingSpan, Span, ) +from opentelemetry.util._decorator import _agnosticcontextmanager class TestProvider(trace.NoOpTracerProvider): @@ -40,7 +40,7 @@ class TestTracer(trace.NoOpTracer): def start_span(self, *args, **kwargs): return TestSpan(INVALID_SPAN_CONTEXT) - @contextmanager + @_agnosticcontextmanager # pylint: disable=protected-access def start_as_current_span(self, *args, **kwargs): # type: ignore with trace.use_span(self.start_span(*args, **kwargs)) as span: # type: ignore yield span diff --git a/opentelemetry-api/tests/trace/test_tracer.py b/opentelemetry-api/tests/trace/test_tracer.py index a7ad589ae60..fae836d564f 100644 --- a/opentelemetry-api/tests/trace/test_tracer.py +++ b/opentelemetry-api/tests/trace/test_tracer.py @@ -13,15 +13,15 @@ # limitations under the License. -from contextlib import contextmanager +import asyncio from unittest import TestCase -from unittest.mock import Mock from opentelemetry.trace import ( INVALID_SPAN, NoOpTracer, Span, Tracer, + _agnosticcontextmanager, get_current_span, ) @@ -39,29 +39,42 @@ def test_start_as_current_span_context_manager(self): self.assertIsInstance(span, Span) def test_start_as_current_span_decorator(self): - - mock_call = Mock() + # using a list to track the mock call order + calls = [] class MockTracer(Tracer): def start_span(self, *args, **kwargs): return INVALID_SPAN - @contextmanager + @_agnosticcontextmanager # pylint: disable=protected-access def start_as_current_span(self, *args, **kwargs): # type: ignore - mock_call() + calls.append(1) yield INVALID_SPAN + calls.append(9) mock_tracer = MockTracer() + # test 1 : sync function @mock_tracer.start_as_current_span("name") - def function(): # type: ignore - pass + def function_sync(data: str) -> int: + calls.append(5) + return len(data) - function() # type: ignore - function() # type: ignore - function() # type: ignore + calls = [] + res = function_sync("123") + self.assertEqual(res, 3) + self.assertEqual(calls, [1, 5, 9]) - self.assertEqual(mock_call.call_count, 3) + # test 2 : async function + @mock_tracer.start_as_current_span("name") + async def function_async(data: str) -> int: + calls.append(5) + return len(data) + + calls = [] + res = asyncio.run(function_async("123")) + self.assertEqual(res, 3) + self.assertEqual(calls, [1, 5, 9]) def test_get_current_span(self): with self.tracer.start_as_current_span("test") as span: diff --git a/opentelemetry-api/tests/util/test_contextmanager.py b/opentelemetry-api/tests/util/test_contextmanager.py new file mode 100644 index 00000000000..f26882c6c79 --- /dev/null +++ b/opentelemetry-api/tests/util/test_contextmanager.py @@ -0,0 +1,68 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import unittest +from typing import Callable, Iterator + +from opentelemetry.util._decorator import _agnosticcontextmanager + + +@_agnosticcontextmanager +def cm() -> Iterator[int]: + yield 3 + + +@_agnosticcontextmanager +def cm_call_when_done(f: Callable[[], None]) -> Iterator[int]: + yield 3 + f() + + +class TestContextManager(unittest.TestCase): + def test_sync_with(self): + with cm() as val: + self.assertEqual(val, 3) + + def test_decorate_sync_func(self): + @cm() + def sync_func(a: str) -> str: + return a + a + + res = sync_func("a") + self.assertEqual(res, "aa") + + def test_decorate_async_func(self): + # Test that a universal context manager decorating an async function runs it's cleanup + # code after the entire async function coroutine finishes. This silently fails when + # using the normal @contextmanager decorator, which runs it's __exit__() after the + # un-started coroutine is returned. + # + # To see this behavior, change cm_call_when_done() to + # be decorated with @contextmanager. + + events = [] + + @cm_call_when_done(lambda: events.append("cm_done")) + async def async_func(a: str) -> str: + events.append("start_async_func") + await asyncio.sleep(0) + events.append("finish_sleep") + return a + a + + res = asyncio.run(async_func("a")) + self.assertEqual(res, "aa") + self.assertEqual( + events, ["start_async_func", "finish_sleep", "cm_done"] + ) diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py b/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py index f937ece8958..7c0a194c82a 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py @@ -21,7 +21,6 @@ import threading import traceback import typing -from contextlib import contextmanager from os import environ from time import time_ns from types import MappingProxyType, TracebackType @@ -66,6 +65,7 @@ from opentelemetry.trace import SpanContext from opentelemetry.trace.status import Status, StatusCode from opentelemetry.util import types +from opentelemetry.util._decorator import _agnosticcontextmanager logger = logging.getLogger(__name__) @@ -1038,7 +1038,7 @@ def __init__( self._span_limits = span_limits self._instrumentation_scope = instrumentation_scope - @contextmanager + @_agnosticcontextmanager # pylint: disable=protected-access def start_as_current_span( self, name: str,