diff --git a/opentelemetry-api/src/opentelemetry/context/__init__.py b/opentelemetry-api/src/opentelemetry/context/__init__.py index 1d1b53e7cb2..1ac837f51c5 100644 --- a/opentelemetry-api/src/opentelemetry/context/__init__.py +++ b/opentelemetry-api/src/opentelemetry/context/__init__.py @@ -14,6 +14,7 @@ import logging import typing +from functools import wraps from os import environ from sys import version_info @@ -25,6 +26,47 @@ _RUNTIME_CONTEXT = None # type: typing.Optional[RuntimeContext] +_F = typing.TypeVar("_F", bound=typing.Callable[..., typing.Any]) + + +def _load_runtime_context(func: _F) -> _F: + """A decorator used to initialize the global RuntimeContext + + Returns: + A wrapper of the decorated method. + """ + + @wraps(func) # type: ignore + def wrapper( + *args: typing.Tuple[typing.Any, typing.Any], + **kwargs: typing.Dict[typing.Any, typing.Any] + ) -> typing.Optional[typing.Any]: + global _RUNTIME_CONTEXT # pylint: disable=global-statement + if _RUNTIME_CONTEXT is None: + # FIXME use a better implementation of a configuration manager to avoid having + # to get configuration values straight from environment variables + if version_info < (3, 5): + # contextvars are not supported in 3.4, use thread-local storage + default_context = "threadlocal_context" + else: + default_context = "contextvars_context" + + configured_context = environ.get( + "OPENTELEMETRY_CONTEXT", default_context + ) # type: str + try: + _RUNTIME_CONTEXT = next( + iter_entry_points( + "opentelemetry_context", configured_context + ) + ).load()() + except Exception: # pylint: disable=broad-except + logger.error("Failed to load context: %s", configured_context) + return func(*args, **kwargs) # type: ignore + + return wrapper # type:ignore + + def get_value(key: str, context: typing.Optional[Context] = None) -> "object": """To access the local state of a concern, the RuntimeContext API provides a function which takes a context and a key as input, @@ -33,6 +75,9 @@ def get_value(key: str, context: typing.Optional[Context] = None) -> "object": Args: key: The key of the value to retrieve. context: The context from which to retrieve the value, if None, the current context is used. + + Returns: + The value associated with the key. """ return context.get(key) if context is not None else get_current().get(key) @@ -46,91 +91,55 @@ def set_value( which contains the new value. Args: - key: The key of the entry to set - value: The value of the entry to set - context: The context to copy, if None, the current context is used - """ - if context is None: - context = get_current() - new_values = context.copy() - new_values[key] = value - return Context(new_values) + key: The key of the entry to set. + value: The value of the entry to set. + context: The context to copy, if None, the current context is used. - -def remove_value( - key: str, context: typing.Optional[Context] = None -) -> Context: - """To remove a value, this method returns a new context with the key - cleared. Note that the removed value still remains present in the old - context. - - Args: - key: The key of the entry to remove - context: The context to copy, if None, the current context is used + Returns: + A new `Context` containing the value set. """ if context is None: context = get_current() new_values = context.copy() - new_values.pop(key, None) + new_values[key] = value return Context(new_values) +@_load_runtime_context # type: ignore def get_current() -> Context: """To access the context associated with program execution, - the RuntimeContext API provides a function which takes no arguments - and returns a RuntimeContext. - """ - - global _RUNTIME_CONTEXT # pylint: disable=global-statement - if _RUNTIME_CONTEXT is None: - # FIXME use a better implementation of a configuration manager to avoid having - # to get configuration values straight from environment variables - if version_info < (3, 5): - # contextvars are not supported in 3.4, use thread-local storage - default_context = "threadlocal_context" - else: - default_context = "contextvars_context" - - configured_context = environ.get( - "OPENTELEMETRY_CONTEXT", default_context - ) # type: str - try: - _RUNTIME_CONTEXT = next( - iter_entry_points("opentelemetry_context", configured_context) - ).load()() - except Exception: # pylint: disable=broad-except - logger.error("Failed to load context: %s", configured_context) + the Context API provides a function which takes no arguments + and returns a Context. + Returns: + The current `Context` object. + """ return _RUNTIME_CONTEXT.get_current() # type:ignore -def set_current(context: Context) -> Context: - """To associate a context with program execution, the Context - API provides a function which takes a Context. +@_load_runtime_context # type: ignore +def attach(context: Context) -> object: + """Associates a Context with the caller's current execution unit. Returns + a token that can be used to restore the previous Context. Args: - context: The context to use as current. - """ - old_context = get_current() - _RUNTIME_CONTEXT.set_current(context) # type:ignore - return old_context - + context: The Context to set as current. -def with_current_context( - func: typing.Callable[..., "object"] -) -> typing.Callable[..., "object"]: - """Capture the current context and apply it to the provided func.""" + Returns: + A token that can be used with `detach` to reset the context. + """ + return _RUNTIME_CONTEXT.attach(context) # type:ignore - caller_context = get_current() - def call_with_current_context( - *args: "object", **kwargs: "object" - ) -> "object": - try: - backup = get_current() - set_current(caller_context) - return func(*args, **kwargs) - finally: - set_current(backup) +@_load_runtime_context # type: ignore +def detach(token: object) -> None: + """Resets the Context associated with the caller's current execution unit + to the value it had before attaching a specified Context. - return call_with_current_context + Args: + token: The Token that was returned by a previous call to attach a Context. + """ + try: + _RUNTIME_CONTEXT.detach(token) # type: ignore + except Exception: # pylint: disable=broad-except + logger.error("Failed to detach context") diff --git a/opentelemetry-api/src/opentelemetry/context/context.py b/opentelemetry-api/src/opentelemetry/context/context.py index 148312a884c..1c7cfba9634 100644 --- a/opentelemetry-api/src/opentelemetry/context/context.py +++ b/opentelemetry-api/src/opentelemetry/context/context.py @@ -29,8 +29,9 @@ class RuntimeContext(ABC): """ @abstractmethod - def set_current(self, context: Context) -> None: - """ Sets the current `Context` object. + def attach(self, context: Context) -> object: + """ Sets the current `Context` object. Returns a + token that can be used to reset to the previous `Context`. Args: context: The Context to set. @@ -40,5 +41,13 @@ def set_current(self, context: Context) -> None: def get_current(self) -> Context: """ Returns the current `Context` object. """ + @abstractmethod + def detach(self, token: object) -> None: + """ Resets Context to a previous value + + Args: + token: A reference to a previous Context. + """ + __all__ = ["Context", "RuntimeContext"] diff --git a/opentelemetry-api/src/opentelemetry/context/contextvars_context.py b/opentelemetry-api/src/opentelemetry/context/contextvars_context.py index 1fd202275a3..0d075e0776a 100644 --- a/opentelemetry-api/src/opentelemetry/context/contextvars_context.py +++ b/opentelemetry-api/src/opentelemetry/context/contextvars_context.py @@ -35,13 +35,17 @@ def __init__(self) -> None: self._CONTEXT_KEY, default=Context() ) - def set_current(self, context: Context) -> None: - """See `opentelemetry.context.RuntimeContext.set_current`.""" - self._current_context.set(context) + def attach(self, context: Context) -> object: + """See `opentelemetry.context.RuntimeContext.attach`.""" + return self._current_context.set(context) def get_current(self) -> Context: """See `opentelemetry.context.RuntimeContext.get_current`.""" return self._current_context.get() + def detach(self, token: object) -> None: + """See `opentelemetry.context.RuntimeContext.detach`.""" + self._current_context.reset(token) # type: ignore + __all__ = ["ContextVarsRuntimeContext"] diff --git a/opentelemetry-api/src/opentelemetry/context/threadlocal_context.py b/opentelemetry-api/src/opentelemetry/context/threadlocal_context.py index 899ab863262..6a0e76bb693 100644 --- a/opentelemetry-api/src/opentelemetry/context/threadlocal_context.py +++ b/opentelemetry-api/src/opentelemetry/context/threadlocal_context.py @@ -23,14 +23,20 @@ class ThreadLocalRuntimeContext(RuntimeContext): implementation is available for usage with Python 3.4. """ + class Token: + def __init__(self, context: Context) -> None: + self._context = context + _CONTEXT_KEY = "current_context" def __init__(self) -> None: self._current_context = threading.local() - def set_current(self, context: Context) -> None: - """See `opentelemetry.context.RuntimeContext.set_current`.""" + def attach(self, context: Context) -> object: + """See `opentelemetry.context.RuntimeContext.attach`.""" + current = self.get_current() setattr(self._current_context, self._CONTEXT_KEY, context) + return self.Token(current) def get_current(self) -> Context: """See `opentelemetry.context.RuntimeContext.get_current`.""" @@ -43,5 +49,12 @@ def get_current(self) -> Context: ) # type: Context return context + def detach(self, token: object) -> None: + """See `opentelemetry.context.RuntimeContext.detach`.""" + if not isinstance(token, self.Token): + raise ValueError("invalid token") + # pylint: disable=protected-access + setattr(self._current_context, self._CONTEXT_KEY, token._context) + __all__ = ["ThreadLocalRuntimeContext"] diff --git a/opentelemetry-api/src/opentelemetry/distributedcontext/__init__.py b/opentelemetry-api/src/opentelemetry/distributedcontext/__init__.py index a89d9825502..dbc7b7e79bd 100644 --- a/opentelemetry-api/src/opentelemetry/distributedcontext/__init__.py +++ b/opentelemetry-api/src/opentelemetry/distributedcontext/__init__.py @@ -17,7 +17,7 @@ import typing from contextlib import contextmanager -from opentelemetry.context import get_value, set_current, set_value +from opentelemetry.context import attach, get_value, set_value from opentelemetry.context.context import Context PRINTABLE = frozenset( @@ -142,4 +142,4 @@ def distributed_context_from_context( def with_distributed_context( dctx: DistributedContext, context: typing.Optional[Context] = None ) -> None: - set_current(set_value(_DISTRIBUTED_CONTEXT_KEY, dctx, context=context)) + attach(set_value(_DISTRIBUTED_CONTEXT_KEY, dctx, context=context)) diff --git a/opentelemetry-api/tests/context/base_context.py b/opentelemetry-api/tests/context/base_context.py new file mode 100644 index 00000000000..66e6df97a2d --- /dev/null +++ b/opentelemetry-api/tests/context/base_context.py @@ -0,0 +1,77 @@ +# Copyright 2020, OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from logging import ERROR + +from opentelemetry import context + + +def do_work() -> None: + context.attach(context.set_value("say", "bar")) + + +class ContextTestCases: + class BaseTest(unittest.TestCase): + def setUp(self) -> None: + self.previous_context = context.get_current() + + def tearDown(self) -> None: + context.attach(self.previous_context) + + def test_context(self): + self.assertIsNone(context.get_value("say")) + empty = context.get_current() + second = context.set_value("say", "foo") + + self.assertEqual(context.get_value("say", context=second), "foo") + + do_work() + self.assertEqual(context.get_value("say"), "bar") + third = context.get_current() + + self.assertIsNone(context.get_value("say", context=empty)) + self.assertEqual(context.get_value("say", context=second), "foo") + self.assertEqual(context.get_value("say", context=third), "bar") + + def test_set_value(self): + first = context.set_value("a", "yyy") + second = context.set_value("a", "zzz") + third = context.set_value("a", "---", first) + self.assertEqual("yyy", context.get_value("a", context=first)) + self.assertEqual("zzz", context.get_value("a", context=second)) + self.assertEqual("---", context.get_value("a", context=third)) + self.assertEqual(None, context.get_value("a")) + + def test_attach(self): + context.attach(context.set_value("a", "yyy")) + + token = context.attach(context.set_value("a", "zzz")) + self.assertEqual("zzz", context.get_value("a")) + + context.detach(token) + self.assertEqual("yyy", context.get_value("a")) + + with self.assertLogs(level=ERROR): + context.detach("some garbage") + + def test_detach_out_of_order(self): + t1 = context.attach(context.set_value("c", 1)) + self.assertEqual(context.get_current(), {"c": 1}) + t2 = context.attach(context.set_value("c", 2)) + self.assertEqual(context.get_current(), {"c": 2}) + context.detach(t1) + self.assertEqual(context.get_current(), {}) + context.detach(t2) + self.assertEqual(context.get_current(), {"c": 1}) diff --git a/opentelemetry-api/tests/context/test_context.py b/opentelemetry-api/tests/context/test_context.py index 2536e5149be..8942a333ed6 100644 --- a/opentelemetry-api/tests/context/test_context.py +++ b/opentelemetry-api/tests/context/test_context.py @@ -19,12 +19,12 @@ def do_work() -> None: - context.set_current(context.set_value("say", "bar")) + context.attach(context.set_value("say", "bar")) class TestContext(unittest.TestCase): def setUp(self): - context.set_current(Context()) + context.attach(Context()) def test_context(self): self.assertIsNone(context.get_value("say")) @@ -55,11 +55,10 @@ def test_context_is_immutable(self): context.get_current()["test"] = "cant-change-immutable" def test_set_current(self): - context.set_current(context.set_value("a", "yyy")) + context.attach(context.set_value("a", "yyy")) - old_context = context.set_current(context.set_value("a", "zzz")) - self.assertEqual("yyy", context.get_value("a", context=old_context)) + token = context.attach(context.set_value("a", "zzz")) self.assertEqual("zzz", context.get_value("a")) - context.set_current(old_context) + context.detach(token) self.assertEqual("yyy", context.get_value("a")) diff --git a/opentelemetry-api/tests/context/test_contextvars_context.py b/opentelemetry-api/tests/context/test_contextvars_context.py index ebc15d6d9a3..d19ac5ca126 100644 --- a/opentelemetry-api/tests/context/test_contextvars_context.py +++ b/opentelemetry-api/tests/context/test_contextvars_context.py @@ -17,6 +17,8 @@ from opentelemetry import context +from .base_context import ContextTestCases + try: import contextvars # pylint: disable=unused-import from opentelemetry.context.contextvars_context import ( @@ -26,43 +28,14 @@ raise unittest.SkipTest("contextvars not available") -def do_work() -> None: - context.set_current(context.set_value("say", "bar")) - - -class TestContextVarsContext(unittest.TestCase): - def setUp(self): - self.previous_context = context.get_current() - - def tearDown(self): - context.set_current(self.previous_context) - - @patch( - "opentelemetry.context._RUNTIME_CONTEXT", ContextVarsRuntimeContext() # type: ignore - ) - def test_context(self): - self.assertIsNone(context.get_value("say")) - empty = context.get_current() - second = context.set_value("say", "foo") - - self.assertEqual(context.get_value("say", context=second), "foo") - - do_work() - self.assertEqual(context.get_value("say"), "bar") - third = context.get_current() +class TestContextVarsContext(ContextTestCases.BaseTest): + def setUp(self) -> None: + super(TestContextVarsContext, self).setUp() + self.mock_runtime = patch.object( + context, "_RUNTIME_CONTEXT", ContextVarsRuntimeContext(), + ) + self.mock_runtime.start() - self.assertIsNone(context.get_value("say", context=empty)) - self.assertEqual(context.get_value("say", context=second), "foo") - self.assertEqual(context.get_value("say", context=third), "bar") - - @patch( - "opentelemetry.context._RUNTIME_CONTEXT", ContextVarsRuntimeContext() # type: ignore - ) - def test_set_value(self): - first = context.set_value("a", "yyy") - second = context.set_value("a", "zzz") - third = context.set_value("a", "---", first) - self.assertEqual("yyy", context.get_value("a", context=first)) - self.assertEqual("zzz", context.get_value("a", context=second)) - self.assertEqual("---", context.get_value("a", context=third)) - self.assertEqual(None, context.get_value("a")) + def tearDown(self) -> None: + super(TestContextVarsContext, self).tearDown() + self.mock_runtime.stop() diff --git a/opentelemetry-api/tests/context/test_threadlocal_context.py b/opentelemetry-api/tests/context/test_threadlocal_context.py index aca6b69de72..342163020ed 100644 --- a/opentelemetry-api/tests/context/test_threadlocal_context.py +++ b/opentelemetry-api/tests/context/test_threadlocal_context.py @@ -12,50 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest from unittest.mock import patch from opentelemetry import context from opentelemetry.context.threadlocal_context import ThreadLocalRuntimeContext +from .base_context import ContextTestCases -def do_work() -> None: - context.set_current(context.set_value("say", "bar")) +class TestThreadLocalContext(ContextTestCases.BaseTest): + def setUp(self) -> None: + super(TestThreadLocalContext, self).setUp() + self.mock_runtime = patch.object( + context, "_RUNTIME_CONTEXT", ThreadLocalRuntimeContext(), + ) + self.mock_runtime.start() -class TestThreadLocalContext(unittest.TestCase): - def setUp(self): - self.previous_context = context.get_current() - - def tearDown(self): - context.set_current(self.previous_context) - - @patch( - "opentelemetry.context._RUNTIME_CONTEXT", ThreadLocalRuntimeContext() # type: ignore - ) - def test_context(self): - self.assertIsNone(context.get_value("say")) - empty = context.get_current() - second = context.set_value("say", "foo") - - self.assertEqual(context.get_value("say", context=second), "foo") - - do_work() - self.assertEqual(context.get_value("say"), "bar") - third = context.get_current() - - self.assertIsNone(context.get_value("say", context=empty)) - self.assertEqual(context.get_value("say", context=second), "foo") - self.assertEqual(context.get_value("say", context=third), "bar") - - @patch( - "opentelemetry.context._RUNTIME_CONTEXT", ThreadLocalRuntimeContext() # type: ignore - ) - def test_set_value(self): - first = context.set_value("a", "yyy") - second = context.set_value("a", "zzz") - third = context.set_value("a", "---", first) - self.assertEqual("yyy", context.get_value("a", context=first)) - self.assertEqual("zzz", context.get_value("a", context=second)) - self.assertEqual("---", context.get_value("a", context=third)) - self.assertEqual(None, context.get_value("a")) + def tearDown(self) -> None: + super(TestThreadLocalContext, self).tearDown() + self.mock_runtime.stop() diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py b/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py index 7ce0ea3a836..dd0169ea9f7 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py @@ -543,12 +543,11 @@ def use_span( ) -> Iterator[trace_api.Span]: """See `opentelemetry.trace.Tracer.use_span`.""" try: - context_snapshot = context_api.get_current() - context_api.set_current(context_api.set_value(SPAN_KEY, span)) + token = context_api.attach(context_api.set_value(SPAN_KEY, span)) try: yield span finally: - context_api.set_current(context_snapshot) + context_api.detach(token) except Exception as error: # pylint: disable=broad-except if ( diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/trace/export/__init__.py b/opentelemetry-sdk/src/opentelemetry/sdk/trace/export/__init__.py index 0a1b1c8041d..0f96808ea88 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/trace/export/__init__.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/trace/export/__init__.py @@ -19,7 +19,7 @@ import typing from enum import Enum -from opentelemetry.context import get_current, set_current, set_value +from opentelemetry.context import attach, detach, get_current, set_value from opentelemetry.trace import DefaultSpan from opentelemetry.util import time_ns @@ -75,14 +75,13 @@ def on_start(self, span: Span) -> None: pass def on_end(self, span: Span) -> None: - backup_context = get_current() - set_current(set_value("suppress_instrumentation", True)) + token = attach(set_value("suppress_instrumentation", True)) try: self.span_exporter.export((span,)) # pylint: disable=broad-except except Exception: logger.exception("Exception while exporting Span.") - set_current(backup_context) + detach(token) def shutdown(self) -> None: self.span_exporter.shutdown() @@ -202,8 +201,7 @@ def export(self) -> None: else: self.spans_list[idx] = span idx += 1 - backup_context = get_current() - set_current(set_value("suppress_instrumentation", True)) + token = attach(set_value("suppress_instrumentation", True)) try: # Ignore type b/c the Optional[None]+slicing is too "clever" # for mypy @@ -211,7 +209,7 @@ def export(self) -> None: # pylint: disable=broad-except except Exception: logger.exception("Exception while exporting Span batch.") - set_current(backup_context) + detach(token) if notify_flush: with self.flush_condition: diff --git a/opentelemetry-sdk/tests/context/test_asyncio.py b/opentelemetry-sdk/tests/context/test_asyncio.py index 22773a80cd6..ea7ebbddbf8 100644 --- a/opentelemetry-sdk/tests/context/test_asyncio.py +++ b/opentelemetry-sdk/tests/context/test_asyncio.py @@ -63,8 +63,7 @@ def submit_another_task(self, name): self.loop.create_task(self.task(name)) def setUp(self): - self.previous_context = context.get_current() - context.set_current(context.Context()) + self.token = context.attach(context.Context()) self.tracer_provider = trace.TracerProvider() self.tracer = self.tracer_provider.get_tracer(__name__) self.memory_exporter = InMemorySpanExporter() @@ -73,7 +72,7 @@ def setUp(self): self.loop = asyncio.get_event_loop() def tearDown(self): - context.set_current(self.previous_context) + context.detach(self.token) @patch( "opentelemetry.context._RUNTIME_CONTEXT", ContextVarsRuntimeContext()