Skip to content

Commit

Permalink
update LossRecorder to LossOutput
Browse files Browse the repository at this point in the history
  • Loading branch information
rpeys committed Jul 11, 2023
1 parent d1934e4 commit efc5981
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions mrvi/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.distributions as db
import torch.nn as nn
from scvi import REGISTRY_KEYS
from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data
from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data
from scvi.nn import one_hot
from torch.distributions import kl_divergence as kl

Expand Down Expand Up @@ -206,9 +206,9 @@ def loss(

kl_local = torch.tensor(0.0)
kl_global = torch.tensor(0.0)
return LossRecorder(
loss,
reconstruction_loss,
kl_local,
kl_global,
return LossOutput(
loss = loss,
reconstruction_loss = reconstruction_loss,
kl_local = kl_local,
kl_global = kl_global
)

0 comments on commit efc5981

Please sign in to comment.