Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

to_review flag #835

Merged
merged 14 commits into from
Jul 20, 2022
3 changes: 2 additions & 1 deletion docs/source/governance.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ Core Maintainers
----------------
- William Falcon (`williamFalcon <https://github.com/williamFalcon>`_) (Lightning founder)
- Jirka Borovec (`Borda <https://github.com/Borda>`_)
- Ananya Harsh Jha (`ananyahjha93 <https://github.com/ananyahjha93>`_)
- Ota Jašek (`otaj <https://github.com/otaj>`_)
- Akihiro Nitta (`akihironitta <https://github.com/akihironitta>`_)

Alumni
------
- Teddy Koker (`teddykoker <https://github.com/teddykoker>`_)
- Annika Brundyn (`annikabrundyn <https://github.com/annikabrundyn>`_)
- Ananya Harsh Jha (`ananyahjha93 <https://github.com/ananyahjha93>`_)
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ Lightning-Bolts documentation

CONTRIBUTING.md
governance.md
stability.md
CHANGELOG.md


Expand Down
2 changes: 2 additions & 0 deletions docs/source/introduction_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ Bolts is a Deep learning research and production toolbox of:

**The Main goal of Bolts is to enable trying new ideas as fast as possible!**

.. note:: Currently, Bolts is going through a major revision. For more information about it, see `this GitHub issue <https://github.com/Lightning-AI/lightning-bolts/issues/819>`_ and `stability section <https://lightning-bolts.readthedocs.io/en/latest/stability.html>`_

All models are tested (daily), benchmarked, documented and work on CPUs, TPUs, GPUs and 16-bit precision.

**some examples!**
Expand Down
37 changes: 37 additions & 0 deletions docs/source/stability.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
.. _stability:

Bolts stability
===============

Currently we are going through major revision of Bolts to ensure all of the code is stable and compatible with the rest of the Lightning ecosystem.
For this reason, all of our features are either marked as stable or in need of review. Stable features are implicit, features to be reviewed are explicitly marked.

At the beginning of the aforementioned revision, **ALL** of the features currently in the project have been marked as to be reviewed and will undergo rigorous review and testing before they can be marked as stable.

This document is intended to help you know what to expect and to outline our commitment to stability.

Stable
______

For stable features, all of the following are true:

- the API isn’t expected to change
- if anything does change, incorrect usage will give a deprecation warning for **one minor release** before the breaking change is made
- the API has been tested for compatibility with latest releases of PyTorch Lightning and Flash

To Review
_________

For features to be reviewed, any or all of the following may be true:

- the feature has unstable dependencies
- the API may change without notice in future versions
- the performance of the feature has not been verified
- the docs for this feature are under active development


Before a feature can be moved to Stable it needs to satisfy following conditions:

- Have appropriate tests, that will check not only correctness of the feature, but also compatibility with the current versions.
- Not have duplicate code accross Lightning ecosystem and more mature OSS projects.
- Pass a review process.
3 changes: 3 additions & 0 deletions pl_bolts/callbacks/byol_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from torch import Tensor
from torch.nn import Module

from pl_bolts.utils.stability import to_review


@to_review()
class BYOLMAWeightUpdate(Callback):
"""Weight update rule from BYOL.

Expand Down
5 changes: 5 additions & 0 deletions pl_bolts/callbacks/data_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch.utils.hooks import RemovableHandle

from pl_bolts.utils import _WANDB_AVAILABLE
from pl_bolts.utils.stability import to_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _WANDB_AVAILABLE:
Expand All @@ -19,6 +20,7 @@
warn_missing_pkg("wandb")


@to_review()
class DataMonitorBase(Callback):

supported_loggers = (
Expand Down Expand Up @@ -109,6 +111,7 @@ def _is_logger_available(self, logger: LightningLoggerBase) -> bool:
return available


@to_review()
class ModuleDataMonitor(DataMonitorBase):

GROUP_NAME_INPUT = "input"
Expand Down Expand Up @@ -194,6 +197,7 @@ def hook(_: Module, inp: Sequence, out: Sequence) -> None:
return handle


@to_review()
class TrainingDataMonitor(DataMonitorBase):

GROUP_NAME = "training_step"
Expand Down Expand Up @@ -257,6 +261,7 @@ def collect_and_name_tensors(data: Any, output: Dict[str, Tensor], parent_name:
collect_and_name_tensors(item, output, parent_name=f"{parent_name}/{i:d}")


@to_review()
def shape2str(tensor: Tensor) -> str:
"""Returns the shape of a tensor in bracket notation as a string.

Expand Down
4 changes: 4 additions & 0 deletions pl_bolts/callbacks/knn_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from torch import Tensor
from torch.nn import functional as F

from pl_bolts.utils.stability import to_review


@to_review()
class KNNOnlineEvaluator(Callback):
"""Weighted KNN online evaluator for self-supervised learning.
The weighted KNN classifier matches sec 3.4 of https://arxiv.org/pdf/1805.01978.pdf.
Expand Down Expand Up @@ -138,5 +141,6 @@ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule)
pl_module.log("online_knn_val_acc", total_top1 / total_num, on_step=False, on_epoch=True, sync_dist=True)


@to_review()
def concat_all_gather(tensor: Tensor, accelerator: Accelerator) -> Tensor:
return accelerator.all_gather(tensor).view(-1, *tensor.shape[1:])
4 changes: 4 additions & 0 deletions pl_bolts/callbacks/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_info

from pl_bolts.utils.stability import to_review


@to_review()
class PrintTableMetricsCallback(Callback):
"""Prints a table with the metrics in columns on every epoch end.

Expand Down Expand Up @@ -41,6 +44,7 @@ def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
rank_zero_info(dicts_to_table(self.metrics))


@to_review()
def dicts_to_table(
dicts: List[Dict],
keys: Optional[List[str]] = None,
Expand Down
3 changes: 3 additions & 0 deletions pl_bolts/callbacks/sparseml.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from sparseml.pytorch.optim import ScheduledModifierManager
from sparseml.pytorch.utils import ModuleExporter

from pl_bolts.utils.stability import to_review


@to_review()
class SparseMLCallback(Callback):
"""Enables SparseML aware training. Requires a recipe to run during training.

Expand Down
3 changes: 3 additions & 0 deletions pl_bolts/callbacks/ssl_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
from torchmetrics.functional import accuracy

from pl_bolts.models.self_supervised.evaluator import SSLEvaluator
from pl_bolts.utils.stability import to_review


@to_review()
class SSLOnlineEvaluator(Callback): # pragma: no cover
"""Attaches a MLP for fine-tuning using the standard self-supervised protocol.

Expand Down Expand Up @@ -173,6 +175,7 @@ def on_load_checkpoint(self, trainer: Trainer, pl_module: LightningModule, callb
self._recovered_callback_state = callback_state


@to_review()
@contextmanager
def set_training(module: nn.Module, mode: bool):
"""Context manager to set training mode.
Expand Down
3 changes: 3 additions & 0 deletions pl_bolts/callbacks/torch_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
if _TORCH_ORT_AVAILABLE:
from torch_ort import ORTModule

from pl_bolts.utils.stability import to_review


@to_review()
class ORTCallback(Callback):
"""Enables Torch ORT: Accelerate PyTorch models with ONNX Runtime.

Expand Down
2 changes: 2 additions & 0 deletions pl_bolts/callbacks/variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch import Tensor

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import to_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
Expand All @@ -15,6 +16,7 @@
warn_missing_pkg("torchvision")


@to_review()
class LatentDimInterpolator(Callback):
"""Interpolates the latent space for a model by setting all dims to zero and stepping through the first two
dims increasing one unit at a time.
Expand Down
4 changes: 4 additions & 0 deletions pl_bolts/callbacks/verification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature

from pl_bolts.utils.stability import to_review


@to_review()
class VerificationBase:
"""Base class for model verification.

Expand Down Expand Up @@ -79,6 +82,7 @@ def _model_forward(self, input_array: Any) -> Any:
return self.model(input_array)


@to_review()
class VerificationCallbackBase(Callback):
"""Base class for model verification in form of a callback.

Expand Down
7 changes: 7 additions & 0 deletions pl_bolts/callbacks/verification/batch_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
from torch import Tensor

from pl_bolts.callbacks.verification.base import VerificationBase, VerificationCallbackBase
from pl_bolts.utils.stability import to_review


@to_review()
class BatchGradientVerification(VerificationBase):
"""Checks if a model mixes data across the batch dimension.

Expand Down Expand Up @@ -82,6 +84,7 @@ def check(
return not any(has_grad_outside_sample) and all(has_grad_inside_sample)


@to_review()
class BatchGradientVerificationCallback(VerificationCallbackBase):
"""The callback version of the :class:`BatchGradientVerification` test.

Expand Down Expand Up @@ -130,6 +133,7 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
self._raise()


@to_review()
def default_input_mapping(data: Any) -> List[Tensor]:
"""Finds all tensors in a (nested) collection that have the same batch size.

Expand Down Expand Up @@ -157,6 +161,7 @@ def default_input_mapping(data: Any) -> List[Tensor]:
return batches


@to_review()
def default_output_mapping(data: Any) -> Tensor:
"""Pulls out all tensors in a output collection and combines them into one big batch for verification.

Expand Down Expand Up @@ -188,6 +193,7 @@ def default_output_mapping(data: Any) -> Tensor:
return combined


@to_review()
def collect_tensors(data: Any) -> List[Tensor]:
"""Filters all tensors in a collection and returns them in a list."""
tensors = []
Expand All @@ -200,6 +206,7 @@ def collect_batches(tensor: Tensor) -> Tensor:
return tensors


@to_review()
@contextmanager
def selective_eval(model: nn.Module, layer_types: Iterable[Type[nn.Module]]) -> None:
"""A context manager that sets all requested types of layers to eval mode. This method uses an ``isinstance``
Expand Down
2 changes: 2 additions & 0 deletions pl_bolts/callbacks/vision/confused_logit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch import Tensor, nn

from pl_bolts.utils import _MATPLOTLIB_AVAILABLE
from pl_bolts.utils.stability import to_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _MATPLOTLIB_AVAILABLE:
Expand All @@ -17,6 +18,7 @@
Figure = object


@to_review()
class ConfusedLogitCallback(Callback): # pragma: no cover
"""Takes the logit predictions of a model and when the probabilities of two classes are very close, the model
doesn't have high certainty that it should pick one vs the other class.
Expand Down
2 changes: 2 additions & 0 deletions pl_bolts/callbacks/vision/image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pytorch_lightning import Callback, LightningModule, Trainer

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import to_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
Expand All @@ -12,6 +13,7 @@
warn_missing_pkg("torchvision")


@to_review()
class TensorboardGenerativeModelImageSampler(Callback):
"""Generates images and logs to tensorboard. Your model must implement the ``forward`` function for generation.

Expand Down
2 changes: 2 additions & 0 deletions pl_bolts/callbacks/vision/sr_image_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pytorch_lightning import Callback

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import to_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
Expand All @@ -14,6 +15,7 @@
warn_missing_pkg("torchvision")


@to_review()
class SRImageLoggerCallback(Callback):
"""Logs low-res, generated high-res, and ground truth high-res images to TensorBoard Your model must implement
the ``forward`` function for generation.
Expand Down
3 changes: 3 additions & 0 deletions pl_bolts/datamodules/async_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from torch._six import string_classes
from torch.utils.data import DataLoader, Dataset

from pl_bolts.utils.stability import to_review


@to_review()
class AsynchronousLoader:
"""Class for asynchronously loading from CPU memory to device memory with DataLoader.

Expand Down
2 changes: 2 additions & 0 deletions pl_bolts/datamodules/binary_emnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from pl_bolts.datamodules.emnist_datamodule import EMNISTDataModule
from pl_bolts.datasets import BinaryEMNIST
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import to_review


@to_review()
class BinaryEMNISTDataModule(EMNISTDataModule):
"""
.. figure:: https://user-images.githubusercontent.com/4632336/123210742-4d6b3380-d477-11eb-80da-3e9a74a18a07.png
Expand Down
2 changes: 2 additions & 0 deletions pl_bolts/datamodules/binary_mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.datasets import BinaryMNIST
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import to_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
Expand All @@ -11,6 +12,7 @@
warn_missing_pkg("torchvision")


@to_review()
class BinaryMNISTDataModule(VisionDataModule):
"""
.. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png
Expand Down
3 changes: 3 additions & 0 deletions pl_bolts/datamodules/cifar10_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pl_bolts.datasets import TrialCIFAR10
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import to_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
Expand All @@ -14,6 +15,7 @@
CIFAR10 = None


@to_review()
class CIFAR10DataModule(VisionDataModule):
"""
.. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2019/01/
Expand Down Expand Up @@ -122,6 +124,7 @@ def default_transforms(self) -> Callable:
return cf10_transforms


@to_review()
class TinyCIFAR10DataModule(CIFAR10DataModule):
"""Standard CIFAR10, train, val, test splits and transforms.

Expand Down
Loading