diff --git a/docs/source/reference/question_answering.rst b/docs/source/reference/question_answering.rst index b264b83823..9afed358c0 100644 --- a/docs/source/reference/question_answering.rst +++ b/docs/source/reference/question_answering.rst @@ -1,7 +1,7 @@ .. customcarditem:: :header: Extractive Question Answering :card_description: Learn to answer questions pertaining to some known textual context. - :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/default.svg + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/extractive_question_answering.svg :tags: NLP,Text .. _question_answering: diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 054efbfb6d..fdf3f22e48 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -288,6 +288,7 @@ def _train_dataloader(self) -> DataLoader: else: drop_last = len(train_ds) > self.batch_size pin_memory = True + persistent_workers = self.num_workers > 0 if self.sampler is None: sampler = None @@ -317,12 +318,14 @@ def _train_dataloader(self) -> DataLoader: pin_memory=pin_memory, drop_last=drop_last, collate_fn=collate_fn, + persistent_workers=persistent_workers, ) def _val_dataloader(self) -> DataLoader: val_ds: Dataset = self._val_ds() if isinstance(self._val_ds, Callable) else self._val_ds collate_fn = self._resolve_collate_fn(val_ds, RunningStage.VALIDATING) pin_memory = True + persistent_workers = self.num_workers > 0 if isinstance(getattr(self, "trainer", None), pl.Trainer): return self.trainer.lightning_module.process_val_dataset( @@ -340,12 +343,14 @@ def _val_dataloader(self) -> DataLoader: num_workers=self.num_workers, pin_memory=pin_memory, collate_fn=collate_fn, + persistent_workers=persistent_workers, ) def _test_dataloader(self) -> DataLoader: test_ds: Dataset = self._test_ds() if isinstance(self._test_ds, Callable) else self._test_ds collate_fn = self._resolve_collate_fn(test_ds, RunningStage.TESTING) pin_memory = True + persistent_workers = False if isinstance(getattr(self, "trainer", None), pl.Trainer): return self.trainer.lightning_module.process_test_dataset( @@ -363,6 +368,7 @@ def _test_dataloader(self) -> DataLoader: num_workers=self.num_workers, pin_memory=pin_memory, collate_fn=collate_fn, + persistent_workers=persistent_workers, ) def _predict_dataloader(self) -> DataLoader: @@ -375,6 +381,7 @@ def _predict_dataloader(self) -> DataLoader: collate_fn = self._resolve_collate_fn(predict_ds, RunningStage.PREDICTING) pin_memory = True + persistent_workers = False if isinstance(getattr(self, "trainer", None), pl.Trainer): return self.trainer.lightning_module.process_predict_dataset( @@ -386,7 +393,12 @@ def _predict_dataloader(self) -> DataLoader: ) return DataLoader( - predict_ds, batch_size=batch_size, num_workers=self.num_workers, pin_memory=True, collate_fn=collate_fn + predict_ds, + batch_size=batch_size, + num_workers=self.num_workers, + pin_memory=True, + collate_fn=collate_fn, + persistent_workers=persistent_workers, ) @property