Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Lightning Trainer/PyTorch Lightning 1.7.0 support + CI Fixes (JIT Tracing and Functions to Classes conversion) #1410

Merged
merged 33 commits into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
d22295e
Fix Flash CI: in case of converting functions to classes
krshrimali Aug 4, 2022
b725620
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 4, 2022
6d2c636
Attach Trainer to model for jit tracing
krshrimali Aug 4, 2022
f25eb1f
Merge branch 'flash-ci/functions-to-classes' of github.com:Lightning-…
krshrimali Aug 4, 2022
41fd845
on_<>_dataloader were deprecated, and were removed in v1.7 - remove
krshrimali Aug 4, 2022
df9f48b
Fix jit_script test
krshrimali Aug 4, 2022
4c8c0fb
fix
rohitgr7 Aug 8, 2022
76764e7
Try upgrading torchtext to 0.13.1 to support latest pytorch
krshrimali Aug 10, 2022
05c15ba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 10, 2022
2fc3d82
Update requirements/datatype_text.txt
krshrimali Aug 10, 2022
18e7144
dont allow 1.7.0 until issues are fixed
krshrimali Aug 10, 2022
74e7345
Requirements: go back to PL 1.7
krshrimali Aug 10, 2022
f04a274
CI Fix: manually setattr for collate_fn after DataLoader is initialized
krshrimali Aug 10, 2022
cdbaefd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 10, 2022
b410d31
install effdet from latest commit hash, fix object_detection
krshrimali Aug 12, 2022
a8f57ad
Force install effdet 0.3.0
krshrimali Aug 12, 2022
598d978
Remove dataloader_idx from on_train_batch_end hooks
krshrimali Aug 12, 2022
02826a3
rename, attempt to support previous PL versions
krshrimali Aug 12, 2022
cf1d6e9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2022
ea2679e
syntax error fix
krshrimali Aug 12, 2022
aa3e00e
Merge branch 'flash-ci/functions-to-classes' of github.com:Lightning-…
krshrimali Aug 12, 2022
4241ac9
Fix updating collate fn if input_transform is not None (icevision)
krshrimali Aug 12, 2022
d387308
Merge branch 'flash-ci/functions-to-classes' of github.com:Lightning-…
krshrimali Aug 12, 2022
802ab2b
pep8 fix
krshrimali Aug 12, 2022
23de0a4
Revert effdet changes, address reviews
krshrimali Aug 12, 2022
f609ef6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2022
7a24767
Apply suggestions from code review
krshrimali Aug 12, 2022
e58b1d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2022
4b32904
indentation fix
krshrimali Aug 12, 2022
2c89478
Update .azure-pipelines/testing-template.yml
krshrimali Aug 12, 2022
ccc7e24
Update .azure-pipelines/testing-template.yml
krshrimali Aug 12, 2022
ef9b0c0
Add CHANGELOG entries
krshrimali Aug 12, 2022
8e03fd7
Merge branch 'flash-ci/functions-to-classes' of github.com:Lightning-…
krshrimali Aug 12, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .azure-pipelines/testing-template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ jobs:

- bash: |
# python -m pip install "pip==20.1"
if [ "${{config}}" == "icevision" ]; then pip install '.[image]' icevision effdet icedata; elif [ "${{config}}" == "vissl" ]; then pip install '.[image]'; else pip install '.[${{config}}]'; fi
if [ "${{config}}" == "icevision" ]; then pip install '.[image]' icevision icedata; elif [ "${{config}}" == "vissl" ]; then pip install '.[image]'; else pip install '.[${{config}}]'; fi
krshrimali marked this conversation as resolved.
Show resolved Hide resolved
if [ "${{config}}" == "icevision" ]; then pip install git+https://github.com/rwightman/efficientdet-pytorch.git@79d26d8982b9f8e1f27d9f896e38012ac250fd26; fi
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
krshrimali marked this conversation as resolved.
Show resolved Hide resolved
pip install '.[test]' --upgrade-strategy only-if-needed
pip list
displayName: 'Install dependencies'
Expand Down
42 changes: 34 additions & 8 deletions flash/core/integrations/icevision/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,16 @@ def _wrap_collate_fn(collate_fn, samples):
DataKeys.METADATA: metadata,
}

def _update_collate_fn_dataloader(self, new_collate_fn, data_loader):
# Starting PL 1.7.0 - changing attributes after the DataLoader is initialized - will not work
# So we manually update the collate_fn for the dataloader, for now.
new_kwargs = getattr(data_loader, "__pl_saved_kwargs", None)
if new_kwargs:
new_kwargs["collate_fn"] = new_collate_fn
setattr(data_loader, "__pl_saved_kwargs", new_kwargs)
data_loader.collate_fn = new_collate_fn
return data_loader

def process_train_dataset(
self,
dataset: InputBase,
Expand All @@ -134,12 +144,16 @@ def process_train_dataset(
persistent_workers=persistent_workers,
)

data_loader.collate_fn = functools.partial(self._wrap_collate_fn, data_loader.collate_fn)
data_loader = self._update_collate_fn_dataloader(
functools.partial(self._wrap_collate_fn, data_loader.collate_fn), data_loader
)

input_transform = input_transform or self.input_transform
if input_transform is not None:
input_transform.inject_collate_fn(data_loader.collate_fn)
data_loader.collate_fn = create_worker_input_transform_processor(RunningStage.TRAINING, input_transform)
data_loader = self._update_collate_fn_dataloader(
create_worker_input_transform_processor(RunningStage.TRAINING, input_transform), data_loader
)
return data_loader

def process_val_dataset(
Expand All @@ -166,12 +180,16 @@ def process_val_dataset(
persistent_workers=persistent_workers,
)

data_loader.collate_fn = functools.partial(self._wrap_collate_fn, data_loader.collate_fn)
data_loader = self._update_collate_fn_dataloader(
functools.partial(self._wrap_collate_fn, data_loader.collate_fn), data_loader
)

input_transform = input_transform or self.input_transform
if input_transform is not None:
input_transform.inject_collate_fn(data_loader.collate_fn)
data_loader.collate_fn = create_worker_input_transform_processor(RunningStage.VALIDATING, input_transform)
data_loader = self._update_collate_fn_dataloader(
create_worker_input_transform_processor(RunningStage.VALIDATING, input_transform), data_loader
)
return data_loader

def process_test_dataset(
Expand All @@ -198,12 +216,16 @@ def process_test_dataset(
persistent_workers=persistent_workers,
)

data_loader.collate_fn = functools.partial(self._wrap_collate_fn, data_loader.collate_fn)
data_loader = self._update_collate_fn_dataloader(
functools.partial(self._wrap_collate_fn, data_loader.collate_fn), data_loader
)

input_transform = input_transform or self.input_transform
if input_transform is not None:
input_transform.inject_collate_fn(data_loader.collate_fn)
data_loader.collate_fn = create_worker_input_transform_processor(RunningStage.TESTING, input_transform)
data_loader = self._update_collate_fn_dataloader(
create_worker_input_transform_processor(RunningStage.TESTING, input_transform), data_loader
)
return data_loader

def process_predict_dataset(
Expand All @@ -230,12 +252,16 @@ def process_predict_dataset(
persistent_workers=persistent_workers,
)

data_loader.collate_fn = functools.partial(self._wrap_collate_fn, data_loader.collate_fn)
data_loader = self._update_collate_fn_dataloader(
functools.partial(self._wrap_collate_fn, data_loader.collate_fn), data_loader
)

input_transform = input_transform or self.input_transform
if input_transform is not None:
input_transform.inject_collate_fn(data_loader.collate_fn)
data_loader.collate_fn = create_worker_input_transform_processor(RunningStage.PREDICTING, input_transform)
data_loader = self._update_collate_fn_dataloader(
create_worker_input_transform_processor(RunningStage.PREDICTING, input_transform), data_loader
)
return data_loader

def training_step(self, batch, batch_idx) -> Any:
Expand Down
3 changes: 2 additions & 1 deletion flash/core/utilities/lightning_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def add_lightning_class_args(
lightning_class = class_from_function(lightning_class)

if inspect.isclass(lightning_class) and issubclass(
cast(type, lightning_class), (Trainer, LightningModule, LightningDataModule, Callback)
cast(type, lightning_class),
(Trainer, LightningModule, LightningDataModule, Callback, ClassFromFunctionBase),
):
if issubclass(cast(type, lightning_class), Callback):
self.callback_keys.append(nested_key)
Expand Down
2 changes: 1 addition & 1 deletion flash/image/classification/integrations/baal/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def label(self, probabilities: List[Tensor] = None, indices=None):
raise MisconfigurationException(
"The `probabilities` and `indices` are mutually exclusive, pass only of one them."
)
if probabilities is not None:
if probabilities is not None and len(probabilities) != 0:
probabilities = torch.cat([p[0].unsqueeze(0) for p in probabilities], dim=0)
uncertainties = self.heuristic.get_uncertainties(probabilities)
indices = np.argsort(uncertainties)
Expand Down
5 changes: 1 addition & 4 deletions flash/image/classification/integrations/baal/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def on_run_start(self, *args: Any, **kwargs: Any) -> None:
assert isinstance(self.trainer.datamodule, ActiveLearningDataModule)
if self._datamodule_state_dict is not None:
self.trainer.datamodule.load_state_dict(self._datamodule_state_dict)
self.trainer.predict_loop._return_predictions = True
self.trainer.predict_loop.return_predictions = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR, but we should figure out what to do with the baal loop. It is so bound to a particular PL and Baal version that I'm not sure it makes sense to have it as part of a framework. Maybe it would be better as a tutorial or in bolts? cc @otaj

self._lightning_module = self.trainer.lightning_module
self._model_state_dict = deepcopy(self._lightning_module.state_dict())
self.inference_model = InferenceMCDropoutTask(self._lightning_module, self.inference_iteration)
Expand Down Expand Up @@ -165,21 +165,18 @@ def _connect(self, model: LightningModule):
def _reset_fitting(self):
self.trainer.state.fn = TrainerFn.FITTING
self.trainer.training = True
self.trainer.lightning_module.on_train_dataloader()
self._connect(self._lightning_module)
self.fit_loop.epoch_progress = Progress()

def _reset_predicting(self):
self.trainer.state.fn = TrainerFn.PREDICTING
self.trainer.predicting = True
self.trainer.lightning_module.on_predict_dataloader()
self._connect(self.inference_model)

def _reset_testing(self):
self.trainer.state.fn = TrainerFn.TESTING
self.trainer.state.status = TrainerStatus.RUNNING
self.trainer.testing = True
self.trainer.lightning_module.on_test_dataloader()
self._connect(self._lightning_module)

def _reset_dataloader_for_stage(self, running_state: RunningStage):
Expand Down
4 changes: 2 additions & 2 deletions flash/image/embedding/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ def on_train_start(self) -> None:
def on_train_epoch_end(self) -> None:
self.adapter.on_train_epoch_end()

def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
self.adapter.on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)
def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, *args) -> None:
self.adapter.on_train_batch_end(outputs, batch, batch_idx, *args)

@classmethod
@requires("image", "vissl", "fairscale")
Expand Down
2 changes: 1 addition & 1 deletion flash/image/embedding/vissl/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def on_train_start(self) -> None:
for hook in self.hooks:
hook.on_start(self.task)

def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, *args) -> None:
self.task.iteration += 1

def on_train_epoch_end(self) -> None:
Expand Down
2 changes: 2 additions & 0 deletions requirements/datatype_image_extras.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ fairscale
# pin PL for testing, remove when fastface is updated
pytorch-lightning<1.5.0
torchmetrics<0.8.0 # pinned PL so we force a compatible TM version
# effdet had an issue with PL 1.12, and icevision doesn't support effdet's latest version yet (0.3.0)
torch<1.12
11 changes: 11 additions & 0 deletions tests/helpers/task_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,10 @@ def _test_jit_trace(self, tmpdir):
path = os.path.join(tmpdir, "test.pt")

model = self.instantiated_task
trainer = self.instantiated_trainer
model.eval()

model.trainer = trainer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, anything wrong with the current workaround?

model = torch.jit.trace(model, self.example_forward_input)

torch.jit.save(model, path)
Expand All @@ -117,8 +119,10 @@ def _test_jit_script(self, tmpdir):
path = os.path.join(tmpdir, "test.pt")

model = self.instantiated_task
trainer = self.instantiated_trainer
model.eval()

model.trainer = trainer
model = torch.jit.script(model)

torch.jit.save(model, path)
Expand Down Expand Up @@ -261,10 +265,17 @@ class TaskTester(metaclass=TaskTesterMeta):
"test_cli": [pytest.mark.parametrize("extra_args", [[]])],
}

trainer_args: Tuple = ()
trainer_kwargs: Dict = {}

@property
def instantiated_task(self):
return self.task(*self.task_args, **self.task_kwargs)

@property
def instantiated_trainer(self):
return flash.Trainer(*self.trainer_args, **self.trainer_kwargs)

@property
def example_forward_input(self):
pass
Expand Down