Skip to content

Commit

Permalink
Type hint edits for #57 (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
c24t authored and reyang committed Jul 25, 2019
1 parent 123887e commit 1f93a36
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 33 deletions.
2 changes: 1 addition & 1 deletion mypy-relaxed.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
[mypy]
disallow_any_unimported = True
; disallow_any_expr = True
; disallow_any_decorated = True
disallow_any_decorated = True
; disallow_any_explicit = True
disallow_any_generics = True
disallow_subclassing_any = True
Expand Down
2 changes: 1 addition & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[mypy]
disallow_any_unimported = True
disallow_any_expr = True
; disallow_any_decorated = True
disallow_any_decorated = True
; disallow_any_explicit = True
disallow_any_generics = True
disallow_subclassing_any = True
Expand Down
4 changes: 3 additions & 1 deletion opentelemetry-api/src/opentelemetry/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

from .base_context import BaseRuntimeContext


__all__ = ['Context']

Context: typing.Union[BaseRuntimeContext, None] = None

Context: typing.Optional[BaseRuntimeContext]

try:
from .async_context import AsyncRuntimeContext
Expand Down
19 changes: 10 additions & 9 deletions opentelemetry-api/src/opentelemetry/context/async_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from contextvars import ContextVar
import typing
from contextvars import ContextVar

from .base_context import BaseRuntimeContext
from . import base_context


class AsyncRuntimeContext(BaseRuntimeContext):
class Slot(BaseRuntimeContext.Slot):
def __init__(self, name: str, default: typing.Any):
class AsyncRuntimeContext(base_context.BaseRuntimeContext):
class Slot(base_context.BaseRuntimeContext.Slot):
def __init__(self, name: str, default: 'object'):
# pylint: disable=super-init-not-called
self.name = name
self.contextvar: typing.Any = ContextVar(name)
self.default = default if callable(default) else (lambda: default)
self.contextvar: 'ContextVar[object]' = ContextVar(name)
self.default: typing.Callable[..., object]
self.default = base_context.wrap_callable(default)

def clear(self) -> None:
self.contextvar.set(self.default())

def get(self) -> typing.Any:
def get(self) -> 'object':
try:
return self.contextvar.get()
except LookupError:
value = self.default()
self.set(value)
return value

def set(self, value: typing.Any) -> None:
def set(self, value: 'object') -> None:
self.contextvar.set(value)
32 changes: 19 additions & 13 deletions opentelemetry-api/src/opentelemetry/context/base_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,24 @@
import typing


def wrap_callable(target: 'object') -> typing.Callable[[], object]:
if callable(target):
return target
return lambda: target


class BaseRuntimeContext:
class Slot:
def __init__(self, name: str, default: typing.Any):
def __init__(self, name: str, default: 'object'):
raise NotImplementedError

def clear(self) -> None:
raise NotImplementedError

def get(self) -> typing.Any:
def get(self) -> 'object':
raise NotImplementedError

def set(self, value: typing.Any) -> None:
def set(self, value: 'object') -> None:
raise NotImplementedError

_lock = threading.Lock()
Expand All @@ -42,7 +48,7 @@ def clear(cls) -> None:
slot.clear()

@classmethod
def register_slot(cls, name: str, default: typing.Any = None) -> 'Slot':
def register_slot(cls, name: str, default: 'object' = None) -> 'Slot':
"""Register a context slot with an optional default value.
:type name: str
Expand All @@ -60,13 +66,13 @@ def register_slot(cls, name: str, default: typing.Any = None) -> 'Slot':
cls._slots[name] = slot
return slot

def apply(self, snapshot: typing.Dict[str, typing.Any]) -> None:
def apply(self, snapshot: typing.Dict[str, 'object']) -> None:
"""Set the current context from a given snapshot dictionary"""

for name in snapshot:
setattr(self, name, snapshot[name])

def snapshot(self) -> typing.Dict[str, typing.Any]:
def snapshot(self) -> typing.Dict[str, 'object']:
"""Return a dictionary of current slots by reference."""

keys = self._slots.keys()
Expand All @@ -75,31 +81,31 @@ def snapshot(self) -> typing.Dict[str, typing.Any]:
def __repr__(self) -> str:
return '{}({})'.format(type(self).__name__, self.snapshot())

def __getattr__(self, name: str) -> typing.Any:
def __getattr__(self, name: str) -> 'object':
if name not in self._slots:
self.register_slot(name, None)
slot = self._slots[name]
return slot.get()

def __setattr__(self, name: str, value: typing.Any) -> None:
def __setattr__(self, name: str, value: 'object') -> None:
if name not in self._slots:
self.register_slot(name, None)
slot = self._slots[name]
slot.set(value)

def with_current_context(
self,
func: typing.Callable[..., typing.Any],
) -> typing.Callable[..., typing.Any]:
func: typing.Callable[..., 'object'],
) -> typing.Callable[..., 'object']:
"""Capture the current context and apply it to the provided func.
"""

caller_context = self.snapshot()

def call_with_current_context(
*args: typing.Any,
**kwargs: typing.Any,
) -> typing.Any:
*args: 'object',
**kwargs: 'object',
) -> 'object':
try:
backup_context = self.snapshot()
self.apply(caller_context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,30 @@
import threading
import typing

from .base_context import BaseRuntimeContext
from . import base_context


class ThreadLocalRuntimeContext(BaseRuntimeContext):
class Slot(BaseRuntimeContext.Slot):
class ThreadLocalRuntimeContext(base_context.BaseRuntimeContext):
class Slot(base_context.BaseRuntimeContext.Slot):
_thread_local = threading.local()

def __init__(self, name: str, default: typing.Any):
def __init__(self, name: str, default: 'object'):
# pylint: disable=super-init-not-called
self.name = name
self.default = default if callable(default) else (lambda: default)
self.default: typing.Callable[..., object]
self.default = base_context.wrap_callable(default)

def clear(self) -> None:
setattr(self._thread_local, self.name, self.default())

def get(self) -> typing.Any:
def get(self) -> 'object':
try:
return getattr(self._thread_local, self.name)
got: object = getattr(self._thread_local, self.name)
return got
except AttributeError:
value = self.default()
self.set(value)
return value

def set(self, value: typing.Any) -> None:
def set(self, value: 'object') -> None:
setattr(self._thread_local, self.name, value)

0 comments on commit 1f93a36

Please sign in to comment.