Skip to content

Commit

Permalink
Merge pull request #28 from mehta-lab/target-center-slice-fix
Browse files Browse the repository at this point in the history
Fix datamodule
  • Loading branch information
mattersoflight authored Aug 12, 2023
2 parents f9b4e16 + 6edebc2 commit 5b7e0e8
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions viscy/light/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,11 @@ def _setup_predict(self, dataset_settings: dict):
)

def on_before_batch_transfer(self, batch: Sample, dataloader_idx: int) -> Sample:
if self.trainer.predicting or isinstance(batch, torch.Tensor):
predicting = False
if self.trainer:
if self.trainer.predicting:
predicting = True
if predicting or isinstance(batch, torch.Tensor):
# skipping example input array
return batch
if self.target_2d:
Expand All @@ -465,7 +469,7 @@ def train_dataloader(self):
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
persistent_workers=True,
persistent_workers=bool(self.num_workers),
)

def val_dataloader(self):
Expand All @@ -474,7 +478,7 @@ def val_dataloader(self):
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
persistent_workers=True,
persistent_workers=bool(self.num_workers),
)

def test_dataloader(self):
Expand Down

0 comments on commit 5b7e0e8

Please sign in to comment.