diff --git a/cell2location/models/_cell2location_model.py b/cell2location/models/_cell2location_model.py index 36afa834..2fac5313 100755 --- a/cell2location/models/_cell2location_model.py +++ b/cell2location/models/_cell2location_model.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd import scanpy +import torch from anndata import AnnData from pyro import clear_param_store from pyro.infer import Trace_ELBO, TraceEnum_ELBO @@ -160,6 +161,12 @@ def setup_anndata( adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) + def train_compiled(self, compile_mode=None, compile_dynamic=None, **kwargs): + self.train(**kwargs, max_steps=1) + self.module._model = torch.compile(self.module.model, mode=compile_mode, dynamic=compile_dynamic) + self.module._guide = torch.compile(self.module.guide, mode=compile_mode, dynamic=compile_dynamic) + self.train(**kwargs) + def train( self, max_epochs: int = 30000,