diff --git a/Cell_BLAST/directi.py b/Cell_BLAST/directi.py
index b069982..46e6f7f 100644
--- a/Cell_BLAST/directi.py
+++ b/Cell_BLAST/directi.py
@@ -19,6 +19,7 @@
 from torch.utils.tensorboard import SummaryWriter
 
 from . import config, data, latent, prob, rmbatch, utils
+from .config import DEVICE
 from .rebuild import RMSprop
 
 _TRAIN = 1
@@ -184,7 +185,7 @@ def fit(
         )
 
         assert self._mode == _TRAIN
-        self.to(config.DEVICE)
+        self.to(DEVICE)
         self.ensure_reproducibility(self.random_seed)
         self.save_weights(self.path)
 
@@ -247,7 +248,7 @@ def train_epoch(self, train_dataloader, epoch, summarywriter):
 
         for feed_dict in train_dataloader:
             for key, value in feed_dict.items():
-                feed_dict[key] = value.to(config.DEVICE)
+                feed_dict[key] = value.to(DEVICE)
 
             exprs = feed_dict["exprs"]
             libs = feed_dict["library_size"]
@@ -326,7 +327,7 @@ def val_epoch(self, val_dataloader, epoch, summarywriter):
 
         for feed_dict in val_dataloader:
             for key, value in feed_dict.items():
-                feed_dict[key] = value.to(config.DEVICE)
+                feed_dict[key] = value.to(DEVICE)
 
             exprs = feed_dict["exprs"]
             libs = feed_dict["library_size"]
@@ -380,7 +381,7 @@ def save_weights(self, path: str, checkpoint: str = "checkpoint.pk"):
 
     def load_weights(self, path: str, checkpoint: str = "checkpoint.pk"):
         assert os.path.exists(path)
-        self.load_state_dict(torch.load(os.path.join(path, checkpoint)))
+        self.load_state_dict(torch.load(os.path.join(path, checkpoint), map_location=DEVICE))
 
     @classmethod
     def load_config(cls, configuration: typing.Mapping):
@@ -461,7 +462,7 @@ def load(
             )
 
         model = cls.load_config(configuration)
-        model.load_state_dict(torch.load(os.path.join(path, weights)), strict=False)
+        model.load_state_dict(torch.load(os.path.join(path, weights), map_location=DEVICE), strict=False)
 
         return model
 
@@ -510,7 +511,7 @@ def inference(
         """
 
         self.eval()
-        self.to(config.DEVICE)
+        self.to(DEVICE)
 
         random_seed = (
             config.RANDOM_SEED
@@ -593,7 +594,7 @@ def clustering(
         """
 
         self.eval()
-        self.to(config.DEVICE)
+        self.to(DEVICE)
 
         if not isinstance(self.latent_module, latent.CatGau):
             raise Exception("Model has no intrinsic clustering")
@@ -638,7 +639,7 @@ def gene_grad(
         """
 
         self.eval()
-        self.to(config.DEVICE)
+        self.to(DEVICE)
 
         x = data.select_vars(adata, self.genes).X
         if "__libsize__" not in adata.obs.columns:
@@ -669,7 +670,7 @@ def _fetch_latent(
             latents = []
             for feed_dict in dataloader:
                 for key, value in feed_dict.items():
-                    feed_dict[key] = value.to(config.DEVICE)
+                    feed_dict[key] = value.to(DEVICE)
                 exprs = feed_dict["exprs"]
                 libs = feed_dict["library_size"]
                 latents.append(
@@ -692,7 +693,7 @@ def _fetch_cat(
             cats = []
             for feed_dict in dataloader:
                 for key, value in feed_dict.items():
-                    feed_dict[key] = value.to(config.DEVICE)
+                    feed_dict[key] = value.to(DEVICE)
                 exprs = feed_dict["exprs"]
                 libs = feed_dict["library_size"]
                 cats.append(
@@ -714,7 +715,7 @@ def _fetch_grad(
         grads = []
         for feed_dict in dataloader:
             for key, value in feed_dict.items():
-                feed_dict[key] = value.to(config.DEVICE)
+                feed_dict[key] = value.to(DEVICE)
             exprs = feed_dict["exprs"]
             libs = feed_dict["library_size"]
             latent_grad = feed_dict["output_grad"]
@@ -955,7 +956,7 @@ def fit_DIRECTi(
     )
 
     if not reuse_weights is None:
-        model.load_state_dict(torch.load(reuse_weights))
+        model.load_state_dict(torch.load(reuse_weights, map_location=DEVICE))
 
     if optimizer != "RMSPropOptimizer":
         utils.logger.warning("Argument `optimizer` is no longer supported!")