Skip to content

Commit

Permalink
force HasState Protocol instead
Browse files Browse the repository at this point in the history
  • Loading branch information
kratsg authored and matthewfeickert committed Sep 4, 2022
1 parent d3faff6 commit 1a204e3
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
29 changes: 15 additions & 14 deletions src/pyhf/tensor/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,23 @@
from pyhf import exceptions
from pyhf import events
from pyhf.optimize import OptimizerRetriever
from pyhf.typing import TensorBackend, Optimizer, TypedDict
from types import ModuleType

_default_backend: TensorBackend = BackendRetriever.numpy_backend()
_default_optimizer: Optimizer = OptimizerRetriever.scipy_optimizer() # type: ignore[no-untyped-call]
from pyhf.typing import TensorBackend, Optimizer, TypedDict, Protocol


class State(TypedDict):
default: tuple[TensorBackend, Optimizer]
current: tuple[TensorBackend, Optimizer]


class Module(ModuleType):
state: State = {
'default': (_default_backend, _default_optimizer),
'current': (_default_backend, _default_optimizer),
}
class HasState(Protocol):
state: State


this = Module(__name__)
this: HasState = sys.modules[__name__]
this.state = {
'default': (None, None), # type: ignore[typeddict-item]
'current': (None, None), # type: ignore[typeddict-item]
}


def get_backend(default: bool = False) -> tuple[TensorBackend, Optimizer]:
Expand Down Expand Up @@ -52,6 +49,13 @@ def get_backend(default: bool = False) -> tuple[TensorBackend, Optimizer]:
return this.state['current']


_default_backend: TensorBackend = BackendRetriever.numpy_backend()
_default_optimizer: Optimizer = OptimizerRetriever.scipy_optimizer() # type: ignore[no-untyped-call]

this.state['default'] = (_default_backend, _default_optimizer)
this.state['current'] = this.state['default']


@events.register('change_backend')
def set_backend(
backend: str | bytes | TensorBackend,
Expand Down Expand Up @@ -193,6 +197,3 @@ def set_backend(
events.trigger("optimizer_changed")()
# set up any other globals for backend
new_backend._setup()


sys.modules[__name__] = this
1 change: 1 addition & 0 deletions src/pyhf/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"Workspace",
"Literal",
"TypedDict",
"Protocol",
)


Expand Down

0 comments on commit 1a204e3

Please sign in to comment.