From b70f517d539fef4bef6ff5f7af1f22589063ff7f Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 9 Aug 2023 12:29:52 -0700 Subject: [PATCH 1/3] fix slicing when trainer is None --- viscy/light/data.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/viscy/light/data.py b/viscy/light/data.py index b95d41c1..fcab063c 100644 --- a/viscy/light/data.py +++ b/viscy/light/data.py @@ -433,7 +433,11 @@ def _setup_predict(self, dataset_settings: dict): ) def on_before_batch_transfer(self, batch: Sample, dataloader_idx: int) -> Sample: - if self.trainer.predicting or isinstance(batch, torch.Tensor): + predicting= False + if self.trainer: + if self.trainer.predicting: + predicting = True + if predicting or isinstance(batch, torch.Tensor): # skipping example input array return batch if self.target_2d: From 8fcc41b1054233bf98baad3b37ba7c533bfb9a96 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 9 Aug 2023 12:33:40 -0700 Subject: [PATCH 2/3] do not persist workers number is 0 --- viscy/light/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/viscy/light/data.py b/viscy/light/data.py index fcab063c..06d8fd38 100644 --- a/viscy/light/data.py +++ b/viscy/light/data.py @@ -452,7 +452,7 @@ def train_dataloader(self): batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, - persistent_workers=True, + persistent_workers=bool(self.num_workers), ) def val_dataloader(self): @@ -461,7 +461,7 @@ def val_dataloader(self): batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, - persistent_workers=True, + persistent_workers=bool(self.num_workers), ) def test_dataloader(self): From 6edebc20e02eccd2a58a1c4759efb516d2ce27ab Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 9 Aug 2023 12:34:25 -0700 Subject: [PATCH 3/3] style --- viscy/light/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/light/data.py b/viscy/light/data.py index 06d8fd38..9e0d60f2 100644 --- a/viscy/light/data.py +++ b/viscy/light/data.py @@ -433,7 +433,7 @@ def _setup_predict(self, dataset_settings: dict): ) def on_before_batch_transfer(self, batch: Sample, dataloader_idx: int) -> Sample: - predicting= False + predicting = False if self.trainer: if self.trainer.predicting: predicting = True