-
Notifications
You must be signed in to change notification settings - Fork 518
/
backend.py
201 lines (167 loc) · 4.78 KB
/
backend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
abstractmethod,
)
from enum import (
Flag,
auto,
)
from typing import (
TYPE_CHECKING,
Callable,
ClassVar,
)
from deepmd.utils.plugin import (
PluginVariant,
make_plugin_registry,
)
if TYPE_CHECKING:
from argparse import (
Namespace,
)
from deepmd.infer.deep_eval import (
DeepEvalBackend,
)
from deepmd.utils.neighbor_stat import (
NeighborStat,
)
class Backend(PluginVariant, make_plugin_registry("backend")):
r"""General backend class.
Examples
--------
>>> @Backend.register("tf")
>>> @Backend.register("tensorflow")
>>> class TensorFlowBackend(Backend):
... pass
"""
@staticmethod
def get_backend(key: str) -> type["Backend"]:
"""Get the backend by key.
Parameters
----------
key : str
the key of a backend
Returns
-------
Backend
the backend
"""
return Backend.get_class_by_type(key)
@staticmethod
def get_backends() -> dict[str, type["Backend"]]:
"""Get all the registered backend names.
Returns
-------
list
all the registered backends
"""
return Backend.get_plugins()
@staticmethod
def get_backends_by_feature(
feature: "Backend.Feature",
) -> dict[str, type["Backend"]]:
"""Get all the registered backend names with a specific feature.
Parameters
----------
feature : Backend.Feature
the feature flag
Returns
-------
list
all the registered backends with the feature
"""
return {
key: backend
for key, backend in Backend.get_backends().items()
if backend.features & feature
}
@staticmethod
def detect_backend_by_model(filename: str) -> type["Backend"]:
"""Detect the backend of the given model file.
Parameters
----------
filename : str
The model file name
"""
filename = str(filename).lower()
for backend in Backend.get_backends().values():
for suffix in backend.suffixes:
if filename.endswith(suffix):
return backend
raise ValueError(f"Cannot detect the backend of the model file {filename}.")
class Feature(Flag):
"""Feature flag to indicate whether the backend supports certain features."""
ENTRY_POINT = auto()
"""Support entry point hook."""
DEEP_EVAL = auto()
"""Support Deep Eval backend."""
NEIGHBOR_STAT = auto()
"""Support neighbor statistics."""
IO = auto()
"""Support IO hook."""
name: ClassVar[str] = "Unknown"
"""The formal name of the backend.
To be consistent, this name should be also registered in the plugin system."""
features: ClassVar[Feature] = Feature(0)
"""The features of the backend."""
suffixes: ClassVar[list[str]] = []
"""The supported suffixes of the saved model.
The first element is considered as the default suffix."""
@abstractmethod
def is_available(self) -> bool:
"""Check if the backend is available.
Returns
-------
bool
Whether the backend is available.
"""
@property
@abstractmethod
def entry_point_hook(self) -> Callable[["Namespace"], None]:
"""The entry point hook of the backend.
Returns
-------
Callable[[Namespace], None]
The entry point hook of the backend.
"""
pass
@property
@abstractmethod
def deep_eval(self) -> type["DeepEvalBackend"]:
"""The Deep Eval backend of the backend.
Returns
-------
type[DeepEvalBackend]
The Deep Eval backend of the backend.
"""
pass
@property
@abstractmethod
def neighbor_stat(self) -> type["NeighborStat"]:
"""The neighbor statistics of the backend.
Returns
-------
type[NeighborStat]
The neighbor statistics of the backend.
"""
pass
@property
@abstractmethod
def serialize_hook(self) -> Callable[[str], dict]:
"""The serialize hook to convert the model file to a dictionary.
Returns
-------
Callable[[str], dict]
The serialize hook of the backend.
"""
pass
@property
@abstractmethod
def deserialize_hook(self) -> Callable[[str, dict], None]:
"""The deserialize hook to convert the dictionary to a model file.
Returns
-------
Callable[[str, dict], None]
The deserialize hook of the backend.
"""
pass