From b2f38944dd13f3b3024e10f97abaf3240f6cccfe Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Sun, 26 Nov 2023 13:43:32 +0000 Subject: [PATCH] option to use torch.compile (#336) --- cell2location/models/_cell2location_model.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/cell2location/models/_cell2location_model.py b/cell2location/models/_cell2location_model.py index 36afa834..2fac5313 100755 --- a/cell2location/models/_cell2location_model.py +++ b/cell2location/models/_cell2location_model.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd import scanpy +import torch from anndata import AnnData from pyro import clear_param_store from pyro.infer import Trace_ELBO, TraceEnum_ELBO @@ -160,6 +161,12 @@ def setup_anndata( adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) + def train_compiled(self, compile_mode=None, compile_dynamic=None, **kwargs): + self.train(**kwargs, max_steps=1) + self.module._model = torch.compile(self.module.model, mode=compile_mode, dynamic=compile_dynamic) + self.module._guide = torch.compile(self.module.guide, mode=compile_mode, dynamic=compile_dynamic) + self.train(**kwargs) + def train( self, max_epochs: int = 30000,