A template repository that showcases a streamlined approach to wrapping PyTorch models from research repositories into easy to use objects with clear, typed APIs.
model.py
: wraps the original model as a PyTorch module with an enhanced APImanager.py
: wraps the original model with an enhanced inference-only APIcheckpoint.py
: provides easy reference to checkpointscheckpoints/
: checkpoint files1port/
: files from the original repository after tweaks such as import renaming2.source/
: unchanged files of the original repository3
As it is (understandably) often the case with research code, the forward
method of the toy model to wrap has a rather obscure signature:
class SourceModel:
def forward(self, x, y, a=1, b=0.0):
...
Instead, the Model
wrapper from PyTorch Mediator has a well documented API that is easy to use:
class Model(torch.nn.Module):
def forward(
self,
foo: torch.Tensor, # B, 3, H, W
bar: torch.Tensor, # B, 3, H, W
optional_args: "Model.OptionalArgs" = OptionalArgs(),
) -> "Model.Output":
"""
Run inference with the model.
:param foo: First input image.
:param bar: Second input image.
:param optional_args: Optional arguments.
:return: Output of the model.
"""
...
The classes Model.OptionalArgs
and Model.Output
appearing in the signature are also appropriately documented.
The Manager
wrapper also provides a nice, typed API for inference. Moreover, this class also provides a convenient factory method for easy instantiation:
class Manager:
@staticmethod
def create(
checkpoint: Checkpoint = Checkpoint.v1_1,
device: torch.device = torch.device("cuda"),
half: bool = False,
) -> "Manager":
"""
Create a new manager for the model.
:param checkpoint: Checkpoint to load.
:param device: Device to load the model on.
:param half: Whether to use half precision.
:return: New manager.
"""
...
Footnotes
-
You will probably want to use DVC or similar tools for large checkpoints ↩
-
My suggestion is to only fix what is strictly needed, i.e. imports, bugs, and possibly unoptimized code ↩
-
Having read-only source files always at hand is useful when we want to verify what changed compared to the original code via tools like
diff
↩