Skip to content

Commit

Permalink
Remap device on load
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeff1995 committed Mar 16, 2023
1 parent 680c803 commit 3d8c422
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions Cell_BLAST/directi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch.utils.tensorboard import SummaryWriter

from . import config, data, latent, prob, rmbatch, utils
from .config import DEVICE
from .rebuild import RMSprop

_TRAIN = 1
Expand Down Expand Up @@ -184,7 +185,7 @@ def fit(
)

assert self._mode == _TRAIN
self.to(config.DEVICE)
self.to(DEVICE)
self.ensure_reproducibility(self.random_seed)
self.save_weights(self.path)

Expand Down Expand Up @@ -247,7 +248,7 @@ def train_epoch(self, train_dataloader, epoch, summarywriter):

for feed_dict in train_dataloader:
for key, value in feed_dict.items():
feed_dict[key] = value.to(config.DEVICE)
feed_dict[key] = value.to(DEVICE)

exprs = feed_dict["exprs"]
libs = feed_dict["library_size"]
Expand Down Expand Up @@ -326,7 +327,7 @@ def val_epoch(self, val_dataloader, epoch, summarywriter):

for feed_dict in val_dataloader:
for key, value in feed_dict.items():
feed_dict[key] = value.to(config.DEVICE)
feed_dict[key] = value.to(DEVICE)

exprs = feed_dict["exprs"]
libs = feed_dict["library_size"]
Expand Down Expand Up @@ -380,7 +381,7 @@ def save_weights(self, path: str, checkpoint: str = "checkpoint.pk"):

def load_weights(self, path: str, checkpoint: str = "checkpoint.pk"):
assert os.path.exists(path)
self.load_state_dict(torch.load(os.path.join(path, checkpoint)))
self.load_state_dict(torch.load(os.path.join(path, checkpoint), map_location=DEVICE))

@classmethod
def load_config(cls, configuration: typing.Mapping):
Expand Down Expand Up @@ -461,7 +462,7 @@ def load(
)

model = cls.load_config(configuration)
model.load_state_dict(torch.load(os.path.join(path, weights)), strict=False)
model.load_state_dict(torch.load(os.path.join(path, weights), map_location=DEVICE), strict=False)

return model

Expand Down Expand Up @@ -510,7 +511,7 @@ def inference(
"""

self.eval()
self.to(config.DEVICE)
self.to(DEVICE)

random_seed = (
config.RANDOM_SEED
Expand Down Expand Up @@ -593,7 +594,7 @@ def clustering(
"""

self.eval()
self.to(config.DEVICE)
self.to(DEVICE)

if not isinstance(self.latent_module, latent.CatGau):
raise Exception("Model has no intrinsic clustering")
Expand Down Expand Up @@ -638,7 +639,7 @@ def gene_grad(
"""

self.eval()
self.to(config.DEVICE)
self.to(DEVICE)

x = data.select_vars(adata, self.genes).X
if "__libsize__" not in adata.obs.columns:
Expand Down Expand Up @@ -669,7 +670,7 @@ def _fetch_latent(
latents = []
for feed_dict in dataloader:
for key, value in feed_dict.items():
feed_dict[key] = value.to(config.DEVICE)
feed_dict[key] = value.to(DEVICE)
exprs = feed_dict["exprs"]
libs = feed_dict["library_size"]
latents.append(
Expand All @@ -692,7 +693,7 @@ def _fetch_cat(
cats = []
for feed_dict in dataloader:
for key, value in feed_dict.items():
feed_dict[key] = value.to(config.DEVICE)
feed_dict[key] = value.to(DEVICE)
exprs = feed_dict["exprs"]
libs = feed_dict["library_size"]
cats.append(
Expand All @@ -714,7 +715,7 @@ def _fetch_grad(
grads = []
for feed_dict in dataloader:
for key, value in feed_dict.items():
feed_dict[key] = value.to(config.DEVICE)
feed_dict[key] = value.to(DEVICE)
exprs = feed_dict["exprs"]
libs = feed_dict["library_size"]
latent_grad = feed_dict["output_grad"]
Expand Down Expand Up @@ -955,7 +956,7 @@ def fit_DIRECTi(
)

if not reuse_weights is None:
model.load_state_dict(torch.load(reuse_weights))
model.load_state_dict(torch.load(reuse_weights, map_location=DEVICE))

if optimizer != "RMSPropOptimizer":
utils.logger.warning("Argument `optimizer` is no longer supported!")
Expand Down

0 comments on commit 3d8c422

Please sign in to comment.