Skip to content

Commit

Permalink
fix: replaced contextmanager by my own agnosticcontextmanager
Browse files Browse the repository at this point in the history
  • Loading branch information
QuentinN42 committed Jan 28, 2024
1 parent 33741dc commit acdfa03
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 11 deletions.
37 changes: 34 additions & 3 deletions opentelemetry-api/src/opentelemetry/trace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@
"""


import asyncio
import contextlib
import functools
import os
import typing
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -116,6 +119,34 @@
logger = getLogger(__name__)


class _AgnosticContextManager(contextlib._GeneratorContextManager):
def __call__(self, func):
if asyncio.iscoroutinefunction(func):

@functools.wraps(func)
async def async_wrapper(*args, **kwargs):
with self._recreate_cm():
return await func(*args, **kwargs)

return async_wrapper
else:

@functools.wraps(func)
def wrapper(*args, **kwargs):
with self._recreate_cm():
return func(*args, **kwargs)

return wrapper


def agnosticcontextmanager(func):
@functools.wraps(func)
def helper(*args, **kwds):
return _AgnosticContextManager(func, args, kwds)

return helper


class _LinkBase(ABC):
def __init__(self, context: "SpanContext") -> None:
self._context = context
Expand Down Expand Up @@ -327,7 +358,7 @@ def start_span(
The newly-created span.
"""

@contextmanager
@agnosticcontextmanager
@abstractmethod
def start_as_current_span(
self,
Expand Down Expand Up @@ -434,7 +465,7 @@ def _tracer(self) -> Tracer:
def start_span(self, *args, **kwargs) -> Span: # type: ignore
return self._tracer.start_span(*args, **kwargs) # type: ignore

@contextmanager # type: ignore
@agnosticcontextmanager # type: ignore
def start_as_current_span(self, *args, **kwargs) -> Iterator[Span]: # type: ignore
with self._tracer.start_as_current_span(*args, **kwargs) as span: # type: ignore
yield span
Expand All @@ -460,7 +491,7 @@ def start_span(
# pylint: disable=unused-argument,no-self-use
return INVALID_SPAN

@contextmanager
@agnosticcontextmanager
def start_as_current_span(
self,
name: str,
Expand Down
3 changes: 1 addition & 2 deletions opentelemetry-api/tests/trace/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# pylint: disable=W0212,W0222,W0221
import typing
import unittest
from contextlib import contextmanager

from opentelemetry import trace
from opentelemetry.test.globals_test import TraceGlobalsTest
Expand All @@ -40,7 +39,7 @@ class TestTracer(trace.NoOpTracer):
def start_span(self, *args, **kwargs):
return TestSpan(INVALID_SPAN_CONTEXT)

@contextmanager
@trace.agnosticcontextmanager
def start_as_current_span(self, *args, **kwargs): # type: ignore
with trace.use_span(self.start_span(*args, **kwargs)) as span: # type: ignore
yield span
Expand Down
8 changes: 4 additions & 4 deletions opentelemetry-api/tests/trace/test_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
# limitations under the License.


import time
import asyncio
from contextlib import contextmanager
import time
from unittest import TestCase
from unittest.mock import Mock

Expand All @@ -24,6 +23,7 @@
NoOpTracer,
Span,
Tracer,
agnosticcontextmanager,
get_current_span,
)

Expand All @@ -47,7 +47,7 @@ class MockTracer(Tracer):
def start_span(self, *args, **kwargs):
return INVALID_SPAN

@contextmanager
@agnosticcontextmanager
def start_as_current_span(self, *args, **kwargs): # type: ignore
mock_call()
yield INVALID_SPAN
Expand Down Expand Up @@ -80,7 +80,7 @@ class MockTracer(Tracer):
def start_span(self, *args, **kwargs):
return INVALID_SPAN

@contextmanager
@agnosticcontextmanager
def start_as_current_span(self, *args, **kwargs): # type: ignore
mock_call()
i = time.monotonic()
Expand Down
3 changes: 1 addition & 2 deletions opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import threading
import traceback
import typing
from contextlib import contextmanager
from os import environ
from time import time_ns
from types import MappingProxyType, TracebackType
Expand Down Expand Up @@ -1014,7 +1013,7 @@ def __init__(
self._span_limits = span_limits
self._instrumentation_scope = instrumentation_scope

@contextmanager
@trace_api.agnosticcontextmanager
def start_as_current_span(
self,
name: str,
Expand Down

0 comments on commit acdfa03

Please sign in to comment.