Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Lightning Trainer/PyTorch Lightning 1.7.0 support + CI Fixes (JIT Tra…
Browse files Browse the repository at this point in the history
…cing and Functions to Classes conversion) (#1410)

Co-authored-by: rohitgr7 <[email protected]>
  • Loading branch information
krshrimali and rohitgr7 authored Aug 12, 2022
1 parent 8737ee8 commit 2657315
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 17 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed JIT tracing tests where the model class was not attached to the `Trainer` class ([#1410](https://github.com/Lightning-AI/lightning-flash/pull/1410))

- Fixed examples for BaaL integration by removing usage of `on_<stage>_dataloader` hooks (removed in PL 1.7.0) ([#1410](https://github.com/Lightning-AI/lightning-flash/pull/1410))

- Fixed examples for BaaL integration for the case when `probabilities` list is empty ([#1410](https://github.com/Lightning-AI/lightning-flash/pull/1410))

- Fixed a bug where collate functions were not being attached successfully after the `DataLoader` is initialized (in PL 1.7.0 changing attributes after initialization doesn't do anything) ([#1410](https://github.com/Lightning-AI/lightning-flash/pull/1410))

- Fixed a bug where grayscale images were not properly converted to RGB when loaded. ([#1394](https://github.com/PyTorchLightning/lightning-flash/pull/1394))

- Fixed a bug where size of mask for instance segmentation doesn't match to size of original image. ([#1353](https://github.com/PyTorchLightning/lightning-flash/pull/1353))
Expand Down
42 changes: 34 additions & 8 deletions flash/core/integrations/icevision/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,16 @@ def _wrap_collate_fn(collate_fn, samples):
DataKeys.METADATA: metadata,
}

def _update_collate_fn_dataloader(self, new_collate_fn, data_loader):
# Starting PL 1.7.0 - changing attributes after the DataLoader is initialized - will not work
# So we manually update the collate_fn for the dataloader, for now.
new_kwargs = getattr(data_loader, "__pl_saved_kwargs", None)
if new_kwargs:
new_kwargs["collate_fn"] = new_collate_fn
setattr(data_loader, "__pl_saved_kwargs", new_kwargs)
data_loader.collate_fn = new_collate_fn
return data_loader

def process_train_dataset(
self,
dataset: InputBase,
Expand All @@ -134,12 +144,16 @@ def process_train_dataset(
persistent_workers=persistent_workers,
)

data_loader.collate_fn = functools.partial(self._wrap_collate_fn, data_loader.collate_fn)
data_loader = self._update_collate_fn_dataloader(
functools.partial(self._wrap_collate_fn, data_loader.collate_fn), data_loader
)

input_transform = input_transform or self.input_transform
if input_transform is not None:
input_transform.inject_collate_fn(data_loader.collate_fn)
data_loader.collate_fn = create_worker_input_transform_processor(RunningStage.TRAINING, input_transform)
data_loader = self._update_collate_fn_dataloader(
create_worker_input_transform_processor(RunningStage.TRAINING, input_transform), data_loader
)
return data_loader

def process_val_dataset(
Expand All @@ -166,12 +180,16 @@ def process_val_dataset(
persistent_workers=persistent_workers,
)

data_loader.collate_fn = functools.partial(self._wrap_collate_fn, data_loader.collate_fn)
data_loader = self._update_collate_fn_dataloader(
functools.partial(self._wrap_collate_fn, data_loader.collate_fn), data_loader
)

input_transform = input_transform or self.input_transform
if input_transform is not None:
input_transform.inject_collate_fn(data_loader.collate_fn)
data_loader.collate_fn = create_worker_input_transform_processor(RunningStage.VALIDATING, input_transform)
data_loader = self._update_collate_fn_dataloader(
create_worker_input_transform_processor(RunningStage.VALIDATING, input_transform), data_loader
)
return data_loader

def process_test_dataset(
Expand All @@ -198,12 +216,16 @@ def process_test_dataset(
persistent_workers=persistent_workers,
)

data_loader.collate_fn = functools.partial(self._wrap_collate_fn, data_loader.collate_fn)
data_loader = self._update_collate_fn_dataloader(
functools.partial(self._wrap_collate_fn, data_loader.collate_fn), data_loader
)

input_transform = input_transform or self.input_transform
if input_transform is not None:
input_transform.inject_collate_fn(data_loader.collate_fn)
data_loader.collate_fn = create_worker_input_transform_processor(RunningStage.TESTING, input_transform)
data_loader = self._update_collate_fn_dataloader(
create_worker_input_transform_processor(RunningStage.TESTING, input_transform), data_loader
)
return data_loader

def process_predict_dataset(
Expand All @@ -230,12 +252,16 @@ def process_predict_dataset(
persistent_workers=persistent_workers,
)

data_loader.collate_fn = functools.partial(self._wrap_collate_fn, data_loader.collate_fn)
data_loader = self._update_collate_fn_dataloader(
functools.partial(self._wrap_collate_fn, data_loader.collate_fn), data_loader
)

input_transform = input_transform or self.input_transform
if input_transform is not None:
input_transform.inject_collate_fn(data_loader.collate_fn)
data_loader.collate_fn = create_worker_input_transform_processor(RunningStage.PREDICTING, input_transform)
data_loader = self._update_collate_fn_dataloader(
create_worker_input_transform_processor(RunningStage.PREDICTING, input_transform), data_loader
)
return data_loader

def training_step(self, batch, batch_idx) -> Any:
Expand Down
3 changes: 2 additions & 1 deletion flash/core/utilities/lightning_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def add_lightning_class_args(
lightning_class = class_from_function(lightning_class)

if inspect.isclass(lightning_class) and issubclass(
cast(type, lightning_class), (Trainer, LightningModule, LightningDataModule, Callback)
cast(type, lightning_class),
(Trainer, LightningModule, LightningDataModule, Callback, ClassFromFunctionBase),
):
if issubclass(cast(type, lightning_class), Callback):
self.callback_keys.append(nested_key)
Expand Down
2 changes: 1 addition & 1 deletion flash/image/classification/integrations/baal/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def label(self, probabilities: List[Tensor] = None, indices=None):
raise MisconfigurationException(
"The `probabilities` and `indices` are mutually exclusive, pass only of one them."
)
if probabilities is not None:
if probabilities is not None and len(probabilities) != 0:
probabilities = torch.cat([p[0].unsqueeze(0) for p in probabilities], dim=0)
uncertainties = self.heuristic.get_uncertainties(probabilities)
indices = np.argsort(uncertainties)
Expand Down
5 changes: 1 addition & 4 deletions flash/image/classification/integrations/baal/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def on_run_start(self, *args: Any, **kwargs: Any) -> None:
assert isinstance(self.trainer.datamodule, ActiveLearningDataModule)
if self._datamodule_state_dict is not None:
self.trainer.datamodule.load_state_dict(self._datamodule_state_dict)
self.trainer.predict_loop._return_predictions = True
self.trainer.predict_loop.return_predictions = True
self._lightning_module = self.trainer.lightning_module
self._model_state_dict = deepcopy(self._lightning_module.state_dict())
self.inference_model = InferenceMCDropoutTask(self._lightning_module, self.inference_iteration)
Expand Down Expand Up @@ -165,21 +165,18 @@ def _connect(self, model: LightningModule):
def _reset_fitting(self):
self.trainer.state.fn = TrainerFn.FITTING
self.trainer.training = True
self.trainer.lightning_module.on_train_dataloader()
self._connect(self._lightning_module)
self.fit_loop.epoch_progress = Progress()

def _reset_predicting(self):
self.trainer.state.fn = TrainerFn.PREDICTING
self.trainer.predicting = True
self.trainer.lightning_module.on_predict_dataloader()
self._connect(self.inference_model)

def _reset_testing(self):
self.trainer.state.fn = TrainerFn.TESTING
self.trainer.state.status = TrainerStatus.RUNNING
self.trainer.testing = True
self.trainer.lightning_module.on_test_dataloader()
self._connect(self._lightning_module)

def _reset_dataloader_for_stage(self, running_state: RunningStage):
Expand Down
4 changes: 2 additions & 2 deletions flash/image/embedding/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ def on_train_start(self) -> None:
def on_train_epoch_end(self) -> None:
self.adapter.on_train_epoch_end()

def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
self.adapter.on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)
def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, *args) -> None:
self.adapter.on_train_batch_end(outputs, batch, batch_idx, *args)

@classmethod
@requires("image", "vissl", "fairscale")
Expand Down
2 changes: 1 addition & 1 deletion flash/image/embedding/vissl/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def on_train_start(self) -> None:
for hook in self.hooks:
hook.on_start(self.task)

def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, *args) -> None:
self.task.iteration += 1

def on_train_epoch_end(self) -> None:
Expand Down
2 changes: 2 additions & 0 deletions requirements/datatype_image_extras.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ fairscale
# pin PL for testing, remove when fastface is updated
pytorch-lightning<1.5.0
torchmetrics<0.8.0 # pinned PL so we force a compatible TM version
# effdet had an issue with PL 1.12, and icevision doesn't support effdet's latest version yet (0.3.0)
torch<1.12
11 changes: 11 additions & 0 deletions tests/helpers/task_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,10 @@ def _test_jit_trace(self, tmpdir):
path = os.path.join(tmpdir, "test.pt")

model = self.instantiated_task
trainer = self.instantiated_trainer
model.eval()

model.trainer = trainer
model = torch.jit.trace(model, self.example_forward_input)

torch.jit.save(model, path)
Expand All @@ -117,8 +119,10 @@ def _test_jit_script(self, tmpdir):
path = os.path.join(tmpdir, "test.pt")

model = self.instantiated_task
trainer = self.instantiated_trainer
model.eval()

model.trainer = trainer
model = torch.jit.script(model)

torch.jit.save(model, path)
Expand Down Expand Up @@ -261,10 +265,17 @@ class TaskTester(metaclass=TaskTesterMeta):
"test_cli": [pytest.mark.parametrize("extra_args", [[]])],
}

trainer_args: Tuple = ()
trainer_kwargs: Dict = {}

@property
def instantiated_task(self):
return self.task(*self.task_args, **self.task_kwargs)

@property
def instantiated_trainer(self):
return flash.Trainer(*self.trainer_args, **self.trainer_kwargs)

@property
def example_forward_input(self):
pass
Expand Down

0 comments on commit 2657315

Please sign in to comment.