diff --git a/src/pyhf/tensor/manager.py b/src/pyhf/tensor/manager.py index 32dd46de53..830f3fb095 100644 --- a/src/pyhf/tensor/manager.py +++ b/src/pyhf/tensor/manager.py @@ -5,11 +5,7 @@ from pyhf import exceptions from pyhf import events from pyhf.optimize import OptimizerRetriever -from pyhf.typing import TensorBackend, Optimizer, TypedDict -from types import ModuleType - -_default_backend: TensorBackend = BackendRetriever.numpy_backend() -_default_optimizer: Optimizer = OptimizerRetriever.scipy_optimizer() # type: ignore[no-untyped-call] +from pyhf.typing import TensorBackend, Optimizer, TypedDict, Protocol class State(TypedDict): @@ -17,14 +13,15 @@ class State(TypedDict): current: tuple[TensorBackend, Optimizer] -class Module(ModuleType): - state: State = { - 'default': (_default_backend, _default_optimizer), - 'current': (_default_backend, _default_optimizer), - } +class HasState(Protocol): + state: State -this = Module(__name__) +this: HasState = sys.modules[__name__] +this.state = { + 'default': (None, None), # type: ignore[typeddict-item] + 'current': (None, None), # type: ignore[typeddict-item] +} def get_backend(default: bool = False) -> tuple[TensorBackend, Optimizer]: @@ -52,6 +49,13 @@ def get_backend(default: bool = False) -> tuple[TensorBackend, Optimizer]: return this.state['current'] +_default_backend: TensorBackend = BackendRetriever.numpy_backend() +_default_optimizer: Optimizer = OptimizerRetriever.scipy_optimizer() # type: ignore[no-untyped-call] + +this.state['default'] = (_default_backend, _default_optimizer) +this.state['current'] = this.state['default'] + + @events.register('change_backend') def set_backend( backend: str | bytes | TensorBackend, @@ -193,6 +197,3 @@ def set_backend( events.trigger("optimizer_changed")() # set up any other globals for backend new_backend._setup() - - -sys.modules[__name__] = this diff --git a/src/pyhf/typing.py b/src/pyhf/typing.py index 93b1b3fbe0..ee874d5b17 100644 --- a/src/pyhf/typing.py +++ b/src/pyhf/typing.py @@ -27,6 +27,7 @@ "Workspace", "Literal", "TypedDict", + "Protocol", )