Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Lightning Trainer/PyTorch Lightning 1.7.0 support + CI Fixes (JIT Tracing and Functions to Classes conversion) #1410

Merged
merged 33 commits into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
d22295e
Fix Flash CI: in case of converting functions to classes
krshrimali Aug 4, 2022
b725620
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 4, 2022
6d2c636
Attach Trainer to model for jit tracing
krshrimali Aug 4, 2022
f25eb1f
Merge branch 'flash-ci/functions-to-classes' of github.com:Lightning-…
krshrimali Aug 4, 2022
41fd845
on_<>_dataloader were deprecated, and were removed in v1.7 - remove
krshrimali Aug 4, 2022
df9f48b
Fix jit_script test
krshrimali Aug 4, 2022
4c8c0fb
fix
rohitgr7 Aug 8, 2022
76764e7
Try upgrading torchtext to 0.13.1 to support latest pytorch
krshrimali Aug 10, 2022
05c15ba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 10, 2022
2fc3d82
Update requirements/datatype_text.txt
krshrimali Aug 10, 2022
18e7144
dont allow 1.7.0 until issues are fixed
krshrimali Aug 10, 2022
74e7345
Requirements: go back to PL 1.7
krshrimali Aug 10, 2022
f04a274
CI Fix: manually setattr for collate_fn after DataLoader is initialized
krshrimali Aug 10, 2022
cdbaefd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 10, 2022
b410d31
install effdet from latest commit hash, fix object_detection
krshrimali Aug 12, 2022
a8f57ad
Force install effdet 0.3.0
krshrimali Aug 12, 2022
598d978
Remove dataloader_idx from on_train_batch_end hooks
krshrimali Aug 12, 2022
02826a3
rename, attempt to support previous PL versions
krshrimali Aug 12, 2022
cf1d6e9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2022
ea2679e
syntax error fix
krshrimali Aug 12, 2022
aa3e00e
Merge branch 'flash-ci/functions-to-classes' of github.com:Lightning-…
krshrimali Aug 12, 2022
4241ac9
Fix updating collate fn if input_transform is not None (icevision)
krshrimali Aug 12, 2022
d387308
Merge branch 'flash-ci/functions-to-classes' of github.com:Lightning-…
krshrimali Aug 12, 2022
802ab2b
pep8 fix
krshrimali Aug 12, 2022
23de0a4
Revert effdet changes, address reviews
krshrimali Aug 12, 2022
f609ef6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2022
7a24767
Apply suggestions from code review
krshrimali Aug 12, 2022
e58b1d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2022
4b32904
indentation fix
krshrimali Aug 12, 2022
2c89478
Update .azure-pipelines/testing-template.yml
krshrimali Aug 12, 2022
ccc7e24
Update .azure-pipelines/testing-template.yml
krshrimali Aug 12, 2022
ef9b0c0
Add CHANGELOG entries
krshrimali Aug 12, 2022
8e03fd7
Merge branch 'flash-ci/functions-to-classes' of github.com:Lightning-…
krshrimali Aug 12, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion flash/core/utilities/lightning_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def add_lightning_class_args(
lightning_class = class_from_function(lightning_class)

if inspect.isclass(lightning_class) and issubclass(
cast(type, lightning_class), (Trainer, LightningModule, LightningDataModule, Callback)
cast(type, lightning_class),
(Trainer, LightningModule, LightningDataModule, Callback, ClassFromFunctionBase),
):
if issubclass(cast(type, lightning_class), Callback):
self.callback_keys.append(nested_key)
Expand Down
2 changes: 1 addition & 1 deletion flash/image/classification/integrations/baal/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def label(self, probabilities: List[Tensor] = None, indices=None):
raise MisconfigurationException(
"The `probabilities` and `indices` are mutually exclusive, pass only of one them."
)
if probabilities is not None:
if probabilities is not None and len(probabilities) != 0:
probabilities = torch.cat([p[0].unsqueeze(0) for p in probabilities], dim=0)
uncertainties = self.heuristic.get_uncertainties(probabilities)
indices = np.argsort(uncertainties)
Expand Down
5 changes: 1 addition & 4 deletions flash/image/classification/integrations/baal/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def on_run_start(self, *args: Any, **kwargs: Any) -> None:
assert isinstance(self.trainer.datamodule, ActiveLearningDataModule)
if self._datamodule_state_dict is not None:
self.trainer.datamodule.load_state_dict(self._datamodule_state_dict)
self.trainer.predict_loop._return_predictions = True
self.trainer.predict_loop.return_predictions = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR, but we should figure out what to do with the baal loop. It is so bound to a particular PL and Baal version that I'm not sure it makes sense to have it as part of a framework. Maybe it would be better as a tutorial or in bolts? cc @otaj

self._lightning_module = self.trainer.lightning_module
self._model_state_dict = deepcopy(self._lightning_module.state_dict())
self.inference_model = InferenceMCDropoutTask(self._lightning_module, self.inference_iteration)
Expand Down Expand Up @@ -165,21 +165,18 @@ def _connect(self, model: LightningModule):
def _reset_fitting(self):
self.trainer.state.fn = TrainerFn.FITTING
self.trainer.training = True
self.trainer.lightning_module.on_train_dataloader()
self._connect(self._lightning_module)
self.fit_loop.epoch_progress = Progress()

def _reset_predicting(self):
self.trainer.state.fn = TrainerFn.PREDICTING
self.trainer.predicting = True
self.trainer.lightning_module.on_predict_dataloader()
self._connect(self.inference_model)

def _reset_testing(self):
self.trainer.state.fn = TrainerFn.TESTING
self.trainer.state.status = TrainerStatus.RUNNING
self.trainer.testing = True
self.trainer.lightning_module.on_test_dataloader()
self._connect(self._lightning_module)

def _reset_dataloader_for_stage(self, running_state: RunningStage):
Expand Down
1 change: 1 addition & 0 deletions requirements/datatype_text.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ transformers>=4.5
torchmetrics[text]>=0.5.1
datasets>=1.8
sentence-transformers
torchtext==0.13.1
krshrimali marked this conversation as resolved.
Show resolved Hide resolved
11 changes: 11 additions & 0 deletions tests/helpers/task_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,10 @@ def _test_jit_trace(self, tmpdir):
path = os.path.join(tmpdir, "test.pt")

model = self.instantiated_task
trainer = self.instantiated_trainer
model.eval()

model.trainer = trainer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, anything wrong with the current workaround?

model = torch.jit.trace(model, self.example_forward_input)

torch.jit.save(model, path)
Expand All @@ -117,8 +119,10 @@ def _test_jit_script(self, tmpdir):
path = os.path.join(tmpdir, "test.pt")

model = self.instantiated_task
trainer = self.instantiated_trainer
model.eval()

model.trainer = trainer
model = torch.jit.script(model)

torch.jit.save(model, path)
Expand Down Expand Up @@ -261,10 +265,17 @@ class TaskTester(metaclass=TaskTesterMeta):
"test_cli": [pytest.mark.parametrize("extra_args", [[]])],
}

trainer_args: Tuple = ()
trainer_kwargs: Dict = {}

@property
def instantiated_task(self):
return self.task(*self.task_args, **self.task_kwargs)

@property
def instantiated_trainer(self):
return flash.Trainer(*self.trainer_args, **self.trainer_kwargs)

@property
def example_forward_input(self):
pass
Expand Down