diff --git a/Cell_BLAST/directi.py b/Cell_BLAST/directi.py index b069982..46e6f7f 100644 --- a/Cell_BLAST/directi.py +++ b/Cell_BLAST/directi.py @@ -19,6 +19,7 @@ from torch.utils.tensorboard import SummaryWriter from . import config, data, latent, prob, rmbatch, utils +from .config import DEVICE from .rebuild import RMSprop _TRAIN = 1 @@ -184,7 +185,7 @@ def fit( ) assert self._mode == _TRAIN - self.to(config.DEVICE) + self.to(DEVICE) self.ensure_reproducibility(self.random_seed) self.save_weights(self.path) @@ -247,7 +248,7 @@ def train_epoch(self, train_dataloader, epoch, summarywriter): for feed_dict in train_dataloader: for key, value in feed_dict.items(): - feed_dict[key] = value.to(config.DEVICE) + feed_dict[key] = value.to(DEVICE) exprs = feed_dict["exprs"] libs = feed_dict["library_size"] @@ -326,7 +327,7 @@ def val_epoch(self, val_dataloader, epoch, summarywriter): for feed_dict in val_dataloader: for key, value in feed_dict.items(): - feed_dict[key] = value.to(config.DEVICE) + feed_dict[key] = value.to(DEVICE) exprs = feed_dict["exprs"] libs = feed_dict["library_size"] @@ -380,7 +381,7 @@ def save_weights(self, path: str, checkpoint: str = "checkpoint.pk"): def load_weights(self, path: str, checkpoint: str = "checkpoint.pk"): assert os.path.exists(path) - self.load_state_dict(torch.load(os.path.join(path, checkpoint))) + self.load_state_dict(torch.load(os.path.join(path, checkpoint), map_location=DEVICE)) @classmethod def load_config(cls, configuration: typing.Mapping): @@ -461,7 +462,7 @@ def load( ) model = cls.load_config(configuration) - model.load_state_dict(torch.load(os.path.join(path, weights)), strict=False) + model.load_state_dict(torch.load(os.path.join(path, weights), map_location=DEVICE), strict=False) return model @@ -510,7 +511,7 @@ def inference( """ self.eval() - self.to(config.DEVICE) + self.to(DEVICE) random_seed = ( config.RANDOM_SEED @@ -593,7 +594,7 @@ def clustering( """ self.eval() - self.to(config.DEVICE) + self.to(DEVICE) if not isinstance(self.latent_module, latent.CatGau): raise Exception("Model has no intrinsic clustering") @@ -638,7 +639,7 @@ def gene_grad( """ self.eval() - self.to(config.DEVICE) + self.to(DEVICE) x = data.select_vars(adata, self.genes).X if "__libsize__" not in adata.obs.columns: @@ -669,7 +670,7 @@ def _fetch_latent( latents = [] for feed_dict in dataloader: for key, value in feed_dict.items(): - feed_dict[key] = value.to(config.DEVICE) + feed_dict[key] = value.to(DEVICE) exprs = feed_dict["exprs"] libs = feed_dict["library_size"] latents.append( @@ -692,7 +693,7 @@ def _fetch_cat( cats = [] for feed_dict in dataloader: for key, value in feed_dict.items(): - feed_dict[key] = value.to(config.DEVICE) + feed_dict[key] = value.to(DEVICE) exprs = feed_dict["exprs"] libs = feed_dict["library_size"] cats.append( @@ -714,7 +715,7 @@ def _fetch_grad( grads = [] for feed_dict in dataloader: for key, value in feed_dict.items(): - feed_dict[key] = value.to(config.DEVICE) + feed_dict[key] = value.to(DEVICE) exprs = feed_dict["exprs"] libs = feed_dict["library_size"] latent_grad = feed_dict["output_grad"] @@ -955,7 +956,7 @@ def fit_DIRECTi( ) if not reuse_weights is None: - model.load_state_dict(torch.load(reuse_weights)) + model.load_state_dict(torch.load(reuse_weights, map_location=DEVICE)) if optimizer != "RMSPropOptimizer": utils.logger.warning("Argument `optimizer` is no longer supported!")