From c09a74df7337517edbd610fc324daadd10df07dc Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 19 Jul 2022 10:07:21 +0200 Subject: [PATCH 01/11] first portion of experiment flag --- docs/source/governance.rst | 3 +- docs/source/index.rst | 1 + docs/source/stability.rst | 30 ++++++++ pl_bolts/callbacks/byol_updates.py | 3 + pl_bolts/callbacks/data_monitor.py | 5 ++ pl_bolts/callbacks/knn_online.py | 4 + pl_bolts/callbacks/verification/base.py | 4 + .../callbacks/verification/batch_gradient.py | 7 ++ pl_bolts/callbacks/vision/confused_logit.py | 2 + pl_bolts/callbacks/vision/image_generation.py | 2 + pl_bolts/callbacks/vision/sr_image_logger.py | 2 + pl_bolts/utils/stability.py | 76 +++++++++++++++++++ 12 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 docs/source/stability.rst create mode 100644 pl_bolts/utils/stability.py diff --git a/docs/source/governance.rst b/docs/source/governance.rst index 24a4b0c0ad..99c85bacf1 100644 --- a/docs/source/governance.rst +++ b/docs/source/governance.rst @@ -7,10 +7,11 @@ Core Maintainers ---------------- - William Falcon (`williamFalcon `_) (Lightning founder) - Jirka Borovec (`Borda `_) -- Ananya Harsh Jha (`ananyahjha93 `_) +- Ota Jašek (`otaj `_) - Akihiro Nitta (`akihironitta `_) Alumni ------ - Teddy Koker (`teddykoker `_) - Annika Brundyn (`annikabrundyn `_) +- Ananya Harsh Jha (`ananyahjha93 `_) diff --git a/docs/source/index.rst b/docs/source/index.rst index 61d3ee75c5..84b3fd7906 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -97,6 +97,7 @@ Lightning-Bolts documentation CONTRIBUTING.md governance.md + stability.md CHANGELOG.md diff --git a/docs/source/stability.rst b/docs/source/stability.rst new file mode 100644 index 0000000000..f149237db6 --- /dev/null +++ b/docs/source/stability.rst @@ -0,0 +1,30 @@ +.. _stability: + +Bolts stability +=============== + +Currently we are going through major revision of Bolts to ensure all of the code is stable and compatible with the rest of the Lightning ecosystem. +For this reason, all of our features are either marked as stable or experimental. Stable features are implicit, experimental features are explicit. + +At the beginning of the aforementioned revision, **ALL** of the features currently in the project have been marked as experimental and will undergo rigorous review and testing before they can be marked as stable. + +This document is intended to help you know what to expect and to outline our commitment to stability. + +Stable +______ + +For stable features, all of the following are true: + +- the API isn’t expected to change +- if anything does change, incorrect usage will give a deprecation warning for **one major release** before the breaking change is made +- the API has been tested for compatibility with latest releases of PyTorch Lightning and Flash + +Experimental +____________ + +For experimental features, any or all of the following may be true: + +- the feature has unstable dependencies +- the API may change without notice in future versions +- the performance of the feature has not been verified +- the docs for this feature are under active development diff --git a/pl_bolts/callbacks/byol_updates.py b/pl_bolts/callbacks/byol_updates.py index 7ba5a01cfd..15bbc6c4f9 100644 --- a/pl_bolts/callbacks/byol_updates.py +++ b/pl_bolts/callbacks/byol_updates.py @@ -5,7 +5,10 @@ from torch import Tensor from torch.nn import Module +from pl_bolts.utils.stability import experimental + +@experimental() class BYOLMAWeightUpdate(Callback): """Weight update rule from BYOL. diff --git a/pl_bolts/callbacks/data_monitor.py b/pl_bolts/callbacks/data_monitor.py index 6bb2da8ea1..5f4656e8f2 100644 --- a/pl_bolts/callbacks/data_monitor.py +++ b/pl_bolts/callbacks/data_monitor.py @@ -11,6 +11,7 @@ from torch.utils.hooks import RemovableHandle from pl_bolts.utils import _WANDB_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _WANDB_AVAILABLE: @@ -19,6 +20,7 @@ warn_missing_pkg("wandb") +@experimental() class DataMonitorBase(Callback): supported_loggers = ( @@ -109,6 +111,7 @@ def _is_logger_available(self, logger: LightningLoggerBase) -> bool: return available +@experimental() class ModuleDataMonitor(DataMonitorBase): GROUP_NAME_INPUT = "input" @@ -194,6 +197,7 @@ def hook(_: Module, inp: Sequence, out: Sequence) -> None: return handle +@experimental() class TrainingDataMonitor(DataMonitorBase): GROUP_NAME = "training_step" @@ -257,6 +261,7 @@ def collect_and_name_tensors(data: Any, output: Dict[str, Tensor], parent_name: collect_and_name_tensors(item, output, parent_name=f"{parent_name}/{i:d}") +@experimental() def shape2str(tensor: Tensor) -> str: """Returns the shape of a tensor in bracket notation as a string. diff --git a/pl_bolts/callbacks/knn_online.py b/pl_bolts/callbacks/knn_online.py index 258b3f938f..238eda1d30 100644 --- a/pl_bolts/callbacks/knn_online.py +++ b/pl_bolts/callbacks/knn_online.py @@ -6,7 +6,10 @@ from torch import Tensor from torch.nn import functional as F +from pl_bolts.utils.stability import experimental + +@experimental() class KNNOnlineEvaluator(Callback): """Weighted KNN online evaluator for self-supervised learning. The weighted KNN classifier matches sec 3.4 of https://arxiv.org/pdf/1805.01978.pdf. @@ -138,5 +141,6 @@ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) pl_module.log("online_knn_val_acc", total_top1 / total_num, on_step=False, on_epoch=True, sync_dist=True) +@experimental() def concat_all_gather(tensor: Tensor, accelerator: Accelerator) -> Tensor: return accelerator.all_gather(tensor).view(-1, *tensor.shape[1:]) diff --git a/pl_bolts/callbacks/verification/base.py b/pl_bolts/callbacks/verification/base.py index 430b580fd3..25aebaec8b 100644 --- a/pl_bolts/callbacks/verification/base.py +++ b/pl_bolts/callbacks/verification/base.py @@ -9,7 +9,10 @@ from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature +from pl_bolts.utils.stability import experimental + +@experimental() class VerificationBase: """Base class for model verification. @@ -79,6 +82,7 @@ def _model_forward(self, input_array: Any) -> Any: return self.model(input_array) +@experimental() class VerificationCallbackBase(Callback): """Base class for model verification in form of a callback. diff --git a/pl_bolts/callbacks/verification/batch_gradient.py b/pl_bolts/callbacks/verification/batch_gradient.py index 1c53f9a6f4..4c3b08d736 100644 --- a/pl_bolts/callbacks/verification/batch_gradient.py +++ b/pl_bolts/callbacks/verification/batch_gradient.py @@ -10,8 +10,10 @@ from torch import Tensor from pl_bolts.callbacks.verification.base import VerificationBase, VerificationCallbackBase +from pl_bolts.utils.stability import experimental +@experimental() class BatchGradientVerification(VerificationBase): """Checks if a model mixes data across the batch dimension. @@ -82,6 +84,7 @@ def check( return not any(has_grad_outside_sample) and all(has_grad_inside_sample) +@experimental() class BatchGradientVerificationCallback(VerificationCallbackBase): """The callback version of the :class:`BatchGradientVerification` test. @@ -130,6 +133,7 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: self._raise() +@experimental() def default_input_mapping(data: Any) -> List[Tensor]: """Finds all tensors in a (nested) collection that have the same batch size. @@ -157,6 +161,7 @@ def default_input_mapping(data: Any) -> List[Tensor]: return batches +@experimental() def default_output_mapping(data: Any) -> Tensor: """Pulls out all tensors in a output collection and combines them into one big batch for verification. @@ -188,6 +193,7 @@ def default_output_mapping(data: Any) -> Tensor: return combined +@experimental() def collect_tensors(data: Any) -> List[Tensor]: """Filters all tensors in a collection and returns them in a list.""" tensors = [] @@ -200,6 +206,7 @@ def collect_batches(tensor: Tensor) -> Tensor: return tensors +@experimental() @contextmanager def selective_eval(model: nn.Module, layer_types: Iterable[Type[nn.Module]]) -> None: """A context manager that sets all requested types of layers to eval mode. This method uses an ``isinstance`` diff --git a/pl_bolts/callbacks/vision/confused_logit.py b/pl_bolts/callbacks/vision/confused_logit.py index 568b56becc..3a4b52d2f4 100644 --- a/pl_bolts/callbacks/vision/confused_logit.py +++ b/pl_bolts/callbacks/vision/confused_logit.py @@ -5,6 +5,7 @@ from torch import Tensor, nn from pl_bolts.utils import _MATPLOTLIB_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _MATPLOTLIB_AVAILABLE: @@ -17,6 +18,7 @@ Figure = object +@experimental() class ConfusedLogitCallback(Callback): # pragma: no cover """Takes the logit predictions of a model and when the probabilities of two classes are very close, the model doesn't have high certainty that it should pick one vs the other class. diff --git a/pl_bolts/callbacks/vision/image_generation.py b/pl_bolts/callbacks/vision/image_generation.py index 6846b7d492..cea7beafb6 100644 --- a/pl_bolts/callbacks/vision/image_generation.py +++ b/pl_bolts/callbacks/vision/image_generation.py @@ -4,6 +4,7 @@ from pytorch_lightning import Callback, LightningModule, Trainer from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -12,6 +13,7 @@ warn_missing_pkg("torchvision") +@experimental() class TensorboardGenerativeModelImageSampler(Callback): """Generates images and logs to tensorboard. Your model must implement the ``forward`` function for generation. diff --git a/pl_bolts/callbacks/vision/sr_image_logger.py b/pl_bolts/callbacks/vision/sr_image_logger.py index 2767b04646..f94ecd7040 100644 --- a/pl_bolts/callbacks/vision/sr_image_logger.py +++ b/pl_bolts/callbacks/vision/sr_image_logger.py @@ -6,6 +6,7 @@ from pytorch_lightning import Callback from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -14,6 +15,7 @@ warn_missing_pkg("torchvision") +@experimental() class SRImageLoggerCallback(Callback): """Logs low-res, generated high-res, and ground truth high-res images to TensorBoard Your model must implement the ``forward`` function for generation. diff --git a/pl_bolts/utils/stability.py b/pl_bolts/utils/stability.py new file mode 100644 index 0000000000..e20c82e1ea --- /dev/null +++ b/pl_bolts/utils/stability.py @@ -0,0 +1,76 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +import inspect +from typing import Callable, Type, Union + +from pytorch_lightning.utilities import rank_zero_warn + + +@functools.lru_cache() # Trick to only warn once for each message +def _raise_experimental_warning(message: str, stacklevel: int = 6): + rank_zero_warn( + f"{message} The compatibility with other Lightning projects is not guaranteed and API may change at any time." + "The API and functionality may change without warning in future releases. " + "More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html", + stacklevel=stacklevel, + category=UserWarning, + ) + + +def experimental( + message: str = "This feature is currently marked as experimental.", +): + """The experimental decorator is used to indicate that a particular feature is not properly reviewed and tested yet. + A callable or type that has been marked as experimental will give a ``UserWarning`` when it is called or + instantiated. This designation should be used following the description given in :ref:`stability`. + Args: + message: The message to include in the warning. + Examples + ________ + .. testsetup:: + >>> import pytest + .. doctest:: + >>> from pl_bolts.utils.stability import experimental + >>> @experimental() + ... class MyExperimentalFeature: + ... pass + ... + >>> with pytest.warns(UserWarning, match="This feature is currently marked as experimental."): + ... MyExperimentalFeature() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + ... + <...> + >>> @experimental("This feature is currently marked as experimental with a message.") + ... class MyExperimentalFeatureWithCustomMessage: + ... pass + ... + >>> with pytest.warns(UserWarning, match="This feature is currently marked as experimental with a message."): + ... MyExperimentalFeatureWithCustomMessage() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + ... + <...> + """ + + def decorator(callable: Union[Callable, Type]): + if inspect.isclass(callable): + callable.__init__ = decorator(callable.__init__) + return callable + + @functools.wraps(callable) + def wrapper(*args, **kwargs): + _raise_experimental_warning(message) + return callable(*args, **kwargs) + + return wrapper + + return decorator From 7c287b5408715565f253606801c22bb21b4eb0ee Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 19 Jul 2022 15:51:08 +0200 Subject: [PATCH 02/11] add experiment flag --- pl_bolts/callbacks/printing.py | 4 ++++ pl_bolts/callbacks/sparseml.py | 3 +++ pl_bolts/callbacks/ssl_online.py | 3 +++ pl_bolts/callbacks/torch_ort.py | 3 +++ pl_bolts/callbacks/variational.py | 2 ++ pl_bolts/datamodules/async_dataloader.py | 3 +++ .../datamodules/binary_emnist_datamodule.py | 2 ++ pl_bolts/datamodules/binary_mnist_datamodule.py | 2 ++ pl_bolts/datamodules/cifar10_datamodule.py | 3 +++ pl_bolts/datamodules/cityscapes_datamodule.py | 2 ++ pl_bolts/datamodules/emnist_datamodule.py | 2 ++ pl_bolts/datamodules/experience_source.py | 5 +++++ .../datamodules/fashion_mnist_datamodule.py | 2 ++ pl_bolts/datamodules/imagenet_datamodule.py | 2 ++ pl_bolts/datamodules/kitti_datamodule.py | 2 ++ pl_bolts/datamodules/mnist_datamodule.py | 2 ++ pl_bolts/datamodules/sklearn_datamodule.py | 4 ++++ pl_bolts/datamodules/sr_datamodule.py | 3 +++ pl_bolts/datamodules/ssl_imagenet_datamodule.py | 2 ++ pl_bolts/datamodules/stl10_datamodule.py | 2 ++ pl_bolts/datamodules/vision_datamodule.py | 3 +++ pl_bolts/datamodules/vocdetection_datamodule.py | 4 ++++ pl_bolts/datasets/base_dataset.py | 3 +++ pl_bolts/datasets/cifar10_dataset.py | 3 +++ pl_bolts/datasets/concat_dataset.py | 3 +++ pl_bolts/datasets/dummy_dataset.py | 7 +++++++ pl_bolts/datasets/emnist_dataset.py | 2 ++ pl_bolts/datasets/imagenet_dataset.py | 13 +++++++++++++ pl_bolts/datasets/kitti_dataset.py | 2 ++ pl_bolts/datasets/mnist_dataset.py | 2 ++ pl_bolts/datasets/sr_celeba_dataset.py | 2 ++ pl_bolts/datasets/sr_dataset_mixin.py | 2 ++ pl_bolts/datasets/sr_mnist_dataset.py | 2 ++ pl_bolts/datasets/sr_stl10_dataset.py | 2 ++ pl_bolts/datasets/ssl_amdim_datasets.py | 3 +++ pl_bolts/datasets/utils.py | 2 ++ pl_bolts/losses/object_detection.py | 3 +++ pl_bolts/losses/rl.py | 5 +++++ pl_bolts/losses/self_supervised_learning.py | 6 ++++++ pl_bolts/metrics/aggregation.py | 5 +++++ pl_bolts/metrics/object_detection.py | 4 ++++ .../autoencoders/basic_ae/basic_ae_module.py | 3 +++ .../autoencoders/basic_vae/basic_vae_module.py | 3 +++ pl_bolts/models/autoencoders/components.py | 17 +++++++++++++++++ .../components/torchvision_backbones.py | 5 +++++ .../models/detection/faster_rcnn/backbones.py | 2 ++ .../detection/faster_rcnn/faster_rcnn_module.py | 4 ++++ .../models/detection/retinanet/backbones.py | 2 ++ .../detection/retinanet/retinanet_module.py | 3 +++ pl_bolts/models/detection/yolo/yolo_config.py | 9 +++++++++ pl_bolts/models/detection/yolo/yolo_layers.py | 10 ++++++++++ pl_bolts/models/detection/yolo/yolo_module.py | 4 ++++ pl_bolts/models/gans/basic/basic_gan_module.py | 3 +++ pl_bolts/models/gans/basic/components.py | 4 ++++ pl_bolts/models/gans/dcgan/components.py | 4 ++++ pl_bolts/models/gans/dcgan/dcgan_module.py | 3 +++ pl_bolts/models/gans/pix2pix/components.py | 6 ++++++ pl_bolts/models/gans/pix2pix/pix2pix_module.py | 3 +++ pl_bolts/models/gans/srgan/components.py | 5 +++++ pl_bolts/models/gans/srgan/srgan_module.py | 3 +++ pl_bolts/models/gans/srgan/srresnet_module.py | 3 +++ pl_bolts/models/mnist_module.py | 3 +++ pl_bolts/models/regression/linear_regression.py | 4 ++++ .../models/regression/logistic_regression.py | 4 ++++ .../models/rl/advantage_actor_critic_model.py | 3 +++ pl_bolts/models/rl/common/agents.py | 7 +++++++ pl_bolts/models/rl/common/cli.py | 3 +++ pl_bolts/models/rl/common/distributions.py | 3 +++ pl_bolts/models/rl/common/gym_wrappers.py | 10 ++++++++++ pl_bolts/models/rl/common/memory.py | 7 +++++++ pl_bolts/models/rl/common/networks.py | 11 +++++++++++ pl_bolts/models/rl/double_dqn_model.py | 3 +++ pl_bolts/models/rl/dqn_model.py | 3 +++ pl_bolts/models/rl/dueling_dqn_model.py | 3 +++ pl_bolts/models/rl/noisy_dqn_model.py | 3 +++ pl_bolts/models/rl/per_dqn_model.py | 3 +++ pl_bolts/models/rl/ppo_model.py | 3 +++ pl_bolts/models/rl/reinforce_model.py | 3 +++ pl_bolts/models/rl/sac_model.py | 3 +++ .../models/rl/vanilla_policy_gradient_model.py | 3 +++ .../self_supervised/amdim/amdim_module.py | 4 ++++ .../models/self_supervised/amdim/datasets.py | 3 +++ .../models/self_supervised/amdim/networks.py | 9 +++++++++ .../models/self_supervised/amdim/transforms.py | 7 +++++++ .../models/self_supervised/byol/byol_module.py | 3 +++ pl_bolts/models/self_supervised/byol/models.py | 3 +++ .../models/self_supervised/cpc/cpc_finetuner.py | 2 ++ .../models/self_supervised/cpc/cpc_module.py | 3 +++ pl_bolts/models/self_supervised/cpc/networks.py | 8 ++++++++ .../models/self_supervised/cpc/transforms.py | 7 +++++++ pl_bolts/models/self_supervised/evaluator.py | 4 ++++ .../models/self_supervised/moco/callbacks.py | 3 +++ .../models/self_supervised/moco/moco2_module.py | 4 ++++ .../models/self_supervised/moco/transforms.py | 8 ++++++++ pl_bolts/models/self_supervised/resnets.py | 17 +++++++++++++++++ .../self_supervised/simclr/simclr_finetuner.py | 2 ++ .../self_supervised/simclr/simclr_module.py | 5 +++++ .../models/self_supervised/simclr/transforms.py | 5 +++++ .../models/self_supervised/simsiam/models.py | 3 +++ .../self_supervised/simsiam/simsiam_module.py | 3 +++ .../models/self_supervised/ssl_finetuner.py | 2 ++ .../self_supervised/swav/swav_finetuner.py | 2 ++ .../models/self_supervised/swav/swav_module.py | 3 +++ .../models/self_supervised/swav/swav_resnet.py | 13 +++++++++++++ .../models/self_supervised/swav/transforms.py | 5 +++++ pl_bolts/models/vision/image_gpt/gpt2.py | 4 ++++ pl_bolts/models/vision/image_gpt/igpt_module.py | 4 ++++ pl_bolts/models/vision/pixel_cnn.py | 3 +++ pl_bolts/models/vision/segmentation.py | 3 +++ pl_bolts/models/vision/unet.py | 6 ++++++ pl_bolts/optimizers/lars.py | 3 +++ pl_bolts/optimizers/lr_scheduler.py | 4 ++++ pl_bolts/transforms/dataset_normalizations.py | 5 +++++ .../self_supervised/ssl_transforms.py | 3 +++ 114 files changed, 472 insertions(+) diff --git a/pl_bolts/callbacks/printing.py b/pl_bolts/callbacks/printing.py index 72e41eab0d..eefb04ac52 100644 --- a/pl_bolts/callbacks/printing.py +++ b/pl_bolts/callbacks/printing.py @@ -6,7 +6,10 @@ from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities import rank_zero_info +from pl_bolts.utils.stability import experimental + +@experimental() class PrintTableMetricsCallback(Callback): """Prints a table with the metrics in columns on every epoch end. @@ -41,6 +44,7 @@ def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: rank_zero_info(dicts_to_table(self.metrics)) +@experimental() def dicts_to_table( dicts: List[Dict], keys: Optional[List[str]] = None, diff --git a/pl_bolts/callbacks/sparseml.py b/pl_bolts/callbacks/sparseml.py index 1d5c01cf83..58e8a17198 100644 --- a/pl_bolts/callbacks/sparseml.py +++ b/pl_bolts/callbacks/sparseml.py @@ -23,7 +23,10 @@ from sparseml.pytorch.optim import ScheduledModifierManager from sparseml.pytorch.utils import ModuleExporter +from pl_bolts.utils.stability import experimental + +@experimental() class SparseMLCallback(Callback): """Enables SparseML aware training. Requires a recipe to run during training. diff --git a/pl_bolts/callbacks/ssl_online.py b/pl_bolts/callbacks/ssl_online.py index c8188d7b5a..953d48fdbc 100644 --- a/pl_bolts/callbacks/ssl_online.py +++ b/pl_bolts/callbacks/ssl_online.py @@ -10,8 +10,10 @@ from torchmetrics.functional import accuracy from pl_bolts.models.self_supervised.evaluator import SSLEvaluator +from pl_bolts.utils.stability import experimental +@experimental() class SSLOnlineEvaluator(Callback): # pragma: no cover """Attaches a MLP for fine-tuning using the standard self-supervised protocol. @@ -173,6 +175,7 @@ def on_load_checkpoint(self, trainer: Trainer, pl_module: LightningModule, callb self._recovered_callback_state = callback_state +@experimental() @contextmanager def set_training(module: nn.Module, mode: bool): """Context manager to set training mode. diff --git a/pl_bolts/callbacks/torch_ort.py b/pl_bolts/callbacks/torch_ort.py index 0f3fcb72a7..c4b64a1e10 100644 --- a/pl_bolts/callbacks/torch_ort.py +++ b/pl_bolts/callbacks/torch_ort.py @@ -19,7 +19,10 @@ if _TORCH_ORT_AVAILABLE: from torch_ort import ORTModule +from pl_bolts.utils.stability import experimental + +@experimental() class ORTCallback(Callback): """Enables Torch ORT: Accelerate PyTorch models with ONNX Runtime. diff --git a/pl_bolts/callbacks/variational.py b/pl_bolts/callbacks/variational.py index 6101683e1d..15f4fe1b5f 100644 --- a/pl_bolts/callbacks/variational.py +++ b/pl_bolts/callbacks/variational.py @@ -7,6 +7,7 @@ from torch import Tensor from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -15,6 +16,7 @@ warn_missing_pkg("torchvision") +@experimental() class LatentDimInterpolator(Callback): """Interpolates the latent space for a model by setting all dims to zero and stepping through the first two dims increasing one unit at a time. diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index de2193e224..c126ad85dc 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -9,7 +9,10 @@ from torch._six import string_classes from torch.utils.data import DataLoader, Dataset +from pl_bolts.utils.stability import experimental + +@experimental() class AsynchronousLoader: """Class for asynchronously loading from CPU memory to device memory with DataLoader. diff --git a/pl_bolts/datamodules/binary_emnist_datamodule.py b/pl_bolts/datamodules/binary_emnist_datamodule.py index a6ea5217f0..5435b5722c 100644 --- a/pl_bolts/datamodules/binary_emnist_datamodule.py +++ b/pl_bolts/datamodules/binary_emnist_datamodule.py @@ -3,8 +3,10 @@ from pl_bolts.datamodules.emnist_datamodule import EMNISTDataModule from pl_bolts.datasets import BinaryEMNIST from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental +@experimental() class BinaryEMNISTDataModule(EMNISTDataModule): """ .. figure:: https://user-images.githubusercontent.com/4632336/123210742-4d6b3380-d477-11eb-80da-3e9a74a18a07.png diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 33fee8e360..ed65d82b26 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -3,6 +3,7 @@ from pl_bolts.datamodules.vision_datamodule import VisionDataModule from pl_bolts.datasets import BinaryMNIST from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -11,6 +12,7 @@ warn_missing_pkg("torchvision") +@experimental() class BinaryMNISTDataModule(VisionDataModule): """ .. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index ed7a063427..65b08e4f1f 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -4,6 +4,7 @@ from pl_bolts.datasets import TrialCIFAR10 from pl_bolts.transforms.dataset_normalizations import cifar10_normalization from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -14,6 +15,7 @@ CIFAR10 = None +@experimental() class CIFAR10DataModule(VisionDataModule): """ .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2019/01/ @@ -122,6 +124,7 @@ def default_transforms(self) -> Callable: return cf10_transforms +@experimental() class TinyCIFAR10DataModule(CIFAR10DataModule): """Standard CIFAR10, train, val, test splits and transforms. diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index 4c8e84bd4e..6d1dab5733 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -4,6 +4,7 @@ from torch.utils.data import DataLoader from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -13,6 +14,7 @@ warn_missing_pkg("torchvision") +@experimental() class CityscapesDataModule(LightningDataModule): """ .. figure:: https://www.cityscapes-dataset.com/wordpress/wp-content/uploads/2015/07/muenster00-1024x510.png diff --git a/pl_bolts/datamodules/emnist_datamodule.py b/pl_bolts/datamodules/emnist_datamodule.py index 9fbd0c9605..90a4eb4b4a 100644 --- a/pl_bolts/datamodules/emnist_datamodule.py +++ b/pl_bolts/datamodules/emnist_datamodule.py @@ -3,6 +3,7 @@ from pl_bolts.datamodules.vision_datamodule import VisionDataModule from pl_bolts.transforms.dataset_normalizations import emnist_normalization from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -13,6 +14,7 @@ EMNIST = object +@experimental() class EMNISTDataModule(VisionDataModule): """ .. figure:: https://user-images.githubusercontent.com/4632336/123210742-4d6b3380-d477-11eb-80da-3e9a74a18a07.png diff --git a/pl_bolts/datamodules/experience_source.py b/pl_bolts/datamodules/experience_source.py index ecc69724bc..3aaa8623a2 100644 --- a/pl_bolts/datamodules/experience_source.py +++ b/pl_bolts/datamodules/experience_source.py @@ -8,6 +8,7 @@ from torch.utils.data import IterableDataset from pl_bolts.utils import _GYM_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg try: @@ -23,6 +24,7 @@ Experience = namedtuple("Experience", field_names=["state", "action", "reward", "done", "new_state"]) +@experimental() class ExperienceSourceDataset(IterableDataset): """Basic experience source dataset. @@ -39,6 +41,7 @@ def __iter__(self) -> Iterator: # Experience Sources +@experimental() class BaseExperienceSource(ABC): """Simplest form of the experience source.""" @@ -56,6 +59,7 @@ def runner(self) -> Experience: raise NotImplementedError("ExperienceSource has no stepper method implemented") +@experimental() class ExperienceSource(BaseExperienceSource): """Experience source class handling single and multiple environment steps.""" @@ -231,6 +235,7 @@ def pop_rewards_steps(self): return res +@experimental() class DiscountedExperienceSource(ExperienceSource): """Outputs experiences with a discounted reward over N steps.""" diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index 069515694f..d70402ac24 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -2,6 +2,7 @@ from pl_bolts.datamodules.vision_datamodule import VisionDataModule from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -12,6 +13,7 @@ FashionMNIST = None +@experimental() class FashionMNISTDataModule(VisionDataModule): """ .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/ diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 790d9217fc..4ef3d4c59b 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -7,6 +7,7 @@ from pl_bolts.datasets import UnlabeledImagenet from pl_bolts.transforms.dataset_normalizations import imagenet_normalization from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -15,6 +16,7 @@ warn_missing_pkg("torchvision") +@experimental() class ImagenetDataModule(LightningDataModule): """ .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2017/08/ diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 7e9307441e..9195f3b59d 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -8,6 +8,7 @@ from pl_bolts.datasets import KittiDataset from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -16,6 +17,7 @@ warn_missing_pkg("torchvision") +@experimental() class KittiDataModule(LightningDataModule): name = "kitti" diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index b953fb31f1..93dff65ec4 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -3,6 +3,7 @@ from pl_bolts.datamodules.vision_datamodule import VisionDataModule from pl_bolts.datasets import MNIST from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -11,6 +12,7 @@ warn_missing_pkg("torchvision") +@experimental() class MNISTDataModule(VisionDataModule): """ .. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index 66ab64ad94..49d5e3efa5 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -8,6 +8,7 @@ from torch.utils.data import DataLoader, Dataset from pl_bolts.utils import _SKLEARN_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _SKLEARN_AVAILABLE: @@ -16,6 +17,7 @@ warn_missing_pkg("sklearn") +@experimental() class SklearnDataset(Dataset): """Mapping between numpy (or sklearn) datasets to PyTorch datasets. @@ -63,6 +65,7 @@ def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]: return x, y +@experimental() class TensorDataset(Dataset): """Prepare PyTorch tensor dataset for data loaders. @@ -106,6 +109,7 @@ def __getitem__(self, idx) -> Tuple[Tensor, Tensor]: return x, y +@experimental() class SklearnDataModule(LightningDataModule): """Automatically generates the train, validation and test splits for a Numpy dataset. They are set up as dataloaders for convenience. Optionally, you can pass in your own validation and test splits. diff --git a/pl_bolts/datamodules/sr_datamodule.py b/pl_bolts/datamodules/sr_datamodule.py index 7a2c06dacf..4d5e798bc7 100644 --- a/pl_bolts/datamodules/sr_datamodule.py +++ b/pl_bolts/datamodules/sr_datamodule.py @@ -3,7 +3,10 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset +from pl_bolts.utils.stability import experimental + +@experimental() class TVTDataModule(LightningDataModule): """Simple DataModule creating train, val, and test dataloaders from given train, val, and test dataset. diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 0cd7408fc8..fa89b8357c 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -7,6 +7,7 @@ from pl_bolts.datasets import UnlabeledImagenet from pl_bolts.transforms.dataset_normalizations import imagenet_normalization from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -15,6 +16,7 @@ warn_missing_pkg("torchvision") +@experimental() class SSLImagenetDataModule(LightningDataModule): # pragma: no cover name = "imagenet" diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 345bfe7c91..87bee938e4 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -8,6 +8,7 @@ from pl_bolts.datasets import ConcatDataset from pl_bolts.transforms.dataset_normalizations import stl10_normalization from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -17,6 +18,7 @@ warn_missing_pkg("torchvision") +@experimental() class STL10DataModule(LightningDataModule): # pragma: no cover """ .. figure:: https://samyzaf.com/ML/cifar10/cifar1.jpg diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index 372c973724..29b09a3b5b 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -6,7 +6,10 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset, random_split +from pl_bolts.utils.stability import experimental + +@experimental() class VisionDataModule(LightningDataModule): EXTRA_ARGS: dict = {} diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 15e3fc8a80..603f0c9815 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -7,6 +7,7 @@ from torch.utils.data import DataLoader, Dataset from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -16,6 +17,7 @@ warn_missing_pkg("torchvision") +@experimental() class Compose: """Like `torchvision.transforms.compose` but works for (image, target)""" @@ -60,6 +62,7 @@ def _collate_fn(batch: List[Tensor]) -> tuple: ) +@experimental() def _prepare_voc_instance(image: Any, target: Dict[str, Any]): """Prepares VOC dataset into appropriate target for fasterrcnn. @@ -101,6 +104,7 @@ def _prepare_voc_instance(image: Any, target: Dict[str, Any]): return image, target +@experimental() class VOCDetectionDataModule(LightningDataModule): """TODO(teddykoker) docstring.""" diff --git a/pl_bolts/datasets/base_dataset.py b/pl_bolts/datasets/base_dataset.py index c9e6ebf7b1..1acc0a9d81 100644 --- a/pl_bolts/datasets/base_dataset.py +++ b/pl_bolts/datasets/base_dataset.py @@ -8,7 +8,10 @@ from torch import Tensor from torch.utils.data import Dataset +from pl_bolts.utils.stability import experimental + +@experimental() class LightDataset(ABC, Dataset): data: Tensor diff --git a/pl_bolts/datasets/cifar10_dataset.py b/pl_bolts/datasets/cifar10_dataset.py index c864f0da07..9bf89cfa41 100644 --- a/pl_bolts/datasets/cifar10_dataset.py +++ b/pl_bolts/datasets/cifar10_dataset.py @@ -8,6 +8,7 @@ from pl_bolts.datasets import LightDataset from pl_bolts.utils import _PIL_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _PIL_AVAILABLE: @@ -16,6 +17,7 @@ warn_missing_pkg("PIL", pypi_name="Pillow") +@experimental() class CIFAR10(LightDataset): """Customized `CIFAR10 `_ dataset for testing Pytorch Lightning without the torchvision dependency. @@ -153,6 +155,7 @@ def download(self, data_folder: str) -> None: self._download_from_url(self.BASE_URL, data_folder, self.FILE_NAME) +@experimental() class TrialCIFAR10(CIFAR10): """ Customized `CIFAR10 `_ dataset for testing Pytorch Lightning diff --git a/pl_bolts/datasets/concat_dataset.py b/pl_bolts/datasets/concat_dataset.py index ae09a37c7f..b43048b6f7 100644 --- a/pl_bolts/datasets/concat_dataset.py +++ b/pl_bolts/datasets/concat_dataset.py @@ -1,6 +1,9 @@ from torch.utils.data import Dataset +from pl_bolts.utils.stability import experimental + +@experimental() class ConcatDataset(Dataset): def __init__(self, *datasets): self.datasets = datasets diff --git a/pl_bolts/datasets/dummy_dataset.py b/pl_bolts/datasets/dummy_dataset.py index d3b27d2881..74054e6fc1 100644 --- a/pl_bolts/datasets/dummy_dataset.py +++ b/pl_bolts/datasets/dummy_dataset.py @@ -1,7 +1,10 @@ import torch from torch.utils.data import Dataset +from pl_bolts.utils.stability import experimental + +@experimental() class DummyDataset(Dataset): """Generate a dummy dataset. @@ -41,6 +44,7 @@ def __getitem__(self, idx: int): return sample +@experimental() class DummyDetectionDataset(Dataset): """Generate a dummy dataset for detection. @@ -81,6 +85,7 @@ def __getitem__(self, idx: int): return img, {"boxes": boxes, "labels": labels} +@experimental() class RandomDictDataset(Dataset): """Generate a dummy dataset with a dict structure. @@ -109,6 +114,7 @@ def __len__(self): return self.len +@experimental() class RandomDictStringDataset(Dataset): """Generate a dummy dataset with strings. @@ -135,6 +141,7 @@ def __len__(self): return self.len +@experimental() class RandomDataset(Dataset): """Generate a dummy dataset. diff --git a/pl_bolts/datasets/emnist_dataset.py b/pl_bolts/datasets/emnist_dataset.py index 65848804be..fb3628d2aa 100644 --- a/pl_bolts/datasets/emnist_dataset.py +++ b/pl_bolts/datasets/emnist_dataset.py @@ -1,4 +1,5 @@ from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -13,6 +14,7 @@ warn_missing_pkg("PIL", pypi_name="Pillow") +@experimental() class BinaryEMNIST(EMNIST): def __getitem__(self, idx): """ diff --git a/pl_bolts/datasets/imagenet_dataset.py b/pl_bolts/datasets/imagenet_dataset.py index d2cb9ffc77..9195a5a57b 100644 --- a/pl_bolts/datasets/imagenet_dataset.py +++ b/pl_bolts/datasets/imagenet_dataset.py @@ -12,6 +12,7 @@ import torch from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg PY3 = sys.version_info[0] == 3 @@ -24,6 +25,7 @@ ImageNet = object +@experimental() class UnlabeledImagenet(ImageNet): """Official train set gets split into train, val. (using nb_imgs_per_val_class for each class). Official validation becomes test set. @@ -158,6 +160,7 @@ def generate_meta_bins(cls, devkit_dir): print(f"meta.bin generated at {devkit_dir}/meta.bin") +@experimental() def _verify_archive(root, file, md5): if not _check_integrity(os.path.join(root, file), md5): raise RuntimeError( @@ -166,6 +169,7 @@ def _verify_archive(root, file, md5): ) +@experimental() def _check_integrity(fpath, md5=None): if not os.path.isfile(fpath): return False @@ -174,10 +178,12 @@ def _check_integrity(fpath, md5=None): return _check_md5(fpath, md5) +@experimental() def _check_md5(fpath, md5, **kwargs): return md5 == _calculate_md5(fpath, **kwargs) +@experimental() def _calculate_md5(fpath, chunk_size=1024 * 1024): md5 = hashlib.md5() with open(fpath, "rb") as f: @@ -186,6 +192,7 @@ def _calculate_md5(fpath, chunk_size=1024 * 1024): return md5.hexdigest() +@experimental() def parse_devkit_archive(root, file=None): """Parse the devkit archive of the ImageNet2012 classification dataset and save the meta information in a binary file. @@ -242,6 +249,7 @@ def get_tmp_dir(): torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE)) +@experimental() def extract_archive(from_path, to_path=None, remove_finished=False): if to_path is None: to_path = os.path.dirname(from_path) @@ -272,21 +280,26 @@ def extract_archive(from_path, to_path=None, remove_finished=False): os.remove(from_path) +@experimental() def _is_targz(filename): return filename.endswith(".tar.gz") +@experimental() def _is_tarxz(filename): return filename.endswith(".tar.xz") +@experimental() def _is_gzip(filename): return filename.endswith(".gz") and not filename.endswith(".tar.gz") +@experimental() def _is_tar(filename): return filename.endswith(".tar") +@experimental() def _is_zip(filename): return filename.endswith(".zip") diff --git a/pl_bolts/datasets/kitti_dataset.py b/pl_bolts/datasets/kitti_dataset.py index aaa559667b..88e86b1c6a 100644 --- a/pl_bolts/datasets/kitti_dataset.py +++ b/pl_bolts/datasets/kitti_dataset.py @@ -4,6 +4,7 @@ from torch.utils.data import Dataset from pl_bolts.utils import _PIL_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _PIL_AVAILABLE: @@ -15,6 +16,7 @@ DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33) +@experimental() class KittiDataset(Dataset): """ Note: diff --git a/pl_bolts/datasets/mnist_dataset.py b/pl_bolts/datasets/mnist_dataset.py index d0078620be..8c0486d90a 100644 --- a/pl_bolts/datasets/mnist_dataset.py +++ b/pl_bolts/datasets/mnist_dataset.py @@ -1,4 +1,5 @@ from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE, _TORCHVISION_LESS_THAN_0_9_1 +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -36,6 +37,7 @@ ] +@experimental() class BinaryMNIST(MNIST): def __getitem__(self, idx): """ diff --git a/pl_bolts/datasets/sr_celeba_dataset.py b/pl_bolts/datasets/sr_celeba_dataset.py index f912ced3b1..5f3eaaa19c 100644 --- a/pl_bolts/datasets/sr_celeba_dataset.py +++ b/pl_bolts/datasets/sr_celeba_dataset.py @@ -3,6 +3,7 @@ from pl_bolts.datasets.sr_dataset_mixin import SRDatasetMixin from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _PIL_AVAILABLE: @@ -17,6 +18,7 @@ CelebA = object +@experimental() class SRCelebA(SRDatasetMixin, CelebA): """CelebA dataset that can be used to train Super Resolution models. diff --git a/pl_bolts/datasets/sr_dataset_mixin.py b/pl_bolts/datasets/sr_dataset_mixin.py index 17bd176f92..a35b366dc9 100644 --- a/pl_bolts/datasets/sr_dataset_mixin.py +++ b/pl_bolts/datasets/sr_dataset_mixin.py @@ -4,6 +4,7 @@ import torch from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _PIL_AVAILABLE: @@ -17,6 +18,7 @@ warn_missing_pkg("torchvision") +@experimental() class SRDatasetMixin: """Mixin for Super Resolution datasets. diff --git a/pl_bolts/datasets/sr_mnist_dataset.py b/pl_bolts/datasets/sr_mnist_dataset.py index 70fb7c2c23..633d990be5 100644 --- a/pl_bolts/datasets/sr_mnist_dataset.py +++ b/pl_bolts/datasets/sr_mnist_dataset.py @@ -3,6 +3,7 @@ from pl_bolts.datasets.mnist_dataset import MNIST from pl_bolts.datasets.sr_dataset_mixin import SRDatasetMixin from pl_bolts.utils import _PIL_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _PIL_AVAILABLE: @@ -11,6 +12,7 @@ warn_missing_pkg("PIL", pypi_name="Pillow") +@experimental() class SRMNIST(SRDatasetMixin, MNIST): """MNIST dataset that can be used to train Super Resolution models. diff --git a/pl_bolts/datasets/sr_stl10_dataset.py b/pl_bolts/datasets/sr_stl10_dataset.py index 868565fd22..1ecd360a56 100644 --- a/pl_bolts/datasets/sr_stl10_dataset.py +++ b/pl_bolts/datasets/sr_stl10_dataset.py @@ -4,6 +4,7 @@ from pl_bolts.datasets.sr_dataset_mixin import SRDatasetMixin from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _PIL_AVAILABLE: @@ -18,6 +19,7 @@ STL10 = object +@experimental() class SRSTL10(SRDatasetMixin, STL10): """STL10 dataset that can be used to train Super Resolution models. diff --git a/pl_bolts/datasets/ssl_amdim_datasets.py b/pl_bolts/datasets/ssl_amdim_datasets.py index 3558fe1e9d..5efa2d7dfb 100644 --- a/pl_bolts/datasets/ssl_amdim_datasets.py +++ b/pl_bolts/datasets/ssl_amdim_datasets.py @@ -4,6 +4,7 @@ import numpy as np from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -13,6 +14,7 @@ CIFAR10 = object +@experimental() class SSLDatasetMixin(ABC): @classmethod def generate_train_val_split(cls, examples, labels, pct_val): @@ -89,6 +91,7 @@ def deterministic_shuffle(cls, x, y): return x, y +@experimental() class CIFAR10Mixed(SSLDatasetMixin, CIFAR10): def __init__( self, diff --git a/pl_bolts/datasets/utils.py b/pl_bolts/datasets/utils.py index 77946beb5b..2dbb283f78 100644 --- a/pl_bolts/datasets/utils.py +++ b/pl_bolts/datasets/utils.py @@ -3,8 +3,10 @@ from pl_bolts.datasets.sr_celeba_dataset import SRCelebA from pl_bolts.datasets.sr_mnist_dataset import SRMNIST from pl_bolts.datasets.sr_stl10_dataset import SRSTL10 +from pl_bolts.utils.stability import experimental +@experimental() def prepare_sr_datasets(dataset: str, scale_factor: int, data_dir: str): """Creates train, val, and test datasets for training a Super Resolution GAN. diff --git a/pl_bolts/losses/object_detection.py b/pl_bolts/losses/object_detection.py index 7955e8d437..a7b8e15fc9 100644 --- a/pl_bolts/losses/object_detection.py +++ b/pl_bolts/losses/object_detection.py @@ -3,8 +3,10 @@ from torch import Tensor from pl_bolts.metrics.object_detection import giou, iou +from pl_bolts.utils.stability import experimental +@experimental() def iou_loss(preds: Tensor, target: Tensor) -> Tensor: """Calculates the intersection over union loss. @@ -28,6 +30,7 @@ def iou_loss(preds: Tensor, target: Tensor) -> Tensor: return loss +@experimental() def giou_loss(preds: Tensor, target: Tensor) -> Tensor: """Calculates the generalized intersection over union loss. diff --git a/pl_bolts/losses/rl.py b/pl_bolts/losses/rl.py index bc0fb7b36d..3e86bc0d71 100644 --- a/pl_bolts/losses/rl.py +++ b/pl_bolts/losses/rl.py @@ -6,7 +6,10 @@ import torch from torch import Tensor, nn +from pl_bolts.utils.stability import experimental + +@experimental() def dqn_loss(batch: Tuple[Tensor, Tensor], net: nn.Module, target_net: nn.Module, gamma: float = 0.99) -> Tensor: """Calculates the mse loss using a mini batch from the replay buffer. @@ -35,6 +38,7 @@ def dqn_loss(batch: Tuple[Tensor, Tensor], net: nn.Module, target_net: nn.Module return nn.MSELoss()(state_action_values, expected_state_action_values) +@experimental() def double_dqn_loss( batch: Tuple[Tensor, Tensor], net: nn.Module, @@ -80,6 +84,7 @@ def double_dqn_loss( return nn.MSELoss()(state_action_values, expected_state_action_values) +@experimental() def per_dqn_loss( batch: Tuple[Tensor, Tensor], batch_weights: List, diff --git a/pl_bolts/losses/self_supervised_learning.py b/pl_bolts/losses/self_supervised_learning.py index a46939d7ff..e0283f751a 100644 --- a/pl_bolts/losses/self_supervised_learning.py +++ b/pl_bolts/losses/self_supervised_learning.py @@ -3,8 +3,10 @@ from torch import nn from pl_bolts.models.vision.pixel_cnn import PixelCNN +from pl_bolts.utils.stability import experimental +@experimental() def nt_xent_loss(out_1, out_2, temperature): """Loss used in SimCLR.""" out = torch.cat([out_1, out_2], dim=0) @@ -26,6 +28,7 @@ def nt_xent_loss(out_1, out_2, temperature): return loss +@experimental() class CPCTask(nn.Module): """Loss used in CPC.""" @@ -87,6 +90,7 @@ def forward(self, Z): return loss +@experimental() class AmdimNCELoss(nn.Module): """Compute the NCE scores for predicting r_src->r_trg.""" @@ -181,6 +185,7 @@ def forward(self, anchor_representations, positive_representations, mask_mat): return nce_scores, lgt_reg +@experimental() class FeatureMapContrastiveTask(nn.Module): """Performs an anchor, positive negative pair comparison for each each tuple of feature maps passed. @@ -365,6 +370,7 @@ def forward(self, anchor_maps, positive_maps): return torch.stack(losses), regularizer +@experimental() def tanh_clip(x, clip_val=10.0): """soft clip values to the range [-clip_val, +clip_val]""" if clip_val is not None: diff --git a/pl_bolts/metrics/aggregation.py b/pl_bolts/metrics/aggregation.py index 336248cbf4..4e70d7fcb3 100644 --- a/pl_bolts/metrics/aggregation.py +++ b/pl_bolts/metrics/aggregation.py @@ -1,11 +1,15 @@ import torch +from pl_bolts.utils.stability import experimental + +@experimental() def mean(res, key): # recursive mean for multilevel dicts return torch.stack([x[key] if isinstance(x, dict) else mean(x, key) for x in res]).mean() +@experimental() def accuracy(preds, labels): preds = preds.float() max_lgt = torch.max(preds, 1)[1] @@ -16,6 +20,7 @@ def accuracy(preds, labels): return acc +@experimental() def precision_at_k(output, target, top_k=(1,)): """Computes the accuracy over the k top predictions for the specified values of k.""" with torch.no_grad(): diff --git a/pl_bolts/metrics/object_detection.py b/pl_bolts/metrics/object_detection.py index 106de2ec0d..bd0f1888ad 100644 --- a/pl_bolts/metrics/object_detection.py +++ b/pl_bolts/metrics/object_detection.py @@ -1,7 +1,10 @@ import torch from torch import Tensor +from pl_bolts.utils.stability import experimental + +@experimental() def iou(preds: Tensor, target: Tensor) -> Tensor: """Calculates the intersection over union. @@ -34,6 +37,7 @@ def iou(preds: Tensor, target: Tensor) -> Tensor: return iou +@experimental() def giou(preds: Tensor, target: Tensor) -> Tensor: """Calculates the generalized intersection over union. diff --git a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py index 0a14d576ad..47350247b1 100644 --- a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py +++ b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py @@ -13,8 +13,10 @@ resnet50_decoder, resnet50_encoder, ) +from pl_bolts.utils.stability import experimental +@experimental() class AE(LightningModule): """Standard AE. @@ -150,6 +152,7 @@ def add_model_specific_args(parent_parser): return parser +@experimental() def cli_main(args=None): from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule diff --git a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py index 77eba1b82f..9114271361 100644 --- a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py +++ b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py @@ -13,8 +13,10 @@ resnet50_decoder, resnet50_encoder, ) +from pl_bolts.utils.stability import experimental +@experimental() class VAE(LightningModule): """Standard VAE with Gaussian Prior and approx posterior. @@ -182,6 +184,7 @@ def add_model_specific_args(parent_parser): return parser +@experimental() def cli_main(args=None): from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule diff --git a/pl_bolts/models/autoencoders/components.py b/pl_bolts/models/autoencoders/components.py index b097cdb966..2abdc9f500 100644 --- a/pl_bolts/models/autoencoders/components.py +++ b/pl_bolts/models/autoencoders/components.py @@ -2,7 +2,10 @@ from torch import nn from torch.nn import functional as F +from pl_bolts.utils.stability import experimental + +@experimental() class Interpolate(nn.Module): """nn.Module wrapper for F.interpolate.""" @@ -14,16 +17,19 @@ def forward(self, x): return F.interpolate(x, size=self.size, scale_factor=self.scale_factor) +@experimental() def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding.""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) +@experimental() def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution.""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) +@experimental() def resize_conv3x3(in_planes, out_planes, scale=1): """upsample + 3x3 convolution with padding to avoid checkerboard artifact.""" if scale == 1: @@ -31,6 +37,7 @@ def resize_conv3x3(in_planes, out_planes, scale=1): return nn.Sequential(Interpolate(scale_factor=scale), conv3x3(in_planes, out_planes)) +@experimental() def resize_conv1x1(in_planes, out_planes, scale=1): """upsample + 1x1 convolution with padding to avoid checkerboard artifact.""" if scale == 1: @@ -38,6 +45,7 @@ def resize_conv1x1(in_planes, out_planes, scale=1): return nn.Sequential(Interpolate(scale_factor=scale), conv1x1(in_planes, out_planes)) +@experimental() class EncoderBlock(nn.Module): """ResNet block, copied from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L35.""" @@ -71,6 +79,7 @@ def forward(self, x): return out +@experimental() class EncoderBottleneck(nn.Module): """ResNet bottleneck, copied from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L75.""" @@ -112,6 +121,7 @@ def forward(self, x): return out +@experimental() class DecoderBlock(nn.Module): """ResNet block, but convs replaced with resize convs, and channel increase is in second conv, not first.""" @@ -145,6 +155,7 @@ def forward(self, x): return out +@experimental() class DecoderBottleneck(nn.Module): """ResNet bottleneck, but convs replaced with resize convs.""" @@ -185,6 +196,7 @@ def forward(self, x): return out +@experimental() class ResNetEncoder(nn.Module): def __init__(self, block, layers, first_conv=False, maxpool1=False): super().__init__() @@ -244,6 +256,7 @@ def forward(self, x): return x +@experimental() class ResNetDecoder(nn.Module): """Resnet in reverse order.""" @@ -316,17 +329,21 @@ def forward(self, x): return x +@experimental() def resnet18_encoder(first_conv, maxpool1): return ResNetEncoder(EncoderBlock, [2, 2, 2, 2], first_conv, maxpool1) +@experimental() def resnet18_decoder(latent_dim, input_height, first_conv, maxpool1): return ResNetDecoder(DecoderBlock, [2, 2, 2, 2], latent_dim, input_height, first_conv, maxpool1) +@experimental() def resnet50_encoder(first_conv, maxpool1): return ResNetEncoder(EncoderBottleneck, [3, 4, 6, 3], first_conv, maxpool1) +@experimental() def resnet50_decoder(latent_dim, input_height, first_conv, maxpool1): return ResNetDecoder(DecoderBottleneck, [3, 4, 6, 3], latent_dim, input_height, first_conv, maxpool1) diff --git a/pl_bolts/models/detection/components/torchvision_backbones.py b/pl_bolts/models/detection/components/torchvision_backbones.py index 1f16c24c6c..98fd2b2bd8 100644 --- a/pl_bolts/models/detection/components/torchvision_backbones.py +++ b/pl_bolts/models/detection/components/torchvision_backbones.py @@ -4,9 +4,11 @@ from pl_bolts.models.detection.components._supported_models import TORCHVISION_MODEL_ZOO from pl_bolts.utils import _TORCHVISION_AVAILABLE # noqa: F401 +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg # noqa: F401 +@experimental() def _create_backbone_generic(model: nn.Module, out_channels: int) -> nn.Module: """Generic Backbone creater. It removes the last linear layer. @@ -23,6 +25,7 @@ def _create_backbone_generic(model: nn.Module, out_channels: int) -> nn.Module: # Use this when you have Adaptive Pooling layer in End. # When Model.features is not applicable. +@experimental() def _create_backbone_adaptive(model: nn.Module, out_channels: Optional[int] = None) -> nn.Module: """Creates backbone by removing linear after Adaptive Pooling layer. @@ -36,6 +39,7 @@ def _create_backbone_adaptive(model: nn.Module, out_channels: Optional[int] = No return _create_backbone_generic(model, out_channels=out_channels) +@experimental() def _create_backbone_features(model: nn.Module, out_channels: int) -> nn.Module: """Creates backbone from feature sequential block. @@ -48,6 +52,7 @@ def _create_backbone_features(model: nn.Module, out_channels: int) -> nn.Module: return ft_backbone +@experimental() def create_torchvision_backbone(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: """Creates CNN backbone from Torchvision. diff --git a/pl_bolts/models/detection/faster_rcnn/backbones.py b/pl_bolts/models/detection/faster_rcnn/backbones.py index 8e20c4b81c..49254b81f2 100644 --- a/pl_bolts/models/detection/faster_rcnn/backbones.py +++ b/pl_bolts/models/detection/faster_rcnn/backbones.py @@ -4,6 +4,7 @@ from pl_bolts.models.detection.components import create_torchvision_backbone from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -12,6 +13,7 @@ warn_missing_pkg("torchvision") +@experimental() def create_fasterrcnn_backbone( backbone: str, fpn: bool = True, pretrained: Optional[str] = None, trainable_backbone_layers: int = 3, **kwargs: Any ) -> nn.Module: diff --git a/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py b/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py index 4231634832..ea49a1369e 100644 --- a/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py +++ b/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py @@ -6,6 +6,7 @@ from pl_bolts.models.detection.faster_rcnn import create_fasterrcnn_backbone from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -16,6 +17,7 @@ warn_missing_pkg("torchvision") +@experimental() def _evaluate_iou(target, pred): """Evaluate intersection over union (IOU) for target from dataset and output prediction from model.""" if not _TORCHVISION_AVAILABLE: # pragma: no cover @@ -27,6 +29,7 @@ def _evaluate_iou(target, pred): return box_iou(target["boxes"], pred["boxes"]).diag().mean() +@experimental() class FasterRCNN(LightningModule): """PyTorch Lightning implementation of `Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks `_. @@ -152,6 +155,7 @@ def add_model_specific_args(parent_parser): return parser +@experimental() def run_cli(): from pl_bolts.datamodules import VOCDetectionDataModule diff --git a/pl_bolts/models/detection/retinanet/backbones.py b/pl_bolts/models/detection/retinanet/backbones.py index c039ea6ac3..ee761f3710 100644 --- a/pl_bolts/models/detection/retinanet/backbones.py +++ b/pl_bolts/models/detection/retinanet/backbones.py @@ -4,6 +4,7 @@ from pl_bolts.models.detection.components import create_torchvision_backbone from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -12,6 +13,7 @@ warn_missing_pkg("torchvision") +@experimental() def create_retinanet_backbone( backbone: str, fpn: bool = True, pretrained: Optional[str] = None, trainable_backbone_layers: int = 3, **kwargs: Any ) -> nn.Module: diff --git a/pl_bolts/models/detection/retinanet/retinanet_module.py b/pl_bolts/models/detection/retinanet/retinanet_module.py index 6708df1b75..18f0c24188 100644 --- a/pl_bolts/models/detection/retinanet/retinanet_module.py +++ b/pl_bolts/models/detection/retinanet/retinanet_module.py @@ -5,6 +5,7 @@ from pl_bolts.models.detection.retinanet import create_retinanet_backbone from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -15,6 +16,7 @@ warn_missing_pkg("torchvision") +@experimental() class RetinaNet(LightningModule): """PyTorch Lightning implementation of RetinaNet. @@ -118,6 +120,7 @@ def configure_optimizers(self): ) +@experimental() def cli_main(): from pytorch_lightning.utilities.cli import LightningCLI diff --git a/pl_bolts/models/detection/yolo/yolo_config.py b/pl_bolts/models/detection/yolo/yolo_config.py index ed6f30a734..5fb12ae16b 100644 --- a/pl_bolts/models/detection/yolo/yolo_config.py +++ b/pl_bolts/models/detection/yolo/yolo_config.py @@ -6,8 +6,10 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pl_bolts.models.detection.yolo import yolo_layers +from pl_bolts.utils.stability import experimental +@experimental() class YOLOConfiguration: """This class can be used to parse the configuration files of the Darknet YOLOv4 implementation. @@ -147,6 +149,7 @@ def convert(key, value): return sections +@experimental() def _create_layer(config: dict, num_inputs: List[int]) -> Tuple[nn.Module, int]: """Calls one of the ``_create_(config, num_inputs)`` functions to create a PyTorch module from the layer config. @@ -170,6 +173,7 @@ def _create_layer(config: dict, num_inputs: List[int]) -> Tuple[nn.Module, int]: return create_func[config["type"]](config, num_inputs) +@experimental() def _create_convolutional(config, num_inputs): module = nn.Sequential() @@ -206,12 +210,14 @@ def _create_convolutional(config, num_inputs): return module, config["filters"] +@experimental() def _create_maxpool(config, num_inputs): padding = (config["size"] - 1) // 2 module = nn.MaxPool2d(config["size"], config["stride"], padding) return module, num_inputs[-1] +@experimental() def _create_route(config, num_inputs): num_chunks = config.get("groups", 1) chunk_idx = config.get("group_id", 0) @@ -228,16 +234,19 @@ def _create_route(config, num_inputs): return module, num_outputs +@experimental() def _create_shortcut(config, num_inputs): module = yolo_layers.ShortcutLayer(config["from"]) return module, num_inputs[-1] +@experimental() def _create_upsample(config, num_inputs): module = nn.Upsample(scale_factor=config["stride"], mode="nearest") return module, num_inputs[-1] +@experimental() def _create_yolo(config, num_inputs): # The "anchors" list alternates width and height. anchor_dims = config["anchors"] diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index 9b1ee891df..be207d7fb6 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -5,6 +5,7 @@ from torch import Tensor, nn from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -20,6 +21,7 @@ warn_missing_pkg("torchvision") +@experimental() def _corner_coordinates(xy: Tensor, wh: Tensor) -> Tensor: """Converts box center points and sizes to corner coordinates. @@ -36,6 +38,7 @@ def _corner_coordinates(xy: Tensor, wh: Tensor) -> Tensor: return torch.cat((top_left, bottom_right), -1) +@experimental() def _aligned_iou(dims1: Tensor, dims2: Tensor) -> Tensor: """Calculates a matrix of intersections over union from box dimensions, assuming that the boxes are located at the same coordinates. @@ -58,6 +61,7 @@ def _aligned_iou(dims1: Tensor, dims2: Tensor) -> Tensor: return inter / union +@experimental() class SELoss(nn.MSELoss): def __init__(self): super().__init__(reduction="none") @@ -66,11 +70,13 @@ def forward(self, inputs: Tensor, target: Tensor) -> Tensor: return super().forward(inputs, target).sum(1) +@experimental() class IoULoss(nn.Module): def forward(self, inputs: Tensor, target: Tensor) -> Tensor: return 1.0 - box_iou(inputs, target).diagonal() +@experimental() class GIoULoss(nn.Module): def __init__(self) -> None: super().__init__() @@ -83,6 +89,7 @@ def forward(self, inputs: Tensor, target: Tensor) -> Tensor: return 1.0 - generalized_box_iou(inputs, target).diagonal() +@experimental() class DetectionLayer(nn.Module): """A YOLO detection layer. @@ -461,6 +468,7 @@ def _calculate_losses( return losses, hits +@experimental() class Mish(nn.Module): """Mish activation.""" @@ -468,6 +476,7 @@ def forward(self, x): return x * torch.tanh(nn.functional.softplus(x)) +@experimental() class RouteLayer(nn.Module): """Route layer concatenates the output (or part of it) from given layers.""" @@ -488,6 +497,7 @@ def forward(self, x, outputs): return torch.cat(chunks, dim=1) +@experimental() class ShortcutLayer(nn.Module): """Shortcut layer adds a residual connection from the source layer.""" diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index ebb494f5ef..3ab036e814 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -11,6 +11,7 @@ from pl_bolts.models.detection.yolo.yolo_layers import DetectionLayer, RouteLayer, ShortcutLayer from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -22,6 +23,7 @@ log = logging.getLogger(__name__) +@experimental() class YOLO(LightningModule): """PyTorch Lightning implementation of YOLOv3 and YOLOv4. @@ -453,6 +455,7 @@ def _filter_detections(self, detections: Dict[str, Tensor]) -> Dict[str, List[Te return {"boxes": out_boxes, "scores": out_scores, "classprobs": out_classprobs, "labels": out_labels} +@experimental() class Resize: """Rescales the image and target to given dimensions. @@ -483,6 +486,7 @@ def __call__(self, image: Tensor, target: Dict[str, Any]): return image, target +@experimental() def run_cli(): from argparse import ArgumentParser diff --git a/pl_bolts/models/gans/basic/basic_gan_module.py b/pl_bolts/models/gans/basic/basic_gan_module.py index 4a4fa20b25..55a29b0f43 100644 --- a/pl_bolts/models/gans/basic/basic_gan_module.py +++ b/pl_bolts/models/gans/basic/basic_gan_module.py @@ -5,8 +5,10 @@ from torch.nn import functional as F from pl_bolts.models.gans.basic.components import Discriminator, Generator +from pl_bolts.utils.stability import experimental +@experimental() class GAN(LightningModule): """Vanilla GAN implementation. @@ -164,6 +166,7 @@ def add_model_specific_args(parent_parser): return parser +@experimental() def cli_main(args=None): from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, MNISTDataModule, STL10DataModule diff --git a/pl_bolts/models/gans/basic/components.py b/pl_bolts/models/gans/basic/components.py index 76cd519ad9..f63af4079f 100644 --- a/pl_bolts/models/gans/basic/components.py +++ b/pl_bolts/models/gans/basic/components.py @@ -3,7 +3,10 @@ from torch import nn from torch.nn import functional as F +from pl_bolts.utils.stability import experimental + +@experimental() class Generator(nn.Module): def __init__(self, latent_dim, img_shape, hidden_dim=256): super().__init__() @@ -24,6 +27,7 @@ def forward(self, z): return img +@experimental() class Discriminator(nn.Module): def __init__(self, img_shape, hidden_dim=1024): super().__init__() diff --git a/pl_bolts/models/gans/dcgan/components.py b/pl_bolts/models/gans/dcgan/components.py index 7432e7b69a..3f3facc9e5 100644 --- a/pl_bolts/models/gans/dcgan/components.py +++ b/pl_bolts/models/gans/dcgan/components.py @@ -1,7 +1,10 @@ # Based on https://github.com/pytorch/examples/blob/master/dcgan/main.py from torch import Tensor, nn +from pl_bolts.utils.stability import experimental + +@experimental() class DCGANGenerator(nn.Module): def __init__(self, latent_dim: int, feature_maps: int, image_channels: int) -> None: """ @@ -47,6 +50,7 @@ def forward(self, noise: Tensor) -> Tensor: return self.gen(noise) +@experimental() class DCGANDiscriminator(nn.Module): def __init__(self, feature_maps: int, image_channels: int) -> None: """ diff --git a/pl_bolts/models/gans/dcgan/dcgan_module.py b/pl_bolts/models/gans/dcgan/dcgan_module.py index 242242fb8b..d9a54c2d60 100644 --- a/pl_bolts/models/gans/dcgan/dcgan_module.py +++ b/pl_bolts/models/gans/dcgan/dcgan_module.py @@ -9,6 +9,7 @@ from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler from pl_bolts.models.gans.dcgan.components import DCGANDiscriminator, DCGANGenerator from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -18,6 +19,7 @@ warn_missing_pkg("torchvision") +@experimental() class DCGAN(LightningModule): """DCGAN implementation. @@ -171,6 +173,7 @@ def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: return parser +@experimental() def cli_main(args=None): seed_everything(1234) diff --git a/pl_bolts/models/gans/pix2pix/components.py b/pl_bolts/models/gans/pix2pix/components.py index c67cf691c8..d865d06576 100644 --- a/pl_bolts/models/gans/pix2pix/components.py +++ b/pl_bolts/models/gans/pix2pix/components.py @@ -1,7 +1,10 @@ import torch from torch import nn +from pl_bolts.utils.stability import experimental + +@experimental() class UpSampleConv(nn.Module): def __init__( self, in_channels, out_channels, kernel=4, strides=2, padding=1, activation=True, batchnorm=True, dropout=False @@ -32,6 +35,7 @@ def forward(self, x): return x +@experimental() class DownSampleConv(nn.Module): def __init__(self, in_channels, out_channels, kernel=4, strides=2, padding=1, activation=True, batchnorm=True): """Paper details: @@ -61,6 +65,7 @@ def forward(self, x): return x +@experimental() class Generator(nn.Module): def __init__(self, in_channels, out_channels): """Paper details: @@ -122,6 +127,7 @@ def forward(self, x): return self.tanh(x) +@experimental() class PatchGAN(nn.Module): def __init__(self, input_channels): super().__init__() diff --git a/pl_bolts/models/gans/pix2pix/pix2pix_module.py b/pl_bolts/models/gans/pix2pix/pix2pix_module.py index 1d00cde1f8..3aeb5b55d4 100644 --- a/pl_bolts/models/gans/pix2pix/pix2pix_module.py +++ b/pl_bolts/models/gans/pix2pix/pix2pix_module.py @@ -3,8 +3,10 @@ from torch import nn from pl_bolts.models.gans.pix2pix.components import Generator, PatchGAN +from pl_bolts.utils.stability import experimental +@experimental() def _weights_init(m): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): torch.nn.init.normal_(m.weight, 0.0, 0.02) @@ -13,6 +15,7 @@ def _weights_init(m): torch.nn.init.constant_(m.bias, 0) +@experimental() class Pix2Pix(LightningModule): def __init__(self, in_channels, out_channels, learning_rate=0.0002, lambda_recon=200): diff --git a/pl_bolts/models/gans/srgan/components.py b/pl_bolts/models/gans/srgan/components.py index 63a06006aa..3ea0fa60f2 100644 --- a/pl_bolts/models/gans/srgan/components.py +++ b/pl_bolts/models/gans/srgan/components.py @@ -3,6 +3,7 @@ import torch.nn as nn from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -11,6 +12,7 @@ warn_missing_pkg("torchvision") +@experimental() class ResidualBlock(nn.Module): def __init__(self, feature_maps: int = 64) -> None: super().__init__() @@ -27,6 +29,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x + self.block(x) +@experimental() class SRGANGenerator(nn.Module): def __init__( self, @@ -78,6 +81,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +@experimental() class SRGANDiscriminator(nn.Module): def __init__(self, image_channels: int, feature_maps: int = 64) -> None: super().__init__() @@ -131,6 +135,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +@experimental() class VGG19FeatureExtractor(nn.Module): def __init__(self, image_channels: int = 3) -> None: super().__init__() diff --git a/pl_bolts/models/gans/srgan/srgan_module.py b/pl_bolts/models/gans/srgan/srgan_module.py index 434731c761..3847ddab2a 100644 --- a/pl_bolts/models/gans/srgan/srgan_module.py +++ b/pl_bolts/models/gans/srgan/srgan_module.py @@ -12,8 +12,10 @@ from pl_bolts.datamodules import TVTDataModule from pl_bolts.datasets.utils import prepare_sr_datasets from pl_bolts.models.gans.srgan.components import SRGANDiscriminator, SRGANGenerator, VGG19FeatureExtractor +from pl_bolts.utils.stability import experimental +@experimental() class SRGAN(pl.LightningModule): """SRGAN implementation from the paper `Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network `__. It uses a pretrained SRResNet model as the generator @@ -181,6 +183,7 @@ def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: return parser +@experimental() def cli_main(args=None): pl.seed_everything(1234) diff --git a/pl_bolts/models/gans/srgan/srresnet_module.py b/pl_bolts/models/gans/srgan/srresnet_module.py index e7fa02bbb1..49d298c4e4 100644 --- a/pl_bolts/models/gans/srgan/srresnet_module.py +++ b/pl_bolts/models/gans/srgan/srresnet_module.py @@ -10,8 +10,10 @@ from pl_bolts.datamodules import TVTDataModule from pl_bolts.datasets.utils import prepare_sr_datasets from pl_bolts.models.gans.srgan.components import SRGANGenerator +from pl_bolts.utils.stability import experimental +@experimental() class SRResNet(pl.LightningModule): """SRResNet implementation from the paper `Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network `__. A pretrained SRResNet model is used as the generator @@ -108,6 +110,7 @@ def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: return parser +@experimental() def cli_main(args=None): pl.seed_everything(1234) diff --git a/pl_bolts/models/mnist_module.py b/pl_bolts/models/mnist_module.py index 0ca65c6d9f..189e4b2309 100644 --- a/pl_bolts/models/mnist_module.py +++ b/pl_bolts/models/mnist_module.py @@ -7,6 +7,7 @@ from pl_bolts.datasets import MNIST from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -15,6 +16,7 @@ warn_missing_pkg("torchvision") +@experimental() class LitMNIST(LightningModule): def __init__(self, hidden_dim=128, learning_rate=1e-3, batch_size=32, num_workers=4, data_dir="", **kwargs): if not _TORCHVISION_AVAILABLE: # pragma: no cover @@ -88,6 +90,7 @@ def add_model_specific_args(parent_parser): return parser +@experimental() def cli_main(): # args parser = ArgumentParser() diff --git a/pl_bolts/models/regression/linear_regression.py b/pl_bolts/models/regression/linear_regression.py index 6abf5227af..816cb254b2 100644 --- a/pl_bolts/models/regression/linear_regression.py +++ b/pl_bolts/models/regression/linear_regression.py @@ -8,7 +8,10 @@ from torch.optim import Adam from torch.optim.optimizer import Optimizer +from pl_bolts.utils.stability import experimental + +@experimental() class LinearRegression(LightningModule): """ Linear regression model implementing - with optional L1/L2 regularization @@ -109,6 +112,7 @@ def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: return parser +@experimental() def cli_main() -> None: from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule from pl_bolts.utils import _SKLEARN_AVAILABLE diff --git a/pl_bolts/models/regression/logistic_regression.py b/pl_bolts/models/regression/logistic_regression.py index 6796c84eb5..34fe809737 100644 --- a/pl_bolts/models/regression/logistic_regression.py +++ b/pl_bolts/models/regression/logistic_regression.py @@ -10,7 +10,10 @@ from torch.optim.optimizer import Optimizer from torchmetrics.functional import accuracy +from pl_bolts.utils.stability import experimental + +@experimental() class LogisticRegression(LightningModule): """Logistic regression model.""" @@ -115,6 +118,7 @@ def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: return parser +@experimental() def cli_main() -> None: from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule from pl_bolts.utils import _SKLEARN_AVAILABLE diff --git a/pl_bolts/models/rl/advantage_actor_critic_model.py b/pl_bolts/models/rl/advantage_actor_critic_model.py index a27aa9ac69..23323d03ee 100644 --- a/pl_bolts/models/rl/advantage_actor_critic_model.py +++ b/pl_bolts/models/rl/advantage_actor_critic_model.py @@ -15,6 +15,7 @@ from pl_bolts.models.rl.common.agents import ActorCriticAgent from pl_bolts.models.rl.common.networks import ActorCriticMLP from pl_bolts.utils import _GYM_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _GYM_AVAILABLE: @@ -23,6 +24,7 @@ warn_missing_pkg("gym") +@experimental() class AdvantageActorCritic(LightningModule): """PyTorch Lightning implementation of `Advantage Actor Critic `_. @@ -294,6 +296,7 @@ def add_model_specific_args(arg_parser: ArgumentParser) -> ArgumentParser: return arg_parser +@experimental() def cli_main() -> None: parser = ArgumentParser(add_help=False) diff --git a/pl_bolts/models/rl/common/agents.py b/pl_bolts/models/rl/common/agents.py index cbefa5d635..c3b551bbc5 100644 --- a/pl_bolts/models/rl/common/agents.py +++ b/pl_bolts/models/rl/common/agents.py @@ -10,7 +10,10 @@ from torch import Tensor, nn from torch.nn import functional as F +from pl_bolts.utils.stability import experimental + +@experimental() class Agent(ABC): """Basic agent that always returns 0.""" @@ -30,6 +33,7 @@ def __call__(self, state: Tensor, device: str, *args, **kwargs) -> List[int]: return [0] +@experimental() class ValueAgent(Agent): """Value based agent that returns an action based on the Q values from the network.""" @@ -105,6 +109,7 @@ def update_epsilon(self, step: int) -> None: self.epsilon = max(self.eps_end, self.eps_start - (step + 1) / self.eps_frames) +@experimental() class PolicyAgent(Agent): """Policy based agent that returns an action based on the networks policy.""" @@ -134,6 +139,7 @@ def __call__(self, states: Tensor, device: str) -> List[int]: return actions +@experimental() class ActorCriticAgent(Agent): """Actor-Critic based agent that returns an action based on the networks policy.""" @@ -163,6 +169,7 @@ def __call__(self, states: Tensor, device: str) -> List[int]: return actions +@experimental() class SoftActorCriticAgent(Agent): """Actor-Critic based agent that returns a continuous action based on the policy.""" diff --git a/pl_bolts/models/rl/common/cli.py b/pl_bolts/models/rl/common/cli.py index 24c8970e0e..e7ab729bdb 100644 --- a/pl_bolts/models/rl/common/cli.py +++ b/pl_bolts/models/rl/common/cli.py @@ -2,7 +2,10 @@ import argparse +from pl_bolts.utils.stability import experimental + +@experimental() def add_base_args(parent) -> argparse.ArgumentParser: """Adds arguments for DQN model. diff --git a/pl_bolts/models/rl/common/distributions.py b/pl_bolts/models/rl/common/distributions.py index 2a7f945ecb..759f7e235d 100644 --- a/pl_bolts/models/rl/common/distributions.py +++ b/pl_bolts/models/rl/common/distributions.py @@ -1,7 +1,10 @@ """Distributions used in some continuous RL algorithms.""" import torch +from pl_bolts.utils.stability import experimental + +@experimental() class TanhMultivariateNormal(torch.distributions.MultivariateNormal): """The distribution of X is an affine of tanh applied on a normal distribution. diff --git a/pl_bolts/models/rl/common/gym_wrappers.py b/pl_bolts/models/rl/common/gym_wrappers.py index f646868a0c..8eb80e701d 100644 --- a/pl_bolts/models/rl/common/gym_wrappers.py +++ b/pl_bolts/models/rl/common/gym_wrappers.py @@ -6,6 +6,7 @@ import torch from pl_bolts.utils import _GYM_AVAILABLE, _OPENCV_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _GYM_AVAILABLE: @@ -23,6 +24,7 @@ warn_missing_pkg("cv2", pypi_name="opencv-python") +@experimental() class ToTensor(Wrapper): """For environments where the user need to press FIRE for the game to start.""" @@ -42,6 +44,7 @@ def reset(self): return torch.tensor(self.env.reset()) +@experimental() class FireResetEnv(Wrapper): """For environments where the user need to press FIRE for the game to start.""" @@ -69,6 +72,7 @@ def reset(self): return obs +@experimental() class MaxAndSkipEnv(Wrapper): """Return only every `skip`-th frame.""" @@ -105,6 +109,7 @@ def reset(self): return obs +@experimental() class ProcessFrame84(ObservationWrapper): """preprocessing images from env.""" @@ -135,6 +140,7 @@ def process(frame): return x_t.astype(np.uint8) +@experimental() class ImageToPyTorch(ObservationWrapper): """converts image to pytorch format.""" @@ -153,6 +159,7 @@ def observation(observation): return np.moveaxis(observation, 2, 0) +@experimental() class ScaledFloatFrame(ObservationWrapper): """scales the pixels.""" @@ -161,6 +168,7 @@ def observation(obs): return np.array(obs).astype(np.float32) / 255.0 +@experimental() class BufferWrapper(ObservationWrapper): """Wrapper for image stacking.""" @@ -187,6 +195,7 @@ def observation(self, observation): return self.buffer +@experimental() class DataAugmentation(ObservationWrapper): """Carries out basic data augmentation on the env observations. @@ -207,6 +216,7 @@ def observation(self, obs): return ProcessFrame84.process(obs) +@experimental() def make_environment(env_name): """Convert environment with wrappers.""" env = gym_make(env_name) diff --git a/pl_bolts/models/rl/common/memory.py b/pl_bolts/models/rl/common/memory.py index 8d18cee608..611c13fb65 100644 --- a/pl_bolts/models/rl/common/memory.py +++ b/pl_bolts/models/rl/common/memory.py @@ -7,9 +7,12 @@ import numpy as np +from pl_bolts.utils.stability import experimental + Experience = namedtuple("Experience", field_names=["state", "action", "reward", "done", "new_state"]) +@experimental() class Buffer: """Basic Buffer for storing a single experience at a time.""" @@ -51,6 +54,7 @@ def sample(self, *args) -> Union[Tuple, List[Tuple]]: ) +@experimental() class ReplayBuffer(Buffer): """Replay Buffer for storing past experiences allowing the agent to learn from them.""" @@ -76,6 +80,7 @@ def sample(self, batch_size: int) -> Tuple: ) +@experimental() class MultiStepBuffer(ReplayBuffer): """N Step Replay Buffer.""" @@ -184,6 +189,7 @@ def discount_rewards(self, experiences: Tuple[Experience]) -> float: return total_reward +@experimental() class MeanBuffer: """Stores a deque of items and calculates the mean.""" @@ -206,6 +212,7 @@ def mean(self) -> float: return self.sum / len(self.deque) +@experimental() class PERBuffer(ReplayBuffer): """simple list based Prioritized Experience Replay Buffer Based on implementation found here: diff --git a/pl_bolts/models/rl/common/networks.py b/pl_bolts/models/rl/common/networks.py index 476aae54ea..533a34e162 100644 --- a/pl_bolts/models/rl/common/networks.py +++ b/pl_bolts/models/rl/common/networks.py @@ -9,8 +9,10 @@ from torch.nn import functional as F from pl_bolts.models.rl.common.distributions import TanhMultivariateNormal +from pl_bolts.utils.stability import experimental +@experimental() class CNN(nn.Module): """Simple MLP network.""" @@ -57,6 +59,7 @@ def forward(self, input_x) -> Tensor: return self.head(conv_out) +@experimental() class MLP(nn.Module): """Simple MLP network.""" @@ -86,6 +89,7 @@ def forward(self, input_x): return self.net(input_x.float()) +@experimental() class ContinuousMLP(nn.Module): """MLP network that outputs continuous value via Gaussian distribution.""" @@ -144,6 +148,7 @@ def get_action(self, x: FloatTensor) -> Tensor: return self.action_scale * torch.tanh(batch_mean) + self.action_bias +@experimental() class ActorCriticMLP(nn.Module): """MLP network with heads for actor and critic.""" @@ -175,6 +180,7 @@ def forward(self, x) -> Tuple[Tensor, Tensor]: return a, c +@experimental() class DuelingMLP(nn.Module): """MLP network with duel heads for val and advantage.""" @@ -227,6 +233,7 @@ def adv_val(self, input_x) -> Tuple[Tensor, Tensor]: return self.fc_adv(base_out), self.fc_val(base_out) +@experimental() class DuelingCNN(nn.Module): """CNN network with duel heads for val and advantage.""" @@ -295,6 +302,7 @@ def adv_val(self, input_x): return self.head_adv(base_out), self.head_val(base_out) +@experimental() class NoisyCNN(nn.Module): """CNN with Noisy Linear layers for exploration.""" @@ -348,6 +356,7 @@ def forward(self, input_x) -> Tensor: ################### +@experimental() class NoisyLinear(nn.Linear): """Noisy Layer using Independent Gaussian Noise. @@ -404,6 +413,7 @@ def forward(self, input_x: Tensor) -> Tensor: return F.linear(input_x, noisy_weights, bias) +@experimental() class ActorCategorical(nn.Module): """Policy network, for discrete action spaces, which returns a distribution and an action given an observation.""" @@ -437,6 +447,7 @@ def get_log_prob(self, pi: Categorical, actions: Tensor): return pi.log_prob(actions) +@experimental() class ActorContinous(nn.Module): """Policy network, for continous action spaces, which returns a distribution and an action given an observation.""" diff --git a/pl_bolts/models/rl/double_dqn_model.py b/pl_bolts/models/rl/double_dqn_model.py index 267a590445..ad0f53b0bc 100644 --- a/pl_bolts/models/rl/double_dqn_model.py +++ b/pl_bolts/models/rl/double_dqn_model.py @@ -8,8 +8,10 @@ from pl_bolts.losses.rl import double_dqn_loss from pl_bolts.models.rl.dqn_model import DQN +from pl_bolts.utils.stability import experimental +@experimental() class DoubleDQN(DQN): """Double Deep Q-network (DDQN) PyTorch Lightning implementation of `Double DQN`_. @@ -79,6 +81,7 @@ def training_step(self, batch: Tuple[Tensor, Tensor], _) -> OrderedDict: ) +@experimental() def cli_main(): parser = argparse.ArgumentParser(add_help=False) diff --git a/pl_bolts/models/rl/dqn_model.py b/pl_bolts/models/rl/dqn_model.py index b4a93fae48..ce73ba9ccb 100644 --- a/pl_bolts/models/rl/dqn_model.py +++ b/pl_bolts/models/rl/dqn_model.py @@ -19,6 +19,7 @@ from pl_bolts.models.rl.common.memory import MultiStepBuffer from pl_bolts.models.rl.common.networks import CNN from pl_bolts.utils import _GYM_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _GYM_AVAILABLE: @@ -28,6 +29,7 @@ Env = object +@experimental() class DQN(LightningModule): """Basic DQN Model. @@ -408,6 +410,7 @@ def _use_dp_or_ddp2(trainer: Trainer) -> bool: return isinstance(trainer.training_type_plugin, (DataParallelPlugin, DDP2Plugin)) +@experimental() def cli_main(): parser = argparse.ArgumentParser(add_help=False) diff --git a/pl_bolts/models/rl/dueling_dqn_model.py b/pl_bolts/models/rl/dueling_dqn_model.py index 142990fc4b..018ba8d83f 100644 --- a/pl_bolts/models/rl/dueling_dqn_model.py +++ b/pl_bolts/models/rl/dueling_dqn_model.py @@ -5,8 +5,10 @@ from pl_bolts.models.rl.common.networks import DuelingCNN from pl_bolts.models.rl.dqn_model import DQN +from pl_bolts.utils.stability import experimental +@experimental() class DuelingDQN(DQN): """PyTorch Lightning implementation of `Dueling DQN `_ @@ -36,6 +38,7 @@ def build_networks(self) -> None: self.target_net = DuelingCNN(self.obs_shape, self.n_actions) +@experimental() def cli_main(): parser = argparse.ArgumentParser(add_help=False) diff --git a/pl_bolts/models/rl/noisy_dqn_model.py b/pl_bolts/models/rl/noisy_dqn_model.py index a7e88eacb0..b4aa5b6aed 100644 --- a/pl_bolts/models/rl/noisy_dqn_model.py +++ b/pl_bolts/models/rl/noisy_dqn_model.py @@ -9,8 +9,10 @@ from pl_bolts.datamodules.experience_source import Experience from pl_bolts.models.rl.common.networks import NoisyCNN from pl_bolts.models.rl.dqn_model import DQN +from pl_bolts.utils.stability import experimental +@experimental() class NoisyDQN(DQN): """PyTorch Lightning implementation of `Noisy DQN `_ @@ -89,6 +91,7 @@ def train_batch( break +@experimental() def cli_main(): parser = argparse.ArgumentParser(add_help=False) diff --git a/pl_bolts/models/rl/per_dqn_model.py b/pl_bolts/models/rl/per_dqn_model.py index 3e2f823898..34db230eb7 100644 --- a/pl_bolts/models/rl/per_dqn_model.py +++ b/pl_bolts/models/rl/per_dqn_model.py @@ -12,8 +12,10 @@ from pl_bolts.losses.rl import per_dqn_loss from pl_bolts.models.rl.common.memory import Experience, PERBuffer from pl_bolts.models.rl.dqn_model import DQN +from pl_bolts.utils.stability import experimental +@experimental() class PERDQN(DQN): """PyTorch Lightning implementation of `DQN With Prioritized Experience Replay`_. @@ -145,6 +147,7 @@ def _dataloader(self) -> DataLoader: return DataLoader(dataset=self.dataset, batch_size=self.batch_size) +@experimental() def cli_main(): parser = argparse.ArgumentParser(add_help=False) diff --git a/pl_bolts/models/rl/ppo_model.py b/pl_bolts/models/rl/ppo_model.py index 71585fa681..eb71903a50 100644 --- a/pl_bolts/models/rl/ppo_model.py +++ b/pl_bolts/models/rl/ppo_model.py @@ -10,6 +10,7 @@ from pl_bolts.datamodules import ExperienceSourceDataset from pl_bolts.models.rl.common.networks import MLP, ActorCategorical, ActorContinous from pl_bolts.utils import _GYM_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _GYM_AVAILABLE: @@ -18,6 +19,7 @@ warn_missing_pkg("gym") +@experimental() class PPO(LightningModule): """PyTorch Lightning implementation of `Proximal Policy Optimization. @@ -356,6 +358,7 @@ def add_model_specific_args(parent_parser): # pragma: no cover return parser +@experimental() def cli_main() -> None: parent_parser = argparse.ArgumentParser(add_help=False) parent_parser = Trainer.add_argparse_args(parent_parser) diff --git a/pl_bolts/models/rl/reinforce_model.py b/pl_bolts/models/rl/reinforce_model.py index 3d8d2560b5..c288ad3ee8 100644 --- a/pl_bolts/models/rl/reinforce_model.py +++ b/pl_bolts/models/rl/reinforce_model.py @@ -15,6 +15,7 @@ from pl_bolts.models.rl.common.agents import PolicyAgent from pl_bolts.models.rl.common.networks import MLP from pl_bolts.utils import _GYM_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _GYM_AVAILABLE: @@ -23,6 +24,7 @@ warn_missing_pkg("gym") +@experimental() class Reinforce(LightningModule): r"""PyTorch Lightning implementation of REINFORCE_. @@ -302,6 +304,7 @@ def add_model_specific_args(arg_parser) -> argparse.ArgumentParser: return arg_parser +@experimental() def cli_main(): parser = argparse.ArgumentParser(add_help=False) diff --git a/pl_bolts/models/rl/sac_model.py b/pl_bolts/models/rl/sac_model.py index a65a6ff8ee..c6d65a4123 100644 --- a/pl_bolts/models/rl/sac_model.py +++ b/pl_bolts/models/rl/sac_model.py @@ -16,6 +16,7 @@ from pl_bolts.models.rl.common.memory import MultiStepBuffer from pl_bolts.models.rl.common.networks import MLP, ContinuousMLP from pl_bolts.utils import _GYM_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _GYM_AVAILABLE: @@ -25,6 +26,7 @@ Env = object +@experimental() class SAC(LightningModule): def __init__( self, @@ -384,6 +386,7 @@ def add_model_specific_args( return arg_parser +@experimental() def cli_main(): parser = argparse.ArgumentParser(add_help=False) diff --git a/pl_bolts/models/rl/vanilla_policy_gradient_model.py b/pl_bolts/models/rl/vanilla_policy_gradient_model.py index 4292f793cb..68a01acfa2 100644 --- a/pl_bolts/models/rl/vanilla_policy_gradient_model.py +++ b/pl_bolts/models/rl/vanilla_policy_gradient_model.py @@ -15,6 +15,7 @@ from pl_bolts.models.rl.common.agents import PolicyAgent from pl_bolts.models.rl.common.networks import MLP from pl_bolts.utils import _GYM_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _GYM_AVAILABLE: @@ -23,6 +24,7 @@ warn_missing_pkg("gym") +@experimental() class VanillaPolicyGradient(LightningModule): r"""PyTorch Lightning implementation of `Vanilla Policy Gradient`_. @@ -285,6 +287,7 @@ def add_model_specific_args(arg_parser) -> argparse.ArgumentParser: return arg_parser +@experimental() def cli_main(): parser = argparse.ArgumentParser(add_help=False) diff --git a/pl_bolts/models/self_supervised/amdim/amdim_module.py b/pl_bolts/models/self_supervised/amdim/amdim_module.py index da91d4a67f..39dd05785f 100644 --- a/pl_bolts/models/self_supervised/amdim/amdim_module.py +++ b/pl_bolts/models/self_supervised/amdim/amdim_module.py @@ -11,8 +11,10 @@ from pl_bolts.models.self_supervised.amdim.datasets import AMDIMPretraining from pl_bolts.models.self_supervised.amdim.networks import AMDIMEncoder from pl_bolts.utils.self_supervised import torchvision_ssl_encoder +from pl_bolts.utils.stability import experimental +@experimental() def generate_power_seq(lr, nb): half = int(nb / 2) coefs = [2**pow for pow in range(half, -half - 1, -1)] @@ -59,6 +61,7 @@ def generate_power_seq(lr, nb): } +@experimental() class AMDIM(LightningModule): """PyTorch Lightning implementation of Augmented Multiscale Deep InfoMax (AMDIM_) @@ -318,6 +321,7 @@ def add_model_specific_args(parent_parser): return parser +@experimental() def cli_main(): parser = ArgumentParser() parser = Trainer.add_argparse_args(parser) diff --git a/pl_bolts/models/self_supervised/amdim/datasets.py b/pl_bolts/models/self_supervised/amdim/datasets.py index 1922b07630..9fe8c4b0a5 100644 --- a/pl_bolts/models/self_supervised/amdim/datasets.py +++ b/pl_bolts/models/self_supervised/amdim/datasets.py @@ -5,6 +5,7 @@ from pl_bolts.datasets import CIFAR10Mixed, UnlabeledImagenet from pl_bolts.models.self_supervised.amdim import transforms as amdim_transforms from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -13,6 +14,7 @@ warn_missing_pkg("torchvision") +@experimental() class AMDIMPretraining: """For pretraining we use the train transform for both train and val.""" @@ -74,6 +76,7 @@ def get_dataset(datamodule: str, data_dir, split: str = "train", **kwargs): return datasets[datamodule](dataset_root=data_dir, split=split, **kwargs) +@experimental() class AMDIMPatchesPretraining: """For pretraining we use the train transform for both train and val.""" diff --git a/pl_bolts/models/self_supervised/amdim/networks.py b/pl_bolts/models/self_supervised/amdim/networks.py index 2b2cfd8f85..70a6109019 100644 --- a/pl_bolts/models/self_supervised/amdim/networks.py +++ b/pl_bolts/models/self_supervised/amdim/networks.py @@ -5,7 +5,10 @@ from torch import nn from torch.nn import functional as F +from pl_bolts.utils.stability import experimental + +@experimental() class AMDIMEncoder(nn.Module): def __init__( self, @@ -146,6 +149,7 @@ def forward(self, x): return r1, r5, r7 +@experimental() class Conv3x3(nn.Module): def __init__(self, n_in, n_out, n_kern, n_stride, n_pad, use_bn=True, pad_mode="constant"): super().__init__() @@ -169,6 +173,7 @@ def forward(self, x): return out +@experimental() class ConvResBlock(nn.Module): def __init__(self, n_in, n_out, width, stride, pad, depth, use_bn): super().__init__() @@ -188,6 +193,7 @@ def forward(self, x): return x_out +@experimental() class ConvResNxN(nn.Module): def __init__(self, n_in, n_out, width, stride, pad, use_bn=False): super().__init__() @@ -232,6 +238,7 @@ def forward(self, x): return h23 +@experimental() class MaybeBatchNorm2d(nn.Module): def __init__(self, n_ftr, affine, use_bn): super().__init__() @@ -244,6 +251,7 @@ def forward(self, x): return x +@experimental() class NopNet(nn.Module): def __init__(self, norm_dim=None): super().__init__() @@ -257,6 +265,7 @@ def forward(self, x): return x +@experimental() class FakeRKHSConvNet(nn.Module): def __init__(self, n_input, n_output, use_bn=False): super().__init__() diff --git a/pl_bolts/models/self_supervised/amdim/transforms.py b/pl_bolts/models/self_supervised/amdim/transforms.py index 25b985e2ca..a31c67c606 100644 --- a/pl_bolts/models/self_supervised/amdim/transforms.py +++ b/pl_bolts/models/self_supervised/amdim/transforms.py @@ -1,5 +1,6 @@ from pl_bolts.transforms.self_supervised import RandomTranslateWithReflect from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -8,6 +9,7 @@ warn_missing_pkg("torchvision") +@experimental() class AMDIMTrainTransformsCIFAR10: """Transforms applied to AMDIM. @@ -52,6 +54,7 @@ def __call__(self, inp): return out1, out2 +@experimental() class AMDIMEvalTransformsCIFAR10: """Transforms applied to AMDIM. @@ -88,6 +91,7 @@ def __call__(self, inp): return out1 +@experimental() class AMDIMTrainTransformsSTL10: """Transforms applied to AMDIM. @@ -128,6 +132,7 @@ def __call__(self, inp): return out1, out2 +@experimental() class AMDIMEvalTransformsSTL10: """Transforms applied to AMDIM. @@ -170,6 +175,7 @@ def __call__(self, inp): return out1 +@experimental() class AMDIMTrainTransformsImageNet128: """Transforms applied to AMDIM. @@ -213,6 +219,7 @@ def __call__(self, inp): return out1, out2 +@experimental() class AMDIMEvalTransformsImageNet128: """Transforms applied to AMDIM. diff --git a/pl_bolts/models/self_supervised/byol/byol_module.py b/pl_bolts/models/self_supervised/byol/byol_module.py index 2d0d3c1fb0..8e75081a03 100644 --- a/pl_bolts/models/self_supervised/byol/byol_module.py +++ b/pl_bolts/models/self_supervised/byol/byol_module.py @@ -10,8 +10,10 @@ from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate from pl_bolts.models.self_supervised.byol.models import SiameseArm from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR +from pl_bolts.utils.stability import experimental +@experimental() class BYOL(LightningModule): """PyTorch Lightning implementation of Bootstrap Your Own Latent (BYOL_)_ @@ -176,6 +178,7 @@ def add_model_specific_args(parent_parser): return parser +@experimental() def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule diff --git a/pl_bolts/models/self_supervised/byol/models.py b/pl_bolts/models/self_supervised/byol/models.py index 6370d55113..7bbec2b0b0 100644 --- a/pl_bolts/models/self_supervised/byol/models.py +++ b/pl_bolts/models/self_supervised/byol/models.py @@ -1,8 +1,10 @@ from torch import nn from pl_bolts.utils.self_supervised import torchvision_ssl_encoder +from pl_bolts.utils.stability import experimental +@experimental() class MLP(nn.Module): def __init__(self, input_dim=2048, hidden_size=4096, output_dim=256): super().__init__() @@ -20,6 +22,7 @@ def forward(self, x): return x +@experimental() class SiameseArm(nn.Module): def __init__(self, encoder="resnet50", encoder_out_dim=2048, projector_hidden_size=4096, projector_out_dim=256): super().__init__() diff --git a/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py b/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py index e1ba0c640a..d5de41537c 100644 --- a/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py +++ b/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py @@ -10,8 +10,10 @@ CPCTrainTransformsCIFAR10, CPCTrainTransformsSTL10, ) +from pl_bolts.utils.stability import experimental +@experimental() def cli_main(): # pragma: no cover from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule diff --git a/pl_bolts/models/self_supervised/cpc/cpc_module.py b/pl_bolts/models/self_supervised/cpc/cpc_module.py index d325fc55cf..4c3759a572 100644 --- a/pl_bolts/models/self_supervised/cpc/cpc_module.py +++ b/pl_bolts/models/self_supervised/cpc/cpc_module.py @@ -24,10 +24,12 @@ ) from pl_bolts.utils.pretrained_weights import load_pretrained from pl_bolts.utils.self_supervised import torchvision_ssl_encoder +from pl_bolts.utils.stability import experimental __all__ = ["CPC_v2"] +@experimental() class CPC_v2(LightningModule): def __init__( self, @@ -202,6 +204,7 @@ def add_model_specific_args(parent_parser): return parser +@experimental() def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule diff --git a/pl_bolts/models/self_supervised/cpc/networks.py b/pl_bolts/models/self_supervised/cpc/networks.py index 116d1d452b..283b701d56 100644 --- a/pl_bolts/models/self_supervised/cpc/networks.py +++ b/pl_bolts/models/self_supervised/cpc/networks.py @@ -1,7 +1,10 @@ from torch import nn from torch.nn import functional as F +from pl_bolts.utils.stability import experimental + +@experimental() class CPCResNet(nn.Module): def __init__( self, @@ -128,14 +131,17 @@ def forward(self, x): return x +@experimental() def cpc_resnet101(sample_batch, **kwargs): return CPCResNet(sample_batch, LNBottleneck, [3, 4, 46, 3], **kwargs) +@experimental() def cpc_resnet50(sample_batch, **kwargs): return CPCResNet(sample_batch, LNBottleneck, [3, 4, 6, 3], **kwargs) +@experimental() class LNBottleneck(nn.Module): def __init__( self, @@ -201,6 +207,7 @@ def forward(self, x): return out +@experimental() def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding.""" return nn.Conv2d( @@ -215,6 +222,7 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): ) +@experimental() def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution.""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) diff --git a/pl_bolts/models/self_supervised/cpc/transforms.py b/pl_bolts/models/self_supervised/cpc/transforms.py index 107db25393..70e930a787 100644 --- a/pl_bolts/models/self_supervised/cpc/transforms.py +++ b/pl_bolts/models/self_supervised/cpc/transforms.py @@ -1,5 +1,6 @@ from pl_bolts.transforms.self_supervised import Patchify, RandomTranslateWithReflect from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -8,6 +9,7 @@ warn_missing_pkg("torchvision") +@experimental() class CPCTrainTransformsCIFAR10: """Transforms used for CPC: @@ -69,6 +71,7 @@ def __call__(self, inp): return out1 +@experimental() class CPCEvalTransformsCIFAR10: """Transforms used for CPC: @@ -121,6 +124,7 @@ def __call__(self, inp): return out1 +@experimental() class CPCTrainTransformsSTL10: """Transforms used for CPC: @@ -181,6 +185,7 @@ def __call__(self, inp): return out1 +@experimental() class CPCEvalTransformsSTL10: """Transforms used for CPC: @@ -231,6 +236,7 @@ def __call__(self, inp): return out1 +@experimental() class CPCTrainTransformsImageNet128: """Transforms used for CPC: @@ -284,6 +290,7 @@ def __call__(self, inp): return out1 +@experimental() class CPCEvalTransformsImageNet128: """Transforms used for CPC: diff --git a/pl_bolts/models/self_supervised/evaluator.py b/pl_bolts/models/self_supervised/evaluator.py index 8dd79a1ef2..7fe737b5b7 100644 --- a/pl_bolts/models/self_supervised/evaluator.py +++ b/pl_bolts/models/self_supervised/evaluator.py @@ -1,6 +1,9 @@ from torch import nn +from pl_bolts.utils.stability import experimental + +@experimental() class SSLEvaluator(nn.Module): def __init__(self, n_input, n_classes, n_hidden=512, p=0.1): super().__init__() @@ -27,6 +30,7 @@ def forward(self, x): return logits +@experimental() class Flatten(nn.Module): def __init__(self): super().__init__() diff --git a/pl_bolts/models/self_supervised/moco/callbacks.py b/pl_bolts/models/self_supervised/moco/callbacks.py index 992c1a2361..0b799c7709 100644 --- a/pl_bolts/models/self_supervised/moco/callbacks.py +++ b/pl_bolts/models/self_supervised/moco/callbacks.py @@ -2,7 +2,10 @@ from pytorch_lightning import Callback +from pl_bolts.utils.stability import experimental + +@experimental() class MocoLRScheduler(Callback): def __init__(self, initial_lr=0.03, use_cosine_scheduler=False, schedule=(120, 160), max_epochs=200): super().__init__() diff --git a/pl_bolts/models/self_supervised/moco/moco2_module.py b/pl_bolts/models/self_supervised/moco/moco2_module.py index 5ab20d0d0f..92eaf4660c 100644 --- a/pl_bolts/models/self_supervised/moco/moco2_module.py +++ b/pl_bolts/models/self_supervised/moco/moco2_module.py @@ -27,6 +27,7 @@ Moco2TrainSTL10Transforms, ) from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -35,6 +36,7 @@ warn_missing_pkg("torchvision") +@experimental() class Moco_v2(LightningModule): """PyTorch Lightning implementation of `Moco `_ @@ -340,6 +342,7 @@ def _use_ddp_or_ddp2(trainer: Trainer) -> bool: # utils +@experimental() @torch.no_grad() def concat_all_gather(tensor): """Performs all_gather operation on the provided tensors. @@ -353,6 +356,7 @@ def concat_all_gather(tensor): return output +@experimental() def cli_main(): from pl_bolts.datamodules import CIFAR10DataModule, SSLImagenetDataModule, STL10DataModule diff --git a/pl_bolts/models/self_supervised/moco/transforms.py b/pl_bolts/models/self_supervised/moco/transforms.py index 60ffb5eb33..a7aa13a26b 100644 --- a/pl_bolts/models/self_supervised/moco/transforms.py +++ b/pl_bolts/models/self_supervised/moco/transforms.py @@ -6,6 +6,7 @@ stl10_normalization, ) from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -19,6 +20,7 @@ warn_missing_pkg("PIL", pypi_name="Pillow") +@experimental() class Moco2TrainCIFAR10Transforms: """Moco 2 augmentation: @@ -48,6 +50,7 @@ def __call__(self, inp): return q, k +@experimental() class Moco2EvalCIFAR10Transforms: """Moco 2 augmentation: @@ -73,6 +76,7 @@ def __call__(self, inp): return q, k +@experimental() class Moco2TrainSTL10Transforms: """Moco 2 augmentation: @@ -102,6 +106,7 @@ def __call__(self, inp): return q, k +@experimental() class Moco2EvalSTL10Transforms: """Moco 2 augmentation: @@ -127,6 +132,7 @@ def __call__(self, inp): return q, k +@experimental() class Moco2TrainImagenetTransforms: """Moco 2 augmentation: @@ -156,6 +162,7 @@ def __call__(self, inp): return q, k +@experimental() class Moco2EvalImagenetTransforms: """Moco 2 augmentation: @@ -181,6 +188,7 @@ def __call__(self, inp): return q, k +@experimental() class GaussianBlur: """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709.""" diff --git a/pl_bolts/models/self_supervised/resnets.py b/pl_bolts/models/self_supervised/resnets.py index 912f5b983c..8da862c84f 100644 --- a/pl_bolts/models/self_supervised/resnets.py +++ b/pl_bolts/models/self_supervised/resnets.py @@ -2,6 +2,8 @@ from torch import nn from torch.utils.model_zoo import load_url as load_state_dict_from_url +from pl_bolts.utils.stability import experimental + __all__ = [ "ResNet", "resnet18", @@ -28,6 +30,7 @@ } +@experimental() def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding.""" return nn.Conv2d( @@ -42,11 +45,13 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): ) +@experimental() def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution.""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) +@experimental() class BasicBlock(nn.Module): expansion = 1 @@ -88,6 +93,7 @@ def forward(self, x): return out +@experimental() class Bottleneck(nn.Module): expansion = 4 @@ -132,6 +138,7 @@ def forward(self, x): return out +@experimental() class ResNet(nn.Module): def __init__( self, @@ -269,6 +276,7 @@ def forward(self, x): return [x0] +@experimental() def _resnet(arch, block, layers, pretrained, progress, **kwargs): model = ResNet(block, layers, **kwargs) if pretrained: @@ -277,6 +285,7 @@ def _resnet(arch, block, layers, pretrained, progress, **kwargs): return model +@experimental() def resnet18(pretrained: bool = False, progress: bool = True, **kwargs): r"""ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_ @@ -288,6 +297,7 @@ def resnet18(pretrained: bool = False, progress: bool = True, **kwargs): return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) +@experimental() def resnet34(pretrained=False, progress=True, **kwargs): r"""ResNet-34 model from `"Deep Residual Learning for Image Recognition" `_ @@ -299,6 +309,7 @@ def resnet34(pretrained=False, progress=True, **kwargs): return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) +@experimental() def resnet50(pretrained: bool = False, progress: bool = True, **kwargs): r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_ @@ -310,6 +321,7 @@ def resnet50(pretrained: bool = False, progress: bool = True, **kwargs): return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) +@experimental() def resnet101(pretrained: bool = False, progress: bool = True, **kwargs): r"""ResNet-101 model from `"Deep Residual Learning for Image Recognition" `_ @@ -321,6 +333,7 @@ def resnet101(pretrained: bool = False, progress: bool = True, **kwargs): return _resnet("resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) +@experimental() def resnet152(pretrained: bool = False, progress: bool = True, **kwargs): r"""ResNet-152 model from `"Deep Residual Learning for Image Recognition" `_ @@ -332,6 +345,7 @@ def resnet152(pretrained: bool = False, progress: bool = True, **kwargs): return _resnet("resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs) +@experimental() def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs): r"""ResNeXt-50 32x4d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_ @@ -345,6 +359,7 @@ def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs): return _resnet("resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) +@experimental() def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs): r"""ResNeXt-101 32x8d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_ @@ -358,6 +373,7 @@ def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs): return _resnet("resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) +@experimental() def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs): r"""Wide ResNet-50-2 model from `"Wide Residual Networks" `_ @@ -374,6 +390,7 @@ def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs): return _resnet("wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) +@experimental() def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs): r"""Wide ResNet-101-2 model from `"Wide Residual Networks" `_ diff --git a/pl_bolts/models/self_supervised/simclr/simclr_finetuner.py b/pl_bolts/models/self_supervised/simclr/simclr_finetuner.py index 2be64712b8..e5e50de55d 100644 --- a/pl_bolts/models/self_supervised/simclr/simclr_finetuner.py +++ b/pl_bolts/models/self_supervised/simclr/simclr_finetuner.py @@ -11,8 +11,10 @@ imagenet_normalization, stl10_normalization, ) +from pl_bolts.utils.stability import experimental +@experimental() def cli_main(): # pragma: no cover from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule diff --git a/pl_bolts/models/self_supervised/simclr/simclr_module.py b/pl_bolts/models/self_supervised/simclr/simclr_module.py index 7171aa885f..1cb8308cc5 100644 --- a/pl_bolts/models/self_supervised/simclr/simclr_module.py +++ b/pl_bolts/models/self_supervised/simclr/simclr_module.py @@ -15,8 +15,10 @@ imagenet_normalization, stl10_normalization, ) +from pl_bolts.utils.stability import experimental +@experimental() class SyncFunction(torch.autograd.Function): @staticmethod def forward(ctx, tensor): @@ -39,6 +41,7 @@ def backward(ctx, grad_output): return grad_input[idx_from:idx_to] +@experimental() class Projection(nn.Module): def __init__(self, input_dim=2048, hidden_dim=2048, output_dim=128): super().__init__() @@ -58,6 +61,7 @@ def forward(self, x): return F.normalize(x, dim=1) +@experimental() class SimCLR(LightningModule): def __init__( self, @@ -300,6 +304,7 @@ def add_model_specific_args(parent_parser): return parser +@experimental() def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule diff --git a/pl_bolts/models/self_supervised/simclr/transforms.py b/pl_bolts/models/self_supervised/simclr/transforms.py index 64d8f6d1f9..c74ad6fc82 100644 --- a/pl_bolts/models/self_supervised/simclr/transforms.py +++ b/pl_bolts/models/self_supervised/simclr/transforms.py @@ -1,6 +1,7 @@ import numpy as np from pl_bolts.utils import _OPENCV_AVAILABLE, _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -14,6 +15,7 @@ warn_missing_pkg("cv2", pypi_name="opencv-python") +@experimental() class SimCLRTrainDataTransform: """Transforms for SimCLR. @@ -91,6 +93,7 @@ def __call__(self, sample): return xi, xj, self.online_transform(sample) +@experimental() class SimCLREvalDataTransform(SimCLRTrainDataTransform): """Transforms for SimCLR. @@ -126,6 +129,7 @@ def __init__( ) +@experimental() class SimCLRFinetuneTransform: def __init__( self, input_height: int = 224, jitter_strength: float = 1.0, normalize=None, eval_transform: bool = False @@ -167,6 +171,7 @@ def __call__(self, sample): return self.transform(sample) +@experimental() class GaussianBlur: # Implements Gaussian blur as described in the SimCLR paper def __init__(self, kernel_size, p=0.5, min=0.1, max=2.0): diff --git a/pl_bolts/models/self_supervised/simsiam/models.py b/pl_bolts/models/self_supervised/simsiam/models.py index bc85724949..4ea855e952 100644 --- a/pl_bolts/models/self_supervised/simsiam/models.py +++ b/pl_bolts/models/self_supervised/simsiam/models.py @@ -3,8 +3,10 @@ from torch import Tensor, nn from pl_bolts.utils.self_supervised import torchvision_ssl_encoder +from pl_bolts.utils.stability import experimental +@experimental() class MLP(nn.Module): def __init__(self, input_dim: int = 2048, hidden_size: int = 4096, output_dim: int = 256) -> None: super().__init__() @@ -22,6 +24,7 @@ def forward(self, x: Tensor) -> Tensor: return x +@experimental() class SiameseArm(nn.Module): def __init__( self, diff --git a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py index cf390b3b49..046615e11e 100644 --- a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py +++ b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -14,8 +14,10 @@ imagenet_normalization, stl10_normalization, ) +from pl_bolts.utils.stability import experimental +@experimental() class SimSiam(LightningModule): """PyTorch Lightning implementation of Exploring Simple Siamese Representation Learning (SimSiam_) @@ -262,6 +264,7 @@ def add_model_specific_args(parent_parser): return parser +@experimental() def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule diff --git a/pl_bolts/models/self_supervised/ssl_finetuner.py b/pl_bolts/models/self_supervised/ssl_finetuner.py index 49919c54cb..4a2d9300a5 100644 --- a/pl_bolts/models/self_supervised/ssl_finetuner.py +++ b/pl_bolts/models/self_supervised/ssl_finetuner.py @@ -6,8 +6,10 @@ from torchmetrics import Accuracy from pl_bolts.models.self_supervised import SSLEvaluator +from pl_bolts.utils.stability import experimental +@experimental() class SSLFineTuner(LightningModule): """Finetunes a self-supervised learning backbone using the standard evaluation protocol of a singler layer MLP with 1024 units. diff --git a/pl_bolts/models/self_supervised/swav/swav_finetuner.py b/pl_bolts/models/self_supervised/swav/swav_finetuner.py index 70e0570d84..5fc66da6e6 100644 --- a/pl_bolts/models/self_supervised/swav/swav_finetuner.py +++ b/pl_bolts/models/self_supervised/swav/swav_finetuner.py @@ -7,8 +7,10 @@ from pl_bolts.models.self_supervised.swav.swav_module import SwAV from pl_bolts.models.self_supervised.swav.transforms import SwAVFinetuneTransform from pl_bolts.transforms.dataset_normalizations import imagenet_normalization, stl10_normalization +from pl_bolts.utils.stability import experimental +@experimental() def cli_main(): # pragma: no cover from pl_bolts.datamodules import ImagenetDataModule, STL10DataModule diff --git a/pl_bolts/models/self_supervised/swav/swav_module.py b/pl_bolts/models/self_supervised/swav/swav_module.py index 0b91308786..56d5c173d0 100644 --- a/pl_bolts/models/self_supervised/swav/swav_module.py +++ b/pl_bolts/models/self_supervised/swav/swav_module.py @@ -17,8 +17,10 @@ imagenet_normalization, stl10_normalization, ) +from pl_bolts.utils.stability import experimental +@experimental() class SwAV(LightningModule): def __init__( self, @@ -444,6 +446,7 @@ def add_model_specific_args(parent_parser): return parser +@experimental() def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule diff --git a/pl_bolts/models/self_supervised/swav/swav_resnet.py b/pl_bolts/models/self_supervised/swav/swav_resnet.py index fb24651a11..de9ae477b2 100644 --- a/pl_bolts/models/self_supervised/swav/swav_resnet.py +++ b/pl_bolts/models/self_supervised/swav/swav_resnet.py @@ -2,7 +2,10 @@ import torch from torch import nn +from pl_bolts.utils.stability import experimental + +@experimental() def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding.""" return nn.Conv2d( @@ -17,11 +20,13 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): ) +@experimental() def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution.""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) +@experimental() class BasicBlock(nn.Module): expansion = 1 __constants__ = ["downsample"] @@ -72,6 +77,7 @@ def forward(self, x): return out +@experimental() class Bottleneck(nn.Module): expansion = 4 __constants__ = ["downsample"] @@ -125,6 +131,7 @@ def forward(self, x): return out +@experimental() class ResNet(nn.Module): def __init__( self, @@ -336,6 +343,7 @@ def forward(self, inputs): return self.forward_head(output) +@experimental() class MultiPrototypes(nn.Module): def __init__(self, output_dim, nmb_prototypes): super().__init__() @@ -350,21 +358,26 @@ def forward(self, x): return out +@experimental() def resnet18(**kwargs): return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) +@experimental() def resnet50(**kwargs): return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) +@experimental() def resnet50w2(**kwargs): return ResNet(Bottleneck, [3, 4, 6, 3], widen=2, **kwargs) +@experimental() def resnet50w4(**kwargs): return ResNet(Bottleneck, [3, 4, 6, 3], widen=4, **kwargs) +@experimental() def resnet50w5(**kwargs): return ResNet(Bottleneck, [3, 4, 6, 3], widen=5, **kwargs) diff --git a/pl_bolts/models/self_supervised/swav/transforms.py b/pl_bolts/models/self_supervised/swav/transforms.py index ac963e645c..c5c211f395 100644 --- a/pl_bolts/models/self_supervised/swav/transforms.py +++ b/pl_bolts/models/self_supervised/swav/transforms.py @@ -3,6 +3,7 @@ import numpy as np from pl_bolts.utils import _OPENCV_AVAILABLE, _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -16,6 +17,7 @@ warn_missing_pkg("cv2", pypi_name="opencv-python") +@experimental() class SwAVTrainDataTransform: def __init__( self, @@ -98,6 +100,7 @@ def __call__(self, sample): return multi_crops +@experimental() class SwAVEvalDataTransform(SwAVTrainDataTransform): def __init__( self, @@ -132,6 +135,7 @@ def __init__( self.transform[-1] = test_transform +@experimental() class SwAVFinetuneTransform: def __init__( self, input_height: int = 224, jitter_strength: float = 1.0, normalize=None, eval_transform: bool = False @@ -173,6 +177,7 @@ def __call__(self, sample): return self.transform(sample) +@experimental() class GaussianBlur: # Implements Gaussian blur as described in the SimCLR paper def __init__(self, kernel_size, p=0.5, min=0.1, max=2.0): diff --git a/pl_bolts/models/vision/image_gpt/gpt2.py b/pl_bolts/models/vision/image_gpt/gpt2.py index c9c05e02f0..4e2709bd47 100644 --- a/pl_bolts/models/vision/image_gpt/gpt2.py +++ b/pl_bolts/models/vision/image_gpt/gpt2.py @@ -2,7 +2,10 @@ from pytorch_lightning import LightningModule from torch import nn +from pl_bolts.utils.stability import experimental + +@experimental() class Block(nn.Module): def __init__(self, embed_dim, heads): super().__init__() @@ -27,6 +30,7 @@ def forward(self, x): return x +@experimental() class GPT2(LightningModule): """GPT-2 from `language Models are Unsupervised Multitask Learners `_ diff --git a/pl_bolts/models/vision/image_gpt/igpt_module.py b/pl_bolts/models/vision/image_gpt/igpt_module.py index 471121fdeb..8ae07bfff9 100644 --- a/pl_bolts/models/vision/image_gpt/igpt_module.py +++ b/pl_bolts/models/vision/image_gpt/igpt_module.py @@ -6,8 +6,10 @@ from torch import nn from pl_bolts.models.vision.image_gpt.gpt2 import GPT2 +from pl_bolts.utils.stability import experimental +@experimental() def _shape_input(x): """shape batch of images for input into GPT2 model.""" x = x.view(x.shape[0], -1) # flatten images into sequences @@ -15,6 +17,7 @@ def _shape_input(x): return x +@experimental() class ImageGPT(LightningModule): """ **Paper**: `Generative Pretraining from Pixels @@ -238,6 +241,7 @@ def add_model_specific_args(parent_parser): return parser +@experimental() def cli_main(): from pl_bolts.datamodules import FashionMNISTDataModule, ImagenetDataModule diff --git a/pl_bolts/models/vision/pixel_cnn.py b/pl_bolts/models/vision/pixel_cnn.py index 7144023fe5..a93059f009 100644 --- a/pl_bolts/models/vision/pixel_cnn.py +++ b/pl_bolts/models/vision/pixel_cnn.py @@ -7,7 +7,10 @@ from torch import nn from torch.nn import functional as F +from pl_bolts.utils.stability import experimental + +@experimental() class PixelCNN(nn.Module): """Implementation of `Pixel CNN `_. diff --git a/pl_bolts/models/vision/segmentation.py b/pl_bolts/models/vision/segmentation.py index 9860280b18..1bd3be8e43 100644 --- a/pl_bolts/models/vision/segmentation.py +++ b/pl_bolts/models/vision/segmentation.py @@ -5,8 +5,10 @@ from torch.nn import functional as F from pl_bolts.models.vision.unet import UNet +from pl_bolts.utils.stability import experimental +@experimental() class SemSegment(LightningModule): def __init__( self, @@ -90,6 +92,7 @@ def add_model_specific_args(parent_parser): return parser +@experimental() def cli_main(): from pl_bolts.datamodules import KittiDataModule diff --git a/pl_bolts/models/vision/unet.py b/pl_bolts/models/vision/unet.py index 6a1b1556d1..5619c8bb3c 100644 --- a/pl_bolts/models/vision/unet.py +++ b/pl_bolts/models/vision/unet.py @@ -2,7 +2,10 @@ from torch import nn from torch.nn import functional as F +from pl_bolts.utils.stability import experimental + +@experimental() class UNet(nn.Module): """ Paper: `U-Net: Convolutional Networks for Biomedical Image Segmentation @@ -64,6 +67,7 @@ def forward(self, x): return self.layers[-1](xi[-1]) +@experimental() class DoubleConv(nn.Module): """[ Conv2d => BatchNorm (optional) => ReLU ] x 2.""" @@ -82,6 +86,7 @@ def forward(self, x): return self.net(x) +@experimental() class Down(nn.Module): """Downscale with MaxPool => DoubleConvolution block.""" @@ -93,6 +98,7 @@ def forward(self, x): return self.net(x) +@experimental() class Up(nn.Module): """Upsampling (by either bilinear interpolation or transpose convolutions) followed by concatenation of feature map from contracting path, followed by DoubleConv.""" diff --git a/pl_bolts/optimizers/lars.py b/pl_bolts/optimizers/lars.py index 2a4eb51be2..de03c6dcd6 100644 --- a/pl_bolts/optimizers/lars.py +++ b/pl_bolts/optimizers/lars.py @@ -6,7 +6,10 @@ import torch from torch.optim.optimizer import Optimizer, required +from pl_bolts.utils.stability import experimental + +@experimental() class LARS(Optimizer): """Extends SGD in PyTorch with LARS scaling from the paper `Large batch training of Convolutional Networks `_. diff --git a/pl_bolts/optimizers/lr_scheduler.py b/pl_bolts/optimizers/lr_scheduler.py index 44810f99d0..5e01bafcd3 100644 --- a/pl_bolts/optimizers/lr_scheduler.py +++ b/pl_bolts/optimizers/lr_scheduler.py @@ -6,7 +6,10 @@ from torch.optim import Adam, Optimizer from torch.optim.lr_scheduler import _LRScheduler +from pl_bolts.utils.stability import experimental + +@experimental() class LinearWarmupCosineAnnealingLR(_LRScheduler): """Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr and base_lr followed by a cosine annealing schedule between base_lr and eta_min. @@ -121,6 +124,7 @@ def _get_closed_form_lr(self) -> List[float]: # warmup + decay as a function +@experimental() def linear_warmup_decay(warmup_steps, total_steps, cosine=True, linear=False): """Linear warmup for warmup_steps, optionally with cosine annealing or linear decay to 0 at total_steps.""" assert not (linear and cosine) diff --git a/pl_bolts/transforms/dataset_normalizations.py b/pl_bolts/transforms/dataset_normalizations.py index d742b07444..3b1a859f1e 100644 --- a/pl_bolts/transforms/dataset_normalizations.py +++ b/pl_bolts/transforms/dataset_normalizations.py @@ -1,4 +1,5 @@ from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -7,6 +8,7 @@ warn_missing_pkg("torchvision") +@experimental() def imagenet_normalization(): if not _TORCHVISION_AVAILABLE: # pragma: no cover raise ModuleNotFoundError( @@ -17,6 +19,7 @@ def imagenet_normalization(): return normalize +@experimental() def cifar10_normalization(): if not _TORCHVISION_AVAILABLE: # pragma: no cover raise ModuleNotFoundError( @@ -30,6 +33,7 @@ def cifar10_normalization(): return normalize +@experimental() def stl10_normalization(): if not _TORCHVISION_AVAILABLE: # pragma: no cover raise ModuleNotFoundError( @@ -40,6 +44,7 @@ def stl10_normalization(): return normalize +@experimental() def emnist_normalization(split: str): if not _TORCHVISION_AVAILABLE: # pragma: no cover raise ModuleNotFoundError( diff --git a/pl_bolts/transforms/self_supervised/ssl_transforms.py b/pl_bolts/transforms/self_supervised/ssl_transforms.py index 4397b60c09..492a41a53b 100644 --- a/pl_bolts/transforms/self_supervised/ssl_transforms.py +++ b/pl_bolts/transforms/self_supervised/ssl_transforms.py @@ -2,6 +2,7 @@ from torch.nn import functional as F from pl_bolts.utils import _PIL_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _PIL_AVAILABLE: @@ -10,6 +11,7 @@ warn_missing_pkg("PIL", pypi_name="Pillow") +@experimental() class RandomTranslateWithReflect: """Translate image randomly Translate vertically and horizontally by n pixels where n is integer drawn uniformly independently for each axis from [-max_translation, max_translation]. @@ -53,6 +55,7 @@ def __call__(self, old_image): return new_image +@experimental() class Patchify: def __init__(self, patch_size, overlap_size): self.patch_size = patch_size From 61143d43822519b65c1e20f3c8f509533220d47f Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 19 Jul 2022 16:50:04 +0200 Subject: [PATCH 03/11] intro message --- docs/source/introduction_guide.rst | 2 ++ pl_bolts/models/self_supervised/moco/moco2_module.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/source/introduction_guide.rst b/docs/source/introduction_guide.rst index df31f17173..1a2d462468 100644 --- a/docs/source/introduction_guide.rst +++ b/docs/source/introduction_guide.rst @@ -12,6 +12,8 @@ Bolts is a Deep learning research and production toolbox of: **The Main goal of Bolts is to enable trying new ideas as fast as possible!** +.. note:: Currently, Bolts is going through a major revision. For more information about it, see `this GitHub issue `_ and `stability section `_ + All models are tested (daily), benchmarked, documented and work on CPUs, TPUs, GPUs and 16-bit precision. **some examples!** diff --git a/pl_bolts/models/self_supervised/moco/moco2_module.py b/pl_bolts/models/self_supervised/moco/moco2_module.py index 92eaf4660c..0b35b38027 100644 --- a/pl_bolts/models/self_supervised/moco/moco2_module.py +++ b/pl_bolts/models/self_supervised/moco/moco2_module.py @@ -342,8 +342,8 @@ def _use_ddp_or_ddp2(trainer: Trainer) -> bool: # utils -@experimental() @torch.no_grad() +@experimental() def concat_all_gather(tensor): """Performs all_gather operation on the provided tensors. From e0739f7f0ad3258c73c4dbba651069ad68b1d4d9 Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 19 Jul 2022 17:15:25 +0200 Subject: [PATCH 04/11] make doctests happier --- pl_bolts/utils/stability.py | 38 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/pl_bolts/utils/stability.py b/pl_bolts/utils/stability.py index e20c82e1ea..3df3bbe359 100644 --- a/pl_bolts/utils/stability.py +++ b/pl_bolts/utils/stability.py @@ -39,26 +39,24 @@ def experimental( message: The message to include in the warning. Examples ________ - .. testsetup:: - >>> import pytest - .. doctest:: - >>> from pl_bolts.utils.stability import experimental - >>> @experimental() - ... class MyExperimentalFeature: - ... pass - ... - >>> with pytest.warns(UserWarning, match="This feature is currently marked as experimental."): - ... MyExperimentalFeature() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - ... - <...> - >>> @experimental("This feature is currently marked as experimental with a message.") - ... class MyExperimentalFeatureWithCustomMessage: - ... pass - ... - >>> with pytest.warns(UserWarning, match="This feature is currently marked as experimental with a message."): - ... MyExperimentalFeatureWithCustomMessage() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - ... - <...> + >>> import pytest + >>> from pl_bolts.utils.stability import experimental + >>> @experimental() + ... class MyExperimentalFeature: + ... pass + ... + >>> with pytest.warns(UserWarning, match="This feature is currently marked as experimental."): + ... MyExperimentalFeature() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + ... + <...> + >>> @experimental("This feature is currently marked as experimental with a message.") + ... class MyExperimentalFeatureWithCustomMessage: + ... pass + ... + >>> with pytest.warns(UserWarning, match="This feature is currently marked as experimental with a message."): + ... MyExperimentalFeatureWithCustomMessage() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + ... + <...> """ def decorator(callable: Union[Callable, Type]): From 9ddf0365b6629e8c2208ad4f01939b3bbdb3185d Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 19 Jul 2022 17:17:07 +0200 Subject: [PATCH 05/11] make mypy happier --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 08cf4b3a9e..5cda411322 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,5 +53,6 @@ module = [ "pl_bolts.models.vision.*", "pl_bolts.optimizers.*", "pl_bolts.transforms.*", + "pl_bolts.utils.stability" ] ignore_errors = "True" From 4be7eb1016643b45b48de63db3de6357773c1edd Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 19 Jul 2022 17:21:04 +0200 Subject: [PATCH 06/11] experimental utils too --- pl_bolts/utils/arguments.py | 5 +++++ pl_bolts/utils/pretrained_weights.py | 3 +++ pl_bolts/utils/self_supervised.py | 2 ++ pl_bolts/utils/semi_supervised.py | 4 ++++ pl_bolts/utils/shaping.py | 3 +++ pl_bolts/utils/warnings.py | 3 +++ 6 files changed, 20 insertions(+) diff --git a/pl_bolts/utils/arguments.py b/pl_bolts/utils/arguments.py index 0586c575de..c4b03aa16a 100644 --- a/pl_bolts/utils/arguments.py +++ b/pl_bolts/utils/arguments.py @@ -5,8 +5,11 @@ from pytorch_lightning import LightningDataModule, LightningModule +from pl_bolts.utils.stability import experimental + @dataclass(frozen=True) +@experimental() class LitArg: """Dataclass to represent init args of an object.""" @@ -17,6 +20,7 @@ class LitArg: context: Optional[str] = None +@experimental() class LightningArgumentParser(ArgumentParser): """Extension of argparse.ArgumentParser that lets you parse arbitrary object init args. @@ -72,6 +76,7 @@ def parse_lit_args(self, *args: Any, **kwargs: Any) -> Namespace: return lit_args +@experimental() def gather_lit_args(cls: Any, root_cls: Optional[Any] = None) -> List[LitArg]: if root_cls is None: diff --git a/pl_bolts/utils/pretrained_weights.py b/pl_bolts/utils/pretrained_weights.py index b54c9824d3..438570d18a 100644 --- a/pl_bolts/utils/pretrained_weights.py +++ b/pl_bolts/utils/pretrained_weights.py @@ -2,6 +2,8 @@ from pytorch_lightning import LightningModule +from pl_bolts.utils.stability import experimental + vae_imagenet2012 = ( "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/" "vae/imagenet_06_22_2019/checkpoints/epoch%3D63.ckpt" ) @@ -10,6 +12,7 @@ urls = {"vae-imagenet2012": vae_imagenet2012, "CPC_v2-resnet18": cpcv2_resnet18} +@experimental() def load_pretrained(model: LightningModule, class_name: Optional[str] = None) -> None: # pragma: no cover if class_name is None: class_name = model.__class__.__name__ diff --git a/pl_bolts/utils/self_supervised.py b/pl_bolts/utils/self_supervised.py index 3f64467a4f..1b1c9c6550 100644 --- a/pl_bolts/utils/self_supervised.py +++ b/pl_bolts/utils/self_supervised.py @@ -2,8 +2,10 @@ from pl_bolts.models.self_supervised import resnets from pl_bolts.utils.semi_supervised import Identity +from pl_bolts.utils.stability import experimental +@experimental() def torchvision_ssl_encoder( name: str, pretrained: bool = False, diff --git a/pl_bolts/utils/semi_supervised.py b/pl_bolts/utils/semi_supervised.py index 87930a62a4..68c4ada0c0 100644 --- a/pl_bolts/utils/semi_supervised.py +++ b/pl_bolts/utils/semi_supervised.py @@ -6,6 +6,7 @@ from torch import Tensor from pl_bolts.utils import _SKLEARN_AVAILABLE +from pl_bolts.utils.stability import experimental from pl_bolts.utils.warnings import warn_missing_pkg if _SKLEARN_AVAILABLE: @@ -14,6 +15,7 @@ warn_missing_pkg("sklearn", pypi_name="scikit-learn") +@experimental() class Identity(torch.nn.Module): """An identity class to replace arbitrary layers in pretrained models. @@ -32,6 +34,7 @@ def forward(self, x: Tensor) -> Tensor: return x +@experimental() def balance_classes( X: Union[Tensor, np.ndarray], Y: Union[Tensor, np.ndarray, Sequence[int]], batch_size: int ) -> Tuple[np.ndarray, np.ndarray]: @@ -95,6 +98,7 @@ def balance_classes( return final_batches_x, final_batches_y +@experimental() def generate_half_labeled_batches( smaller_set_X: np.ndarray, smaller_set_Y: np.ndarray, diff --git a/pl_bolts/utils/shaping.py b/pl_bolts/utils/shaping.py index a85be61373..dabb49df60 100644 --- a/pl_bolts/utils/shaping.py +++ b/pl_bolts/utils/shaping.py @@ -2,7 +2,10 @@ import torch from torch import Tensor +from pl_bolts.utils.stability import experimental + +@experimental() def tile(a: Tensor, dim: int, n_tile: int) -> Tensor: init_dim = a.size(dim) repeat_idx = [1] * a.dim() diff --git a/pl_bolts/utils/warnings.py b/pl_bolts/utils/warnings.py index 116ef2173a..d60c5e1b8c 100644 --- a/pl_bolts/utils/warnings.py +++ b/pl_bolts/utils/warnings.py @@ -2,11 +2,14 @@ import warnings from typing import Callable, Dict, Optional +from pl_bolts.utils.stability import experimental + MISSING_PACKAGE_WARNINGS: Dict[str, int] = {} WARN_MISSING_PACKAGE = int(os.environ.get("WARN_MISSING_PACKAGE", False)) +@experimental() def warn_missing_pkg( pkg_name: str, pypi_name: Optional[str] = None, From 11c5d00a9b43b67dea0e728f9e20f1899c6de03c Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 19 Jul 2022 17:37:03 +0200 Subject: [PATCH 07/11] make doctest happy --- pl_bolts/utils/stability.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pl_bolts/utils/stability.py b/pl_bolts/utils/stability.py index 3df3bbe359..ed66f9b3bb 100644 --- a/pl_bolts/utils/stability.py +++ b/pl_bolts/utils/stability.py @@ -18,7 +18,6 @@ from pytorch_lightning.utilities import rank_zero_warn -@functools.lru_cache() # Trick to only warn once for each message def _raise_experimental_warning(message: str, stacklevel: int = 6): rank_zero_warn( f"{message} The compatibility with other Lightning projects is not guaranteed and API may change at any time." From d0b3a8be39114676b422c40c7ff14703c949b4d4 Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 19 Jul 2022 17:57:43 +0200 Subject: [PATCH 08/11] make dataclasses happier --- pl_bolts/utils/arguments.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/utils/arguments.py b/pl_bolts/utils/arguments.py index c4b03aa16a..ed16117e76 100644 --- a/pl_bolts/utils/arguments.py +++ b/pl_bolts/utils/arguments.py @@ -8,8 +8,8 @@ from pl_bolts.utils.stability import experimental -@dataclass(frozen=True) @experimental() +@dataclass(frozen=True) class LitArg: """Dataclass to represent init args of an object.""" From 7782d11ec4f869e48491bcdbf41e6501bd2b4449 Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 19 Jul 2022 18:09:20 +0200 Subject: [PATCH 09/11] make tests happier --- pl_bolts/datasets/ssl_amdim_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datasets/ssl_amdim_datasets.py b/pl_bolts/datasets/ssl_amdim_datasets.py index 5efa2d7dfb..5475b3ce79 100644 --- a/pl_bolts/datasets/ssl_amdim_datasets.py +++ b/pl_bolts/datasets/ssl_amdim_datasets.py @@ -111,7 +111,7 @@ def __init__( # use train for all of these splits train = split in ("val", "train", "train+unlabeled") - super().__init__(root, train, transform, target_transform, download) + super(SSLDatasetMixin, self).__init__(root, train, transform, target_transform, download) # modify only for val, train if split != "test": From c7c2dbf01569233c1cbb231e71b24a0aae490d54 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 20 Jul 2022 11:17:05 +0200 Subject: [PATCH 10/11] minor Co-authored-by: Akihiro Nitta --- docs/source/stability.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/stability.rst b/docs/source/stability.rst index f149237db6..9d597127d6 100644 --- a/docs/source/stability.rst +++ b/docs/source/stability.rst @@ -16,7 +16,7 @@ ______ For stable features, all of the following are true: - the API isn’t expected to change -- if anything does change, incorrect usage will give a deprecation warning for **one major release** before the breaking change is made +- if anything does change, incorrect usage will give a deprecation warning for **one minor release** before the breaking change is made - the API has been tested for compatibility with latest releases of PyTorch Lightning and Flash Experimental From 817a1b859674882a8f8970fd7a2604dbf68fbc14 Mon Sep 17 00:00:00 2001 From: otaj Date: Wed, 20 Jul 2022 14:32:19 +0200 Subject: [PATCH 11/11] rename experimental -> to_review --- docs/source/stability.rst | 17 +++-- pl_bolts/callbacks/byol_updates.py | 4 +- pl_bolts/callbacks/data_monitor.py | 10 +-- pl_bolts/callbacks/knn_online.py | 6 +- pl_bolts/callbacks/printing.py | 6 +- pl_bolts/callbacks/sparseml.py | 4 +- pl_bolts/callbacks/ssl_online.py | 6 +- pl_bolts/callbacks/torch_ort.py | 4 +- pl_bolts/callbacks/variational.py | 4 +- pl_bolts/callbacks/verification/base.py | 6 +- .../callbacks/verification/batch_gradient.py | 14 ++-- pl_bolts/callbacks/vision/confused_logit.py | 4 +- pl_bolts/callbacks/vision/image_generation.py | 4 +- pl_bolts/callbacks/vision/sr_image_logger.py | 4 +- pl_bolts/datamodules/async_dataloader.py | 4 +- .../datamodules/binary_emnist_datamodule.py | 4 +- .../datamodules/binary_mnist_datamodule.py | 4 +- pl_bolts/datamodules/cifar10_datamodule.py | 6 +- pl_bolts/datamodules/cityscapes_datamodule.py | 4 +- pl_bolts/datamodules/emnist_datamodule.py | 4 +- pl_bolts/datamodules/experience_source.py | 10 +-- .../datamodules/fashion_mnist_datamodule.py | 4 +- pl_bolts/datamodules/imagenet_datamodule.py | 4 +- pl_bolts/datamodules/kitti_datamodule.py | 4 +- pl_bolts/datamodules/mnist_datamodule.py | 4 +- pl_bolts/datamodules/sklearn_datamodule.py | 8 +-- pl_bolts/datamodules/sr_datamodule.py | 4 +- .../datamodules/ssl_imagenet_datamodule.py | 4 +- pl_bolts/datamodules/stl10_datamodule.py | 4 +- pl_bolts/datamodules/vision_datamodule.py | 4 +- .../datamodules/vocdetection_datamodule.py | 8 +-- pl_bolts/datasets/base_dataset.py | 4 +- pl_bolts/datasets/cifar10_dataset.py | 6 +- pl_bolts/datasets/concat_dataset.py | 4 +- pl_bolts/datasets/dummy_dataset.py | 12 ++-- pl_bolts/datasets/emnist_dataset.py | 4 +- pl_bolts/datasets/imagenet_dataset.py | 26 +++---- pl_bolts/datasets/kitti_dataset.py | 4 +- pl_bolts/datasets/mnist_dataset.py | 4 +- pl_bolts/datasets/sr_celeba_dataset.py | 4 +- pl_bolts/datasets/sr_dataset_mixin.py | 4 +- pl_bolts/datasets/sr_mnist_dataset.py | 4 +- pl_bolts/datasets/sr_stl10_dataset.py | 4 +- pl_bolts/datasets/ssl_amdim_datasets.py | 6 +- pl_bolts/datasets/utils.py | 4 +- pl_bolts/losses/object_detection.py | 6 +- pl_bolts/losses/rl.py | 8 +-- pl_bolts/losses/self_supervised_learning.py | 12 ++-- pl_bolts/metrics/aggregation.py | 8 +-- pl_bolts/metrics/object_detection.py | 6 +- .../autoencoders/basic_ae/basic_ae_module.py | 6 +- .../basic_vae/basic_vae_module.py | 6 +- pl_bolts/models/autoencoders/components.py | 32 ++++----- .../components/torchvision_backbones.py | 10 +-- .../models/detection/faster_rcnn/backbones.py | 4 +- .../faster_rcnn/faster_rcnn_module.py | 8 +-- .../models/detection/retinanet/backbones.py | 4 +- .../detection/retinanet/retinanet_module.py | 6 +- pl_bolts/models/detection/yolo/yolo_config.py | 18 ++--- pl_bolts/models/detection/yolo/yolo_layers.py | 20 +++--- pl_bolts/models/detection/yolo/yolo_module.py | 8 +-- .../models/gans/basic/basic_gan_module.py | 6 +- pl_bolts/models/gans/basic/components.py | 6 +- pl_bolts/models/gans/dcgan/components.py | 6 +- pl_bolts/models/gans/dcgan/dcgan_module.py | 6 +- pl_bolts/models/gans/pix2pix/components.py | 10 +-- .../models/gans/pix2pix/pix2pix_module.py | 6 +- pl_bolts/models/gans/srgan/components.py | 10 +-- pl_bolts/models/gans/srgan/srgan_module.py | 6 +- pl_bolts/models/gans/srgan/srresnet_module.py | 6 +- pl_bolts/models/mnist_module.py | 6 +- .../models/regression/linear_regression.py | 6 +- .../models/regression/logistic_regression.py | 6 +- .../models/rl/advantage_actor_critic_model.py | 6 +- pl_bolts/models/rl/common/agents.py | 12 ++-- pl_bolts/models/rl/common/cli.py | 4 +- pl_bolts/models/rl/common/distributions.py | 4 +- pl_bolts/models/rl/common/gym_wrappers.py | 20 +++--- pl_bolts/models/rl/common/memory.py | 12 ++-- pl_bolts/models/rl/common/networks.py | 22 +++--- pl_bolts/models/rl/double_dqn_model.py | 6 +- pl_bolts/models/rl/dqn_model.py | 6 +- pl_bolts/models/rl/dueling_dqn_model.py | 6 +- pl_bolts/models/rl/noisy_dqn_model.py | 6 +- pl_bolts/models/rl/per_dqn_model.py | 6 +- pl_bolts/models/rl/ppo_model.py | 6 +- pl_bolts/models/rl/reinforce_model.py | 6 +- pl_bolts/models/rl/sac_model.py | 6 +- .../rl/vanilla_policy_gradient_model.py | 6 +- .../self_supervised/amdim/amdim_module.py | 8 +-- .../models/self_supervised/amdim/datasets.py | 6 +- .../models/self_supervised/amdim/networks.py | 16 ++--- .../self_supervised/amdim/transforms.py | 14 ++-- .../self_supervised/byol/byol_module.py | 6 +- .../models/self_supervised/byol/models.py | 6 +- .../self_supervised/cpc/cpc_finetuner.py | 4 +- .../models/self_supervised/cpc/cpc_module.py | 6 +- .../models/self_supervised/cpc/networks.py | 14 ++-- .../models/self_supervised/cpc/transforms.py | 14 ++-- pl_bolts/models/self_supervised/evaluator.py | 6 +- .../models/self_supervised/moco/callbacks.py | 4 +- .../self_supervised/moco/moco2_module.py | 8 +-- .../models/self_supervised/moco/transforms.py | 16 ++--- pl_bolts/models/self_supervised/resnets.py | 32 ++++----- .../simclr/simclr_finetuner.py | 4 +- .../self_supervised/simclr/simclr_module.py | 10 +-- .../self_supervised/simclr/transforms.py | 10 +-- .../models/self_supervised/simsiam/models.py | 6 +- .../self_supervised/simsiam/simsiam_module.py | 6 +- .../models/self_supervised/ssl_finetuner.py | 4 +- .../self_supervised/swav/swav_finetuner.py | 4 +- .../self_supervised/swav/swav_module.py | 6 +- .../self_supervised/swav/swav_resnet.py | 24 +++---- .../models/self_supervised/swav/transforms.py | 10 +-- pl_bolts/models/vision/image_gpt/gpt2.py | 6 +- .../models/vision/image_gpt/igpt_module.py | 8 +-- pl_bolts/models/vision/pixel_cnn.py | 4 +- pl_bolts/models/vision/segmentation.py | 6 +- pl_bolts/models/vision/unet.py | 10 +-- pl_bolts/optimizers/lars.py | 4 +- pl_bolts/optimizers/lr_scheduler.py | 6 +- pl_bolts/transforms/dataset_normalizations.py | 10 +-- .../self_supervised/ssl_transforms.py | 6 +- pl_bolts/utils/arguments.py | 8 +-- pl_bolts/utils/pretrained_weights.py | 4 +- pl_bolts/utils/self_supervised.py | 4 +- pl_bolts/utils/semi_supervised.py | 8 +-- pl_bolts/utils/shaping.py | 4 +- pl_bolts/utils/stability.py | 72 +++++++++++-------- pl_bolts/utils/warnings.py | 4 +- 130 files changed, 537 insertions(+), 514 deletions(-) diff --git a/docs/source/stability.rst b/docs/source/stability.rst index f149237db6..2cf0153bb5 100644 --- a/docs/source/stability.rst +++ b/docs/source/stability.rst @@ -4,9 +4,9 @@ Bolts stability =============== Currently we are going through major revision of Bolts to ensure all of the code is stable and compatible with the rest of the Lightning ecosystem. -For this reason, all of our features are either marked as stable or experimental. Stable features are implicit, experimental features are explicit. +For this reason, all of our features are either marked as stable or in need of review. Stable features are implicit, features to be reviewed are explicitly marked. -At the beginning of the aforementioned revision, **ALL** of the features currently in the project have been marked as experimental and will undergo rigorous review and testing before they can be marked as stable. +At the beginning of the aforementioned revision, **ALL** of the features currently in the project have been marked as to be reviewed and will undergo rigorous review and testing before they can be marked as stable. This document is intended to help you know what to expect and to outline our commitment to stability. @@ -19,12 +19,19 @@ For stable features, all of the following are true: - if anything does change, incorrect usage will give a deprecation warning for **one major release** before the breaking change is made - the API has been tested for compatibility with latest releases of PyTorch Lightning and Flash -Experimental -____________ +To Review +_________ -For experimental features, any or all of the following may be true: +For features to be reviewed, any or all of the following may be true: - the feature has unstable dependencies - the API may change without notice in future versions - the performance of the feature has not been verified - the docs for this feature are under active development + + +Before a feature can be moved to Stable it needs to satisfy following conditions: + +- Have appropriate tests, that will check not only correctness of the feature, but also compatibility with the current versions. +- Not have duplicate code accross Lightning ecosystem and more mature OSS projects. +- Pass a review process. diff --git a/pl_bolts/callbacks/byol_updates.py b/pl_bolts/callbacks/byol_updates.py index 15bbc6c4f9..83155f429d 100644 --- a/pl_bolts/callbacks/byol_updates.py +++ b/pl_bolts/callbacks/byol_updates.py @@ -5,10 +5,10 @@ from torch import Tensor from torch.nn import Module -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class BYOLMAWeightUpdate(Callback): """Weight update rule from BYOL. diff --git a/pl_bolts/callbacks/data_monitor.py b/pl_bolts/callbacks/data_monitor.py index 5f4656e8f2..a788664e2e 100644 --- a/pl_bolts/callbacks/data_monitor.py +++ b/pl_bolts/callbacks/data_monitor.py @@ -11,7 +11,7 @@ from torch.utils.hooks import RemovableHandle from pl_bolts.utils import _WANDB_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _WANDB_AVAILABLE: @@ -20,7 +20,7 @@ warn_missing_pkg("wandb") -@experimental() +@to_review() class DataMonitorBase(Callback): supported_loggers = ( @@ -111,7 +111,7 @@ def _is_logger_available(self, logger: LightningLoggerBase) -> bool: return available -@experimental() +@to_review() class ModuleDataMonitor(DataMonitorBase): GROUP_NAME_INPUT = "input" @@ -197,7 +197,7 @@ def hook(_: Module, inp: Sequence, out: Sequence) -> None: return handle -@experimental() +@to_review() class TrainingDataMonitor(DataMonitorBase): GROUP_NAME = "training_step" @@ -261,7 +261,7 @@ def collect_and_name_tensors(data: Any, output: Dict[str, Tensor], parent_name: collect_and_name_tensors(item, output, parent_name=f"{parent_name}/{i:d}") -@experimental() +@to_review() def shape2str(tensor: Tensor) -> str: """Returns the shape of a tensor in bracket notation as a string. diff --git a/pl_bolts/callbacks/knn_online.py b/pl_bolts/callbacks/knn_online.py index 238eda1d30..35ecb67d88 100644 --- a/pl_bolts/callbacks/knn_online.py +++ b/pl_bolts/callbacks/knn_online.py @@ -6,10 +6,10 @@ from torch import Tensor from torch.nn import functional as F -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class KNNOnlineEvaluator(Callback): """Weighted KNN online evaluator for self-supervised learning. The weighted KNN classifier matches sec 3.4 of https://arxiv.org/pdf/1805.01978.pdf. @@ -141,6 +141,6 @@ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) pl_module.log("online_knn_val_acc", total_top1 / total_num, on_step=False, on_epoch=True, sync_dist=True) -@experimental() +@to_review() def concat_all_gather(tensor: Tensor, accelerator: Accelerator) -> Tensor: return accelerator.all_gather(tensor).view(-1, *tensor.shape[1:]) diff --git a/pl_bolts/callbacks/printing.py b/pl_bolts/callbacks/printing.py index eefb04ac52..dfeea93602 100644 --- a/pl_bolts/callbacks/printing.py +++ b/pl_bolts/callbacks/printing.py @@ -6,10 +6,10 @@ from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities import rank_zero_info -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class PrintTableMetricsCallback(Callback): """Prints a table with the metrics in columns on every epoch end. @@ -44,7 +44,7 @@ def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: rank_zero_info(dicts_to_table(self.metrics)) -@experimental() +@to_review() def dicts_to_table( dicts: List[Dict], keys: Optional[List[str]] = None, diff --git a/pl_bolts/callbacks/sparseml.py b/pl_bolts/callbacks/sparseml.py index 58e8a17198..97b9b4a8ef 100644 --- a/pl_bolts/callbacks/sparseml.py +++ b/pl_bolts/callbacks/sparseml.py @@ -23,10 +23,10 @@ from sparseml.pytorch.optim import ScheduledModifierManager from sparseml.pytorch.utils import ModuleExporter -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class SparseMLCallback(Callback): """Enables SparseML aware training. Requires a recipe to run during training. diff --git a/pl_bolts/callbacks/ssl_online.py b/pl_bolts/callbacks/ssl_online.py index 953d48fdbc..f0df310e4b 100644 --- a/pl_bolts/callbacks/ssl_online.py +++ b/pl_bolts/callbacks/ssl_online.py @@ -10,10 +10,10 @@ from torchmetrics.functional import accuracy from pl_bolts.models.self_supervised.evaluator import SSLEvaluator -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class SSLOnlineEvaluator(Callback): # pragma: no cover """Attaches a MLP for fine-tuning using the standard self-supervised protocol. @@ -175,7 +175,7 @@ def on_load_checkpoint(self, trainer: Trainer, pl_module: LightningModule, callb self._recovered_callback_state = callback_state -@experimental() +@to_review() @contextmanager def set_training(module: nn.Module, mode: bool): """Context manager to set training mode. diff --git a/pl_bolts/callbacks/torch_ort.py b/pl_bolts/callbacks/torch_ort.py index c4b64a1e10..e414c37158 100644 --- a/pl_bolts/callbacks/torch_ort.py +++ b/pl_bolts/callbacks/torch_ort.py @@ -19,10 +19,10 @@ if _TORCH_ORT_AVAILABLE: from torch_ort import ORTModule -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class ORTCallback(Callback): """Enables Torch ORT: Accelerate PyTorch models with ONNX Runtime. diff --git a/pl_bolts/callbacks/variational.py b/pl_bolts/callbacks/variational.py index 15f4fe1b5f..3d83c6a35e 100644 --- a/pl_bolts/callbacks/variational.py +++ b/pl_bolts/callbacks/variational.py @@ -7,7 +7,7 @@ from torch import Tensor from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -16,7 +16,7 @@ warn_missing_pkg("torchvision") -@experimental() +@to_review() class LatentDimInterpolator(Callback): """Interpolates the latent space for a model by setting all dims to zero and stepping through the first two dims increasing one unit at a time. diff --git a/pl_bolts/callbacks/verification/base.py b/pl_bolts/callbacks/verification/base.py index 25aebaec8b..5b0879156e 100644 --- a/pl_bolts/callbacks/verification/base.py +++ b/pl_bolts/callbacks/verification/base.py @@ -9,10 +9,10 @@ from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class VerificationBase: """Base class for model verification. @@ -82,7 +82,7 @@ def _model_forward(self, input_array: Any) -> Any: return self.model(input_array) -@experimental() +@to_review() class VerificationCallbackBase(Callback): """Base class for model verification in form of a callback. diff --git a/pl_bolts/callbacks/verification/batch_gradient.py b/pl_bolts/callbacks/verification/batch_gradient.py index 4c3b08d736..47f3244686 100644 --- a/pl_bolts/callbacks/verification/batch_gradient.py +++ b/pl_bolts/callbacks/verification/batch_gradient.py @@ -10,10 +10,10 @@ from torch import Tensor from pl_bolts.callbacks.verification.base import VerificationBase, VerificationCallbackBase -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class BatchGradientVerification(VerificationBase): """Checks if a model mixes data across the batch dimension. @@ -84,7 +84,7 @@ def check( return not any(has_grad_outside_sample) and all(has_grad_inside_sample) -@experimental() +@to_review() class BatchGradientVerificationCallback(VerificationCallbackBase): """The callback version of the :class:`BatchGradientVerification` test. @@ -133,7 +133,7 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: self._raise() -@experimental() +@to_review() def default_input_mapping(data: Any) -> List[Tensor]: """Finds all tensors in a (nested) collection that have the same batch size. @@ -161,7 +161,7 @@ def default_input_mapping(data: Any) -> List[Tensor]: return batches -@experimental() +@to_review() def default_output_mapping(data: Any) -> Tensor: """Pulls out all tensors in a output collection and combines them into one big batch for verification. @@ -193,7 +193,7 @@ def default_output_mapping(data: Any) -> Tensor: return combined -@experimental() +@to_review() def collect_tensors(data: Any) -> List[Tensor]: """Filters all tensors in a collection and returns them in a list.""" tensors = [] @@ -206,7 +206,7 @@ def collect_batches(tensor: Tensor) -> Tensor: return tensors -@experimental() +@to_review() @contextmanager def selective_eval(model: nn.Module, layer_types: Iterable[Type[nn.Module]]) -> None: """A context manager that sets all requested types of layers to eval mode. This method uses an ``isinstance`` diff --git a/pl_bolts/callbacks/vision/confused_logit.py b/pl_bolts/callbacks/vision/confused_logit.py index 3a4b52d2f4..2c2216312b 100644 --- a/pl_bolts/callbacks/vision/confused_logit.py +++ b/pl_bolts/callbacks/vision/confused_logit.py @@ -5,7 +5,7 @@ from torch import Tensor, nn from pl_bolts.utils import _MATPLOTLIB_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _MATPLOTLIB_AVAILABLE: @@ -18,7 +18,7 @@ Figure = object -@experimental() +@to_review() class ConfusedLogitCallback(Callback): # pragma: no cover """Takes the logit predictions of a model and when the probabilities of two classes are very close, the model doesn't have high certainty that it should pick one vs the other class. diff --git a/pl_bolts/callbacks/vision/image_generation.py b/pl_bolts/callbacks/vision/image_generation.py index cea7beafb6..153e534614 100644 --- a/pl_bolts/callbacks/vision/image_generation.py +++ b/pl_bolts/callbacks/vision/image_generation.py @@ -4,7 +4,7 @@ from pytorch_lightning import Callback, LightningModule, Trainer from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -13,7 +13,7 @@ warn_missing_pkg("torchvision") -@experimental() +@to_review() class TensorboardGenerativeModelImageSampler(Callback): """Generates images and logs to tensorboard. Your model must implement the ``forward`` function for generation. diff --git a/pl_bolts/callbacks/vision/sr_image_logger.py b/pl_bolts/callbacks/vision/sr_image_logger.py index f94ecd7040..e1726f8921 100644 --- a/pl_bolts/callbacks/vision/sr_image_logger.py +++ b/pl_bolts/callbacks/vision/sr_image_logger.py @@ -6,7 +6,7 @@ from pytorch_lightning import Callback from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -15,7 +15,7 @@ warn_missing_pkg("torchvision") -@experimental() +@to_review() class SRImageLoggerCallback(Callback): """Logs low-res, generated high-res, and ground truth high-res images to TensorBoard Your model must implement the ``forward`` function for generation. diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index c126ad85dc..833a0dc784 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -9,10 +9,10 @@ from torch._six import string_classes from torch.utils.data import DataLoader, Dataset -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class AsynchronousLoader: """Class for asynchronously loading from CPU memory to device memory with DataLoader. diff --git a/pl_bolts/datamodules/binary_emnist_datamodule.py b/pl_bolts/datamodules/binary_emnist_datamodule.py index 5435b5722c..c59b9584a5 100644 --- a/pl_bolts/datamodules/binary_emnist_datamodule.py +++ b/pl_bolts/datamodules/binary_emnist_datamodule.py @@ -3,10 +3,10 @@ from pl_bolts.datamodules.emnist_datamodule import EMNISTDataModule from pl_bolts.datasets import BinaryEMNIST from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class BinaryEMNISTDataModule(EMNISTDataModule): """ .. figure:: https://user-images.githubusercontent.com/4632336/123210742-4d6b3380-d477-11eb-80da-3e9a74a18a07.png diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index ed65d82b26..5e11092656 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -3,7 +3,7 @@ from pl_bolts.datamodules.vision_datamodule import VisionDataModule from pl_bolts.datasets import BinaryMNIST from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -12,7 +12,7 @@ warn_missing_pkg("torchvision") -@experimental() +@to_review() class BinaryMNISTDataModule(VisionDataModule): """ .. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 65b08e4f1f..86db708973 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -4,7 +4,7 @@ from pl_bolts.datasets import TrialCIFAR10 from pl_bolts.transforms.dataset_normalizations import cifar10_normalization from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -15,7 +15,7 @@ CIFAR10 = None -@experimental() +@to_review() class CIFAR10DataModule(VisionDataModule): """ .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2019/01/ @@ -124,7 +124,7 @@ def default_transforms(self) -> Callable: return cf10_transforms -@experimental() +@to_review() class TinyCIFAR10DataModule(CIFAR10DataModule): """Standard CIFAR10, train, val, test splits and transforms. diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index 6d1dab5733..a74f364556 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -4,7 +4,7 @@ from torch.utils.data import DataLoader from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -14,7 +14,7 @@ warn_missing_pkg("torchvision") -@experimental() +@to_review() class CityscapesDataModule(LightningDataModule): """ .. figure:: https://www.cityscapes-dataset.com/wordpress/wp-content/uploads/2015/07/muenster00-1024x510.png diff --git a/pl_bolts/datamodules/emnist_datamodule.py b/pl_bolts/datamodules/emnist_datamodule.py index 90a4eb4b4a..ecde3e605b 100644 --- a/pl_bolts/datamodules/emnist_datamodule.py +++ b/pl_bolts/datamodules/emnist_datamodule.py @@ -3,7 +3,7 @@ from pl_bolts.datamodules.vision_datamodule import VisionDataModule from pl_bolts.transforms.dataset_normalizations import emnist_normalization from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -14,7 +14,7 @@ EMNIST = object -@experimental() +@to_review() class EMNISTDataModule(VisionDataModule): """ .. figure:: https://user-images.githubusercontent.com/4632336/123210742-4d6b3380-d477-11eb-80da-3e9a74a18a07.png diff --git a/pl_bolts/datamodules/experience_source.py b/pl_bolts/datamodules/experience_source.py index 3aaa8623a2..171f43b997 100644 --- a/pl_bolts/datamodules/experience_source.py +++ b/pl_bolts/datamodules/experience_source.py @@ -8,7 +8,7 @@ from torch.utils.data import IterableDataset from pl_bolts.utils import _GYM_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg try: @@ -24,7 +24,7 @@ Experience = namedtuple("Experience", field_names=["state", "action", "reward", "done", "new_state"]) -@experimental() +@to_review() class ExperienceSourceDataset(IterableDataset): """Basic experience source dataset. @@ -41,7 +41,7 @@ def __iter__(self) -> Iterator: # Experience Sources -@experimental() +@to_review() class BaseExperienceSource(ABC): """Simplest form of the experience source.""" @@ -59,7 +59,7 @@ def runner(self) -> Experience: raise NotImplementedError("ExperienceSource has no stepper method implemented") -@experimental() +@to_review() class ExperienceSource(BaseExperienceSource): """Experience source class handling single and multiple environment steps.""" @@ -235,7 +235,7 @@ def pop_rewards_steps(self): return res -@experimental() +@to_review() class DiscountedExperienceSource(ExperienceSource): """Outputs experiences with a discounted reward over N steps.""" diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index d70402ac24..f13a083e77 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -2,7 +2,7 @@ from pl_bolts.datamodules.vision_datamodule import VisionDataModule from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -13,7 +13,7 @@ FashionMNIST = None -@experimental() +@to_review() class FashionMNISTDataModule(VisionDataModule): """ .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/ diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 4ef3d4c59b..64417ad928 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -7,7 +7,7 @@ from pl_bolts.datasets import UnlabeledImagenet from pl_bolts.transforms.dataset_normalizations import imagenet_normalization from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -16,7 +16,7 @@ warn_missing_pkg("torchvision") -@experimental() +@to_review() class ImagenetDataModule(LightningDataModule): """ .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2017/08/ diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 9195f3b59d..458e596a89 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -8,7 +8,7 @@ from pl_bolts.datasets import KittiDataset from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -17,7 +17,7 @@ warn_missing_pkg("torchvision") -@experimental() +@to_review() class KittiDataModule(LightningDataModule): name = "kitti" diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 93dff65ec4..752aa51177 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -3,7 +3,7 @@ from pl_bolts.datamodules.vision_datamodule import VisionDataModule from pl_bolts.datasets import MNIST from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -12,7 +12,7 @@ warn_missing_pkg("torchvision") -@experimental() +@to_review() class MNISTDataModule(VisionDataModule): """ .. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index 49d5e3efa5..4c9c7e638b 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -8,7 +8,7 @@ from torch.utils.data import DataLoader, Dataset from pl_bolts.utils import _SKLEARN_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _SKLEARN_AVAILABLE: @@ -17,7 +17,7 @@ warn_missing_pkg("sklearn") -@experimental() +@to_review() class SklearnDataset(Dataset): """Mapping between numpy (or sklearn) datasets to PyTorch datasets. @@ -65,7 +65,7 @@ def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]: return x, y -@experimental() +@to_review() class TensorDataset(Dataset): """Prepare PyTorch tensor dataset for data loaders. @@ -109,7 +109,7 @@ def __getitem__(self, idx) -> Tuple[Tensor, Tensor]: return x, y -@experimental() +@to_review() class SklearnDataModule(LightningDataModule): """Automatically generates the train, validation and test splits for a Numpy dataset. They are set up as dataloaders for convenience. Optionally, you can pass in your own validation and test splits. diff --git a/pl_bolts/datamodules/sr_datamodule.py b/pl_bolts/datamodules/sr_datamodule.py index 4d5e798bc7..a58249188d 100644 --- a/pl_bolts/datamodules/sr_datamodule.py +++ b/pl_bolts/datamodules/sr_datamodule.py @@ -3,10 +3,10 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class TVTDataModule(LightningDataModule): """Simple DataModule creating train, val, and test dataloaders from given train, val, and test dataset. diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index fa89b8357c..132528b1b9 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -7,7 +7,7 @@ from pl_bolts.datasets import UnlabeledImagenet from pl_bolts.transforms.dataset_normalizations import imagenet_normalization from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -16,7 +16,7 @@ warn_missing_pkg("torchvision") -@experimental() +@to_review() class SSLImagenetDataModule(LightningDataModule): # pragma: no cover name = "imagenet" diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 87bee938e4..c61cc55f64 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -8,7 +8,7 @@ from pl_bolts.datasets import ConcatDataset from pl_bolts.transforms.dataset_normalizations import stl10_normalization from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -18,7 +18,7 @@ warn_missing_pkg("torchvision") -@experimental() +@to_review() class STL10DataModule(LightningDataModule): # pragma: no cover """ .. figure:: https://samyzaf.com/ML/cifar10/cifar1.jpg diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index 29b09a3b5b..82ca2bf8fd 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -6,10 +6,10 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset, random_split -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class VisionDataModule(LightningDataModule): EXTRA_ARGS: dict = {} diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 603f0c9815..f45aecbbe5 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -7,7 +7,7 @@ from torch.utils.data import DataLoader, Dataset from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -17,7 +17,7 @@ warn_missing_pkg("torchvision") -@experimental() +@to_review() class Compose: """Like `torchvision.transforms.compose` but works for (image, target)""" @@ -62,7 +62,7 @@ def _collate_fn(batch: List[Tensor]) -> tuple: ) -@experimental() +@to_review() def _prepare_voc_instance(image: Any, target: Dict[str, Any]): """Prepares VOC dataset into appropriate target for fasterrcnn. @@ -104,7 +104,7 @@ def _prepare_voc_instance(image: Any, target: Dict[str, Any]): return image, target -@experimental() +@to_review() class VOCDetectionDataModule(LightningDataModule): """TODO(teddykoker) docstring.""" diff --git a/pl_bolts/datasets/base_dataset.py b/pl_bolts/datasets/base_dataset.py index 1acc0a9d81..c90389c2f1 100644 --- a/pl_bolts/datasets/base_dataset.py +++ b/pl_bolts/datasets/base_dataset.py @@ -8,10 +8,10 @@ from torch import Tensor from torch.utils.data import Dataset -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class LightDataset(ABC, Dataset): data: Tensor diff --git a/pl_bolts/datasets/cifar10_dataset.py b/pl_bolts/datasets/cifar10_dataset.py index 9bf89cfa41..6baf3fef62 100644 --- a/pl_bolts/datasets/cifar10_dataset.py +++ b/pl_bolts/datasets/cifar10_dataset.py @@ -8,7 +8,7 @@ from pl_bolts.datasets import LightDataset from pl_bolts.utils import _PIL_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _PIL_AVAILABLE: @@ -17,7 +17,7 @@ warn_missing_pkg("PIL", pypi_name="Pillow") -@experimental() +@to_review() class CIFAR10(LightDataset): """Customized `CIFAR10 `_ dataset for testing Pytorch Lightning without the torchvision dependency. @@ -155,7 +155,7 @@ def download(self, data_folder: str) -> None: self._download_from_url(self.BASE_URL, data_folder, self.FILE_NAME) -@experimental() +@to_review() class TrialCIFAR10(CIFAR10): """ Customized `CIFAR10 `_ dataset for testing Pytorch Lightning diff --git a/pl_bolts/datasets/concat_dataset.py b/pl_bolts/datasets/concat_dataset.py index b43048b6f7..08413def62 100644 --- a/pl_bolts/datasets/concat_dataset.py +++ b/pl_bolts/datasets/concat_dataset.py @@ -1,9 +1,9 @@ from torch.utils.data import Dataset -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class ConcatDataset(Dataset): def __init__(self, *datasets): self.datasets = datasets diff --git a/pl_bolts/datasets/dummy_dataset.py b/pl_bolts/datasets/dummy_dataset.py index 74054e6fc1..dd5723437c 100644 --- a/pl_bolts/datasets/dummy_dataset.py +++ b/pl_bolts/datasets/dummy_dataset.py @@ -1,10 +1,10 @@ import torch from torch.utils.data import Dataset -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class DummyDataset(Dataset): """Generate a dummy dataset. @@ -44,7 +44,7 @@ def __getitem__(self, idx: int): return sample -@experimental() +@to_review() class DummyDetectionDataset(Dataset): """Generate a dummy dataset for detection. @@ -85,7 +85,7 @@ def __getitem__(self, idx: int): return img, {"boxes": boxes, "labels": labels} -@experimental() +@to_review() class RandomDictDataset(Dataset): """Generate a dummy dataset with a dict structure. @@ -114,7 +114,7 @@ def __len__(self): return self.len -@experimental() +@to_review() class RandomDictStringDataset(Dataset): """Generate a dummy dataset with strings. @@ -141,7 +141,7 @@ def __len__(self): return self.len -@experimental() +@to_review() class RandomDataset(Dataset): """Generate a dummy dataset. diff --git a/pl_bolts/datasets/emnist_dataset.py b/pl_bolts/datasets/emnist_dataset.py index fb3628d2aa..0370f0da7c 100644 --- a/pl_bolts/datasets/emnist_dataset.py +++ b/pl_bolts/datasets/emnist_dataset.py @@ -1,5 +1,5 @@ from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -14,7 +14,7 @@ warn_missing_pkg("PIL", pypi_name="Pillow") -@experimental() +@to_review() class BinaryEMNIST(EMNIST): def __getitem__(self, idx): """ diff --git a/pl_bolts/datasets/imagenet_dataset.py b/pl_bolts/datasets/imagenet_dataset.py index 9195a5a57b..42dcbbd799 100644 --- a/pl_bolts/datasets/imagenet_dataset.py +++ b/pl_bolts/datasets/imagenet_dataset.py @@ -12,7 +12,7 @@ import torch from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg PY3 = sys.version_info[0] == 3 @@ -25,7 +25,7 @@ ImageNet = object -@experimental() +@to_review() class UnlabeledImagenet(ImageNet): """Official train set gets split into train, val. (using nb_imgs_per_val_class for each class). Official validation becomes test set. @@ -160,7 +160,7 @@ def generate_meta_bins(cls, devkit_dir): print(f"meta.bin generated at {devkit_dir}/meta.bin") -@experimental() +@to_review() def _verify_archive(root, file, md5): if not _check_integrity(os.path.join(root, file), md5): raise RuntimeError( @@ -169,7 +169,7 @@ def _verify_archive(root, file, md5): ) -@experimental() +@to_review() def _check_integrity(fpath, md5=None): if not os.path.isfile(fpath): return False @@ -178,12 +178,12 @@ def _check_integrity(fpath, md5=None): return _check_md5(fpath, md5) -@experimental() +@to_review() def _check_md5(fpath, md5, **kwargs): return md5 == _calculate_md5(fpath, **kwargs) -@experimental() +@to_review() def _calculate_md5(fpath, chunk_size=1024 * 1024): md5 = hashlib.md5() with open(fpath, "rb") as f: @@ -192,7 +192,7 @@ def _calculate_md5(fpath, chunk_size=1024 * 1024): return md5.hexdigest() -@experimental() +@to_review() def parse_devkit_archive(root, file=None): """Parse the devkit archive of the ImageNet2012 classification dataset and save the meta information in a binary file. @@ -249,7 +249,7 @@ def get_tmp_dir(): torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE)) -@experimental() +@to_review() def extract_archive(from_path, to_path=None, remove_finished=False): if to_path is None: to_path = os.path.dirname(from_path) @@ -280,26 +280,26 @@ def extract_archive(from_path, to_path=None, remove_finished=False): os.remove(from_path) -@experimental() +@to_review() def _is_targz(filename): return filename.endswith(".tar.gz") -@experimental() +@to_review() def _is_tarxz(filename): return filename.endswith(".tar.xz") -@experimental() +@to_review() def _is_gzip(filename): return filename.endswith(".gz") and not filename.endswith(".tar.gz") -@experimental() +@to_review() def _is_tar(filename): return filename.endswith(".tar") -@experimental() +@to_review() def _is_zip(filename): return filename.endswith(".zip") diff --git a/pl_bolts/datasets/kitti_dataset.py b/pl_bolts/datasets/kitti_dataset.py index 88e86b1c6a..d9f81eaba3 100644 --- a/pl_bolts/datasets/kitti_dataset.py +++ b/pl_bolts/datasets/kitti_dataset.py @@ -4,7 +4,7 @@ from torch.utils.data import Dataset from pl_bolts.utils import _PIL_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _PIL_AVAILABLE: @@ -16,7 +16,7 @@ DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33) -@experimental() +@to_review() class KittiDataset(Dataset): """ Note: diff --git a/pl_bolts/datasets/mnist_dataset.py b/pl_bolts/datasets/mnist_dataset.py index 8c0486d90a..df5c877545 100644 --- a/pl_bolts/datasets/mnist_dataset.py +++ b/pl_bolts/datasets/mnist_dataset.py @@ -1,5 +1,5 @@ from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE, _TORCHVISION_LESS_THAN_0_9_1 -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -37,7 +37,7 @@ ] -@experimental() +@to_review() class BinaryMNIST(MNIST): def __getitem__(self, idx): """ diff --git a/pl_bolts/datasets/sr_celeba_dataset.py b/pl_bolts/datasets/sr_celeba_dataset.py index 5f3eaaa19c..db8e88b903 100644 --- a/pl_bolts/datasets/sr_celeba_dataset.py +++ b/pl_bolts/datasets/sr_celeba_dataset.py @@ -3,7 +3,7 @@ from pl_bolts.datasets.sr_dataset_mixin import SRDatasetMixin from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _PIL_AVAILABLE: @@ -18,7 +18,7 @@ CelebA = object -@experimental() +@to_review() class SRCelebA(SRDatasetMixin, CelebA): """CelebA dataset that can be used to train Super Resolution models. diff --git a/pl_bolts/datasets/sr_dataset_mixin.py b/pl_bolts/datasets/sr_dataset_mixin.py index a35b366dc9..3d02425bc2 100644 --- a/pl_bolts/datasets/sr_dataset_mixin.py +++ b/pl_bolts/datasets/sr_dataset_mixin.py @@ -4,7 +4,7 @@ import torch from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _PIL_AVAILABLE: @@ -18,7 +18,7 @@ warn_missing_pkg("torchvision") -@experimental() +@to_review() class SRDatasetMixin: """Mixin for Super Resolution datasets. diff --git a/pl_bolts/datasets/sr_mnist_dataset.py b/pl_bolts/datasets/sr_mnist_dataset.py index 633d990be5..bdecd6d32d 100644 --- a/pl_bolts/datasets/sr_mnist_dataset.py +++ b/pl_bolts/datasets/sr_mnist_dataset.py @@ -3,7 +3,7 @@ from pl_bolts.datasets.mnist_dataset import MNIST from pl_bolts.datasets.sr_dataset_mixin import SRDatasetMixin from pl_bolts.utils import _PIL_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _PIL_AVAILABLE: @@ -12,7 +12,7 @@ warn_missing_pkg("PIL", pypi_name="Pillow") -@experimental() +@to_review() class SRMNIST(SRDatasetMixin, MNIST): """MNIST dataset that can be used to train Super Resolution models. diff --git a/pl_bolts/datasets/sr_stl10_dataset.py b/pl_bolts/datasets/sr_stl10_dataset.py index 1ecd360a56..908f3b6023 100644 --- a/pl_bolts/datasets/sr_stl10_dataset.py +++ b/pl_bolts/datasets/sr_stl10_dataset.py @@ -4,7 +4,7 @@ from pl_bolts.datasets.sr_dataset_mixin import SRDatasetMixin from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _PIL_AVAILABLE: @@ -19,7 +19,7 @@ STL10 = object -@experimental() +@to_review() class SRSTL10(SRDatasetMixin, STL10): """STL10 dataset that can be used to train Super Resolution models. diff --git a/pl_bolts/datasets/ssl_amdim_datasets.py b/pl_bolts/datasets/ssl_amdim_datasets.py index 5475b3ce79..e525aa5842 100644 --- a/pl_bolts/datasets/ssl_amdim_datasets.py +++ b/pl_bolts/datasets/ssl_amdim_datasets.py @@ -4,7 +4,7 @@ import numpy as np from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -14,7 +14,7 @@ CIFAR10 = object -@experimental() +@to_review() class SSLDatasetMixin(ABC): @classmethod def generate_train_val_split(cls, examples, labels, pct_val): @@ -91,7 +91,7 @@ def deterministic_shuffle(cls, x, y): return x, y -@experimental() +@to_review() class CIFAR10Mixed(SSLDatasetMixin, CIFAR10): def __init__( self, diff --git a/pl_bolts/datasets/utils.py b/pl_bolts/datasets/utils.py index 2dbb283f78..3a4583ae61 100644 --- a/pl_bolts/datasets/utils.py +++ b/pl_bolts/datasets/utils.py @@ -3,10 +3,10 @@ from pl_bolts.datasets.sr_celeba_dataset import SRCelebA from pl_bolts.datasets.sr_mnist_dataset import SRMNIST from pl_bolts.datasets.sr_stl10_dataset import SRSTL10 -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() def prepare_sr_datasets(dataset: str, scale_factor: int, data_dir: str): """Creates train, val, and test datasets for training a Super Resolution GAN. diff --git a/pl_bolts/losses/object_detection.py b/pl_bolts/losses/object_detection.py index a7b8e15fc9..72c6cae269 100644 --- a/pl_bolts/losses/object_detection.py +++ b/pl_bolts/losses/object_detection.py @@ -3,10 +3,10 @@ from torch import Tensor from pl_bolts.metrics.object_detection import giou, iou -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() def iou_loss(preds: Tensor, target: Tensor) -> Tensor: """Calculates the intersection over union loss. @@ -30,7 +30,7 @@ def iou_loss(preds: Tensor, target: Tensor) -> Tensor: return loss -@experimental() +@to_review() def giou_loss(preds: Tensor, target: Tensor) -> Tensor: """Calculates the generalized intersection over union loss. diff --git a/pl_bolts/losses/rl.py b/pl_bolts/losses/rl.py index 3e86bc0d71..5fdcb1f41f 100644 --- a/pl_bolts/losses/rl.py +++ b/pl_bolts/losses/rl.py @@ -6,10 +6,10 @@ import torch from torch import Tensor, nn -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() def dqn_loss(batch: Tuple[Tensor, Tensor], net: nn.Module, target_net: nn.Module, gamma: float = 0.99) -> Tensor: """Calculates the mse loss using a mini batch from the replay buffer. @@ -38,7 +38,7 @@ def dqn_loss(batch: Tuple[Tensor, Tensor], net: nn.Module, target_net: nn.Module return nn.MSELoss()(state_action_values, expected_state_action_values) -@experimental() +@to_review() def double_dqn_loss( batch: Tuple[Tensor, Tensor], net: nn.Module, @@ -84,7 +84,7 @@ def double_dqn_loss( return nn.MSELoss()(state_action_values, expected_state_action_values) -@experimental() +@to_review() def per_dqn_loss( batch: Tuple[Tensor, Tensor], batch_weights: List, diff --git a/pl_bolts/losses/self_supervised_learning.py b/pl_bolts/losses/self_supervised_learning.py index e0283f751a..57b5914809 100644 --- a/pl_bolts/losses/self_supervised_learning.py +++ b/pl_bolts/losses/self_supervised_learning.py @@ -3,10 +3,10 @@ from torch import nn from pl_bolts.models.vision.pixel_cnn import PixelCNN -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() def nt_xent_loss(out_1, out_2, temperature): """Loss used in SimCLR.""" out = torch.cat([out_1, out_2], dim=0) @@ -28,7 +28,7 @@ def nt_xent_loss(out_1, out_2, temperature): return loss -@experimental() +@to_review() class CPCTask(nn.Module): """Loss used in CPC.""" @@ -90,7 +90,7 @@ def forward(self, Z): return loss -@experimental() +@to_review() class AmdimNCELoss(nn.Module): """Compute the NCE scores for predicting r_src->r_trg.""" @@ -185,7 +185,7 @@ def forward(self, anchor_representations, positive_representations, mask_mat): return nce_scores, lgt_reg -@experimental() +@to_review() class FeatureMapContrastiveTask(nn.Module): """Performs an anchor, positive negative pair comparison for each each tuple of feature maps passed. @@ -370,7 +370,7 @@ def forward(self, anchor_maps, positive_maps): return torch.stack(losses), regularizer -@experimental() +@to_review() def tanh_clip(x, clip_val=10.0): """soft clip values to the range [-clip_val, +clip_val]""" if clip_val is not None: diff --git a/pl_bolts/metrics/aggregation.py b/pl_bolts/metrics/aggregation.py index 4e70d7fcb3..0f2e67319e 100644 --- a/pl_bolts/metrics/aggregation.py +++ b/pl_bolts/metrics/aggregation.py @@ -1,15 +1,15 @@ import torch -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() def mean(res, key): # recursive mean for multilevel dicts return torch.stack([x[key] if isinstance(x, dict) else mean(x, key) for x in res]).mean() -@experimental() +@to_review() def accuracy(preds, labels): preds = preds.float() max_lgt = torch.max(preds, 1)[1] @@ -20,7 +20,7 @@ def accuracy(preds, labels): return acc -@experimental() +@to_review() def precision_at_k(output, target, top_k=(1,)): """Computes the accuracy over the k top predictions for the specified values of k.""" with torch.no_grad(): diff --git a/pl_bolts/metrics/object_detection.py b/pl_bolts/metrics/object_detection.py index bd0f1888ad..715ef6052b 100644 --- a/pl_bolts/metrics/object_detection.py +++ b/pl_bolts/metrics/object_detection.py @@ -1,10 +1,10 @@ import torch from torch import Tensor -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() def iou(preds: Tensor, target: Tensor) -> Tensor: """Calculates the intersection over union. @@ -37,7 +37,7 @@ def iou(preds: Tensor, target: Tensor) -> Tensor: return iou -@experimental() +@to_review() def giou(preds: Tensor, target: Tensor) -> Tensor: """Calculates the generalized intersection over union. diff --git a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py index 47350247b1..4eceb8b9fa 100644 --- a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py +++ b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py @@ -13,10 +13,10 @@ resnet50_decoder, resnet50_encoder, ) -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class AE(LightningModule): """Standard AE. @@ -152,7 +152,7 @@ def add_model_specific_args(parent_parser): return parser -@experimental() +@to_review() def cli_main(args=None): from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule diff --git a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py index 9114271361..876a620732 100644 --- a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py +++ b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py @@ -13,10 +13,10 @@ resnet50_decoder, resnet50_encoder, ) -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class VAE(LightningModule): """Standard VAE with Gaussian Prior and approx posterior. @@ -184,7 +184,7 @@ def add_model_specific_args(parent_parser): return parser -@experimental() +@to_review() def cli_main(args=None): from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule diff --git a/pl_bolts/models/autoencoders/components.py b/pl_bolts/models/autoencoders/components.py index 2abdc9f500..be58357352 100644 --- a/pl_bolts/models/autoencoders/components.py +++ b/pl_bolts/models/autoencoders/components.py @@ -2,10 +2,10 @@ from torch import nn from torch.nn import functional as F -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class Interpolate(nn.Module): """nn.Module wrapper for F.interpolate.""" @@ -17,19 +17,19 @@ def forward(self, x): return F.interpolate(x, size=self.size, scale_factor=self.scale_factor) -@experimental() +@to_review() def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding.""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) -@experimental() +@to_review() def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution.""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) -@experimental() +@to_review() def resize_conv3x3(in_planes, out_planes, scale=1): """upsample + 3x3 convolution with padding to avoid checkerboard artifact.""" if scale == 1: @@ -37,7 +37,7 @@ def resize_conv3x3(in_planes, out_planes, scale=1): return nn.Sequential(Interpolate(scale_factor=scale), conv3x3(in_planes, out_planes)) -@experimental() +@to_review() def resize_conv1x1(in_planes, out_planes, scale=1): """upsample + 1x1 convolution with padding to avoid checkerboard artifact.""" if scale == 1: @@ -45,7 +45,7 @@ def resize_conv1x1(in_planes, out_planes, scale=1): return nn.Sequential(Interpolate(scale_factor=scale), conv1x1(in_planes, out_planes)) -@experimental() +@to_review() class EncoderBlock(nn.Module): """ResNet block, copied from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L35.""" @@ -79,7 +79,7 @@ def forward(self, x): return out -@experimental() +@to_review() class EncoderBottleneck(nn.Module): """ResNet bottleneck, copied from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L75.""" @@ -121,7 +121,7 @@ def forward(self, x): return out -@experimental() +@to_review() class DecoderBlock(nn.Module): """ResNet block, but convs replaced with resize convs, and channel increase is in second conv, not first.""" @@ -155,7 +155,7 @@ def forward(self, x): return out -@experimental() +@to_review() class DecoderBottleneck(nn.Module): """ResNet bottleneck, but convs replaced with resize convs.""" @@ -196,7 +196,7 @@ def forward(self, x): return out -@experimental() +@to_review() class ResNetEncoder(nn.Module): def __init__(self, block, layers, first_conv=False, maxpool1=False): super().__init__() @@ -256,7 +256,7 @@ def forward(self, x): return x -@experimental() +@to_review() class ResNetDecoder(nn.Module): """Resnet in reverse order.""" @@ -329,21 +329,21 @@ def forward(self, x): return x -@experimental() +@to_review() def resnet18_encoder(first_conv, maxpool1): return ResNetEncoder(EncoderBlock, [2, 2, 2, 2], first_conv, maxpool1) -@experimental() +@to_review() def resnet18_decoder(latent_dim, input_height, first_conv, maxpool1): return ResNetDecoder(DecoderBlock, [2, 2, 2, 2], latent_dim, input_height, first_conv, maxpool1) -@experimental() +@to_review() def resnet50_encoder(first_conv, maxpool1): return ResNetEncoder(EncoderBottleneck, [3, 4, 6, 3], first_conv, maxpool1) -@experimental() +@to_review() def resnet50_decoder(latent_dim, input_height, first_conv, maxpool1): return ResNetDecoder(DecoderBottleneck, [3, 4, 6, 3], latent_dim, input_height, first_conv, maxpool1) diff --git a/pl_bolts/models/detection/components/torchvision_backbones.py b/pl_bolts/models/detection/components/torchvision_backbones.py index 98fd2b2bd8..fc43392884 100644 --- a/pl_bolts/models/detection/components/torchvision_backbones.py +++ b/pl_bolts/models/detection/components/torchvision_backbones.py @@ -4,11 +4,11 @@ from pl_bolts.models.detection.components._supported_models import TORCHVISION_MODEL_ZOO from pl_bolts.utils import _TORCHVISION_AVAILABLE # noqa: F401 -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg # noqa: F401 -@experimental() +@to_review() def _create_backbone_generic(model: nn.Module, out_channels: int) -> nn.Module: """Generic Backbone creater. It removes the last linear layer. @@ -25,7 +25,7 @@ def _create_backbone_generic(model: nn.Module, out_channels: int) -> nn.Module: # Use this when you have Adaptive Pooling layer in End. # When Model.features is not applicable. -@experimental() +@to_review() def _create_backbone_adaptive(model: nn.Module, out_channels: Optional[int] = None) -> nn.Module: """Creates backbone by removing linear after Adaptive Pooling layer. @@ -39,7 +39,7 @@ def _create_backbone_adaptive(model: nn.Module, out_channels: Optional[int] = No return _create_backbone_generic(model, out_channels=out_channels) -@experimental() +@to_review() def _create_backbone_features(model: nn.Module, out_channels: int) -> nn.Module: """Creates backbone from feature sequential block. @@ -52,7 +52,7 @@ def _create_backbone_features(model: nn.Module, out_channels: int) -> nn.Module: return ft_backbone -@experimental() +@to_review() def create_torchvision_backbone(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: """Creates CNN backbone from Torchvision. diff --git a/pl_bolts/models/detection/faster_rcnn/backbones.py b/pl_bolts/models/detection/faster_rcnn/backbones.py index 49254b81f2..d51085d037 100644 --- a/pl_bolts/models/detection/faster_rcnn/backbones.py +++ b/pl_bolts/models/detection/faster_rcnn/backbones.py @@ -4,7 +4,7 @@ from pl_bolts.models.detection.components import create_torchvision_backbone from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -13,7 +13,7 @@ warn_missing_pkg("torchvision") -@experimental() +@to_review() def create_fasterrcnn_backbone( backbone: str, fpn: bool = True, pretrained: Optional[str] = None, trainable_backbone_layers: int = 3, **kwargs: Any ) -> nn.Module: diff --git a/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py b/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py index ea49a1369e..d98c55cfd1 100644 --- a/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py +++ b/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py @@ -6,7 +6,7 @@ from pl_bolts.models.detection.faster_rcnn import create_fasterrcnn_backbone from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -17,7 +17,7 @@ warn_missing_pkg("torchvision") -@experimental() +@to_review() def _evaluate_iou(target, pred): """Evaluate intersection over union (IOU) for target from dataset and output prediction from model.""" if not _TORCHVISION_AVAILABLE: # pragma: no cover @@ -29,7 +29,7 @@ def _evaluate_iou(target, pred): return box_iou(target["boxes"], pred["boxes"]).diag().mean() -@experimental() +@to_review() class FasterRCNN(LightningModule): """PyTorch Lightning implementation of `Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks `_. @@ -155,7 +155,7 @@ def add_model_specific_args(parent_parser): return parser -@experimental() +@to_review() def run_cli(): from pl_bolts.datamodules import VOCDetectionDataModule diff --git a/pl_bolts/models/detection/retinanet/backbones.py b/pl_bolts/models/detection/retinanet/backbones.py index ee761f3710..d9dca39c8d 100644 --- a/pl_bolts/models/detection/retinanet/backbones.py +++ b/pl_bolts/models/detection/retinanet/backbones.py @@ -4,7 +4,7 @@ from pl_bolts.models.detection.components import create_torchvision_backbone from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -13,7 +13,7 @@ warn_missing_pkg("torchvision") -@experimental() +@to_review() def create_retinanet_backbone( backbone: str, fpn: bool = True, pretrained: Optional[str] = None, trainable_backbone_layers: int = 3, **kwargs: Any ) -> nn.Module: diff --git a/pl_bolts/models/detection/retinanet/retinanet_module.py b/pl_bolts/models/detection/retinanet/retinanet_module.py index 18f0c24188..caf2432eae 100644 --- a/pl_bolts/models/detection/retinanet/retinanet_module.py +++ b/pl_bolts/models/detection/retinanet/retinanet_module.py @@ -5,7 +5,7 @@ from pl_bolts.models.detection.retinanet import create_retinanet_backbone from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -16,7 +16,7 @@ warn_missing_pkg("torchvision") -@experimental() +@to_review() class RetinaNet(LightningModule): """PyTorch Lightning implementation of RetinaNet. @@ -120,7 +120,7 @@ def configure_optimizers(self): ) -@experimental() +@to_review() def cli_main(): from pytorch_lightning.utilities.cli import LightningCLI diff --git a/pl_bolts/models/detection/yolo/yolo_config.py b/pl_bolts/models/detection/yolo/yolo_config.py index 5fb12ae16b..25548e5c34 100644 --- a/pl_bolts/models/detection/yolo/yolo_config.py +++ b/pl_bolts/models/detection/yolo/yolo_config.py @@ -6,10 +6,10 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pl_bolts.models.detection.yolo import yolo_layers -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class YOLOConfiguration: """This class can be used to parse the configuration files of the Darknet YOLOv4 implementation. @@ -149,7 +149,7 @@ def convert(key, value): return sections -@experimental() +@to_review() def _create_layer(config: dict, num_inputs: List[int]) -> Tuple[nn.Module, int]: """Calls one of the ``_create_(config, num_inputs)`` functions to create a PyTorch module from the layer config. @@ -173,7 +173,7 @@ def _create_layer(config: dict, num_inputs: List[int]) -> Tuple[nn.Module, int]: return create_func[config["type"]](config, num_inputs) -@experimental() +@to_review() def _create_convolutional(config, num_inputs): module = nn.Sequential() @@ -210,14 +210,14 @@ def _create_convolutional(config, num_inputs): return module, config["filters"] -@experimental() +@to_review() def _create_maxpool(config, num_inputs): padding = (config["size"] - 1) // 2 module = nn.MaxPool2d(config["size"], config["stride"], padding) return module, num_inputs[-1] -@experimental() +@to_review() def _create_route(config, num_inputs): num_chunks = config.get("groups", 1) chunk_idx = config.get("group_id", 0) @@ -234,19 +234,19 @@ def _create_route(config, num_inputs): return module, num_outputs -@experimental() +@to_review() def _create_shortcut(config, num_inputs): module = yolo_layers.ShortcutLayer(config["from"]) return module, num_inputs[-1] -@experimental() +@to_review() def _create_upsample(config, num_inputs): module = nn.Upsample(scale_factor=config["stride"], mode="nearest") return module, num_inputs[-1] -@experimental() +@to_review() def _create_yolo(config, num_inputs): # The "anchors" list alternates width and height. anchor_dims = config["anchors"] diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index be207d7fb6..41f9703683 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -5,7 +5,7 @@ from torch import Tensor, nn from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -21,7 +21,7 @@ warn_missing_pkg("torchvision") -@experimental() +@to_review() def _corner_coordinates(xy: Tensor, wh: Tensor) -> Tensor: """Converts box center points and sizes to corner coordinates. @@ -38,7 +38,7 @@ def _corner_coordinates(xy: Tensor, wh: Tensor) -> Tensor: return torch.cat((top_left, bottom_right), -1) -@experimental() +@to_review() def _aligned_iou(dims1: Tensor, dims2: Tensor) -> Tensor: """Calculates a matrix of intersections over union from box dimensions, assuming that the boxes are located at the same coordinates. @@ -61,7 +61,7 @@ def _aligned_iou(dims1: Tensor, dims2: Tensor) -> Tensor: return inter / union -@experimental() +@to_review() class SELoss(nn.MSELoss): def __init__(self): super().__init__(reduction="none") @@ -70,13 +70,13 @@ def forward(self, inputs: Tensor, target: Tensor) -> Tensor: return super().forward(inputs, target).sum(1) -@experimental() +@to_review() class IoULoss(nn.Module): def forward(self, inputs: Tensor, target: Tensor) -> Tensor: return 1.0 - box_iou(inputs, target).diagonal() -@experimental() +@to_review() class GIoULoss(nn.Module): def __init__(self) -> None: super().__init__() @@ -89,7 +89,7 @@ def forward(self, inputs: Tensor, target: Tensor) -> Tensor: return 1.0 - generalized_box_iou(inputs, target).diagonal() -@experimental() +@to_review() class DetectionLayer(nn.Module): """A YOLO detection layer. @@ -468,7 +468,7 @@ def _calculate_losses( return losses, hits -@experimental() +@to_review() class Mish(nn.Module): """Mish activation.""" @@ -476,7 +476,7 @@ def forward(self, x): return x * torch.tanh(nn.functional.softplus(x)) -@experimental() +@to_review() class RouteLayer(nn.Module): """Route layer concatenates the output (or part of it) from given layers.""" @@ -497,7 +497,7 @@ def forward(self, x, outputs): return torch.cat(chunks, dim=1) -@experimental() +@to_review() class ShortcutLayer(nn.Module): """Shortcut layer adds a residual connection from the source layer.""" diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index 3ab036e814..e742e2c5c6 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -11,7 +11,7 @@ from pl_bolts.models.detection.yolo.yolo_layers import DetectionLayer, RouteLayer, ShortcutLayer from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -23,7 +23,7 @@ log = logging.getLogger(__name__) -@experimental() +@to_review() class YOLO(LightningModule): """PyTorch Lightning implementation of YOLOv3 and YOLOv4. @@ -455,7 +455,7 @@ def _filter_detections(self, detections: Dict[str, Tensor]) -> Dict[str, List[Te return {"boxes": out_boxes, "scores": out_scores, "classprobs": out_classprobs, "labels": out_labels} -@experimental() +@to_review() class Resize: """Rescales the image and target to given dimensions. @@ -486,7 +486,7 @@ def __call__(self, image: Tensor, target: Dict[str, Any]): return image, target -@experimental() +@to_review() def run_cli(): from argparse import ArgumentParser diff --git a/pl_bolts/models/gans/basic/basic_gan_module.py b/pl_bolts/models/gans/basic/basic_gan_module.py index 55a29b0f43..d217c20249 100644 --- a/pl_bolts/models/gans/basic/basic_gan_module.py +++ b/pl_bolts/models/gans/basic/basic_gan_module.py @@ -5,10 +5,10 @@ from torch.nn import functional as F from pl_bolts.models.gans.basic.components import Discriminator, Generator -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class GAN(LightningModule): """Vanilla GAN implementation. @@ -166,7 +166,7 @@ def add_model_specific_args(parent_parser): return parser -@experimental() +@to_review() def cli_main(args=None): from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, MNISTDataModule, STL10DataModule diff --git a/pl_bolts/models/gans/basic/components.py b/pl_bolts/models/gans/basic/components.py index f63af4079f..1648e78667 100644 --- a/pl_bolts/models/gans/basic/components.py +++ b/pl_bolts/models/gans/basic/components.py @@ -3,10 +3,10 @@ from torch import nn from torch.nn import functional as F -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class Generator(nn.Module): def __init__(self, latent_dim, img_shape, hidden_dim=256): super().__init__() @@ -27,7 +27,7 @@ def forward(self, z): return img -@experimental() +@to_review() class Discriminator(nn.Module): def __init__(self, img_shape, hidden_dim=1024): super().__init__() diff --git a/pl_bolts/models/gans/dcgan/components.py b/pl_bolts/models/gans/dcgan/components.py index 3f3facc9e5..55055bf498 100644 --- a/pl_bolts/models/gans/dcgan/components.py +++ b/pl_bolts/models/gans/dcgan/components.py @@ -1,10 +1,10 @@ # Based on https://github.com/pytorch/examples/blob/master/dcgan/main.py from torch import Tensor, nn -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class DCGANGenerator(nn.Module): def __init__(self, latent_dim: int, feature_maps: int, image_channels: int) -> None: """ @@ -50,7 +50,7 @@ def forward(self, noise: Tensor) -> Tensor: return self.gen(noise) -@experimental() +@to_review() class DCGANDiscriminator(nn.Module): def __init__(self, feature_maps: int, image_channels: int) -> None: """ diff --git a/pl_bolts/models/gans/dcgan/dcgan_module.py b/pl_bolts/models/gans/dcgan/dcgan_module.py index d9a54c2d60..1ca3522dda 100644 --- a/pl_bolts/models/gans/dcgan/dcgan_module.py +++ b/pl_bolts/models/gans/dcgan/dcgan_module.py @@ -9,7 +9,7 @@ from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler from pl_bolts.models.gans.dcgan.components import DCGANDiscriminator, DCGANGenerator from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -19,7 +19,7 @@ warn_missing_pkg("torchvision") -@experimental() +@to_review() class DCGAN(LightningModule): """DCGAN implementation. @@ -173,7 +173,7 @@ def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: return parser -@experimental() +@to_review() def cli_main(args=None): seed_everything(1234) diff --git a/pl_bolts/models/gans/pix2pix/components.py b/pl_bolts/models/gans/pix2pix/components.py index d865d06576..088ef31912 100644 --- a/pl_bolts/models/gans/pix2pix/components.py +++ b/pl_bolts/models/gans/pix2pix/components.py @@ -1,10 +1,10 @@ import torch from torch import nn -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class UpSampleConv(nn.Module): def __init__( self, in_channels, out_channels, kernel=4, strides=2, padding=1, activation=True, batchnorm=True, dropout=False @@ -35,7 +35,7 @@ def forward(self, x): return x -@experimental() +@to_review() class DownSampleConv(nn.Module): def __init__(self, in_channels, out_channels, kernel=4, strides=2, padding=1, activation=True, batchnorm=True): """Paper details: @@ -65,7 +65,7 @@ def forward(self, x): return x -@experimental() +@to_review() class Generator(nn.Module): def __init__(self, in_channels, out_channels): """Paper details: @@ -127,7 +127,7 @@ def forward(self, x): return self.tanh(x) -@experimental() +@to_review() class PatchGAN(nn.Module): def __init__(self, input_channels): super().__init__() diff --git a/pl_bolts/models/gans/pix2pix/pix2pix_module.py b/pl_bolts/models/gans/pix2pix/pix2pix_module.py index 3aeb5b55d4..97da62e1cc 100644 --- a/pl_bolts/models/gans/pix2pix/pix2pix_module.py +++ b/pl_bolts/models/gans/pix2pix/pix2pix_module.py @@ -3,10 +3,10 @@ from torch import nn from pl_bolts.models.gans.pix2pix.components import Generator, PatchGAN -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() def _weights_init(m): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): torch.nn.init.normal_(m.weight, 0.0, 0.02) @@ -15,7 +15,7 @@ def _weights_init(m): torch.nn.init.constant_(m.bias, 0) -@experimental() +@to_review() class Pix2Pix(LightningModule): def __init__(self, in_channels, out_channels, learning_rate=0.0002, lambda_recon=200): diff --git a/pl_bolts/models/gans/srgan/components.py b/pl_bolts/models/gans/srgan/components.py index 3ea0fa60f2..0f17dcb375 100644 --- a/pl_bolts/models/gans/srgan/components.py +++ b/pl_bolts/models/gans/srgan/components.py @@ -3,7 +3,7 @@ import torch.nn as nn from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -12,7 +12,7 @@ warn_missing_pkg("torchvision") -@experimental() +@to_review() class ResidualBlock(nn.Module): def __init__(self, feature_maps: int = 64) -> None: super().__init__() @@ -29,7 +29,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x + self.block(x) -@experimental() +@to_review() class SRGANGenerator(nn.Module): def __init__( self, @@ -81,7 +81,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -@experimental() +@to_review() class SRGANDiscriminator(nn.Module): def __init__(self, image_channels: int, feature_maps: int = 64) -> None: super().__init__() @@ -135,7 +135,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -@experimental() +@to_review() class VGG19FeatureExtractor(nn.Module): def __init__(self, image_channels: int = 3) -> None: super().__init__() diff --git a/pl_bolts/models/gans/srgan/srgan_module.py b/pl_bolts/models/gans/srgan/srgan_module.py index 3847ddab2a..afab580994 100644 --- a/pl_bolts/models/gans/srgan/srgan_module.py +++ b/pl_bolts/models/gans/srgan/srgan_module.py @@ -12,10 +12,10 @@ from pl_bolts.datamodules import TVTDataModule from pl_bolts.datasets.utils import prepare_sr_datasets from pl_bolts.models.gans.srgan.components import SRGANDiscriminator, SRGANGenerator, VGG19FeatureExtractor -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class SRGAN(pl.LightningModule): """SRGAN implementation from the paper `Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network `__. It uses a pretrained SRResNet model as the generator @@ -183,7 +183,7 @@ def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: return parser -@experimental() +@to_review() def cli_main(args=None): pl.seed_everything(1234) diff --git a/pl_bolts/models/gans/srgan/srresnet_module.py b/pl_bolts/models/gans/srgan/srresnet_module.py index 49d298c4e4..78a61c8767 100644 --- a/pl_bolts/models/gans/srgan/srresnet_module.py +++ b/pl_bolts/models/gans/srgan/srresnet_module.py @@ -10,10 +10,10 @@ from pl_bolts.datamodules import TVTDataModule from pl_bolts.datasets.utils import prepare_sr_datasets from pl_bolts.models.gans.srgan.components import SRGANGenerator -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class SRResNet(pl.LightningModule): """SRResNet implementation from the paper `Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network `__. A pretrained SRResNet model is used as the generator @@ -110,7 +110,7 @@ def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: return parser -@experimental() +@to_review() def cli_main(args=None): pl.seed_everything(1234) diff --git a/pl_bolts/models/mnist_module.py b/pl_bolts/models/mnist_module.py index 189e4b2309..457b30744d 100644 --- a/pl_bolts/models/mnist_module.py +++ b/pl_bolts/models/mnist_module.py @@ -7,7 +7,7 @@ from pl_bolts.datasets import MNIST from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -16,7 +16,7 @@ warn_missing_pkg("torchvision") -@experimental() +@to_review() class LitMNIST(LightningModule): def __init__(self, hidden_dim=128, learning_rate=1e-3, batch_size=32, num_workers=4, data_dir="", **kwargs): if not _TORCHVISION_AVAILABLE: # pragma: no cover @@ -90,7 +90,7 @@ def add_model_specific_args(parent_parser): return parser -@experimental() +@to_review() def cli_main(): # args parser = ArgumentParser() diff --git a/pl_bolts/models/regression/linear_regression.py b/pl_bolts/models/regression/linear_regression.py index 816cb254b2..9f5546742d 100644 --- a/pl_bolts/models/regression/linear_regression.py +++ b/pl_bolts/models/regression/linear_regression.py @@ -8,10 +8,10 @@ from torch.optim import Adam from torch.optim.optimizer import Optimizer -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class LinearRegression(LightningModule): """ Linear regression model implementing - with optional L1/L2 regularization @@ -112,7 +112,7 @@ def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: return parser -@experimental() +@to_review() def cli_main() -> None: from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule from pl_bolts.utils import _SKLEARN_AVAILABLE diff --git a/pl_bolts/models/regression/logistic_regression.py b/pl_bolts/models/regression/logistic_regression.py index 34fe809737..41a09a92d7 100644 --- a/pl_bolts/models/regression/logistic_regression.py +++ b/pl_bolts/models/regression/logistic_regression.py @@ -10,10 +10,10 @@ from torch.optim.optimizer import Optimizer from torchmetrics.functional import accuracy -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class LogisticRegression(LightningModule): """Logistic regression model.""" @@ -118,7 +118,7 @@ def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: return parser -@experimental() +@to_review() def cli_main() -> None: from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule from pl_bolts.utils import _SKLEARN_AVAILABLE diff --git a/pl_bolts/models/rl/advantage_actor_critic_model.py b/pl_bolts/models/rl/advantage_actor_critic_model.py index 23323d03ee..e3b32206fc 100644 --- a/pl_bolts/models/rl/advantage_actor_critic_model.py +++ b/pl_bolts/models/rl/advantage_actor_critic_model.py @@ -15,7 +15,7 @@ from pl_bolts.models.rl.common.agents import ActorCriticAgent from pl_bolts.models.rl.common.networks import ActorCriticMLP from pl_bolts.utils import _GYM_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _GYM_AVAILABLE: @@ -24,7 +24,7 @@ warn_missing_pkg("gym") -@experimental() +@to_review() class AdvantageActorCritic(LightningModule): """PyTorch Lightning implementation of `Advantage Actor Critic `_. @@ -296,7 +296,7 @@ def add_model_specific_args(arg_parser: ArgumentParser) -> ArgumentParser: return arg_parser -@experimental() +@to_review() def cli_main() -> None: parser = ArgumentParser(add_help=False) diff --git a/pl_bolts/models/rl/common/agents.py b/pl_bolts/models/rl/common/agents.py index c3b551bbc5..6f29837556 100644 --- a/pl_bolts/models/rl/common/agents.py +++ b/pl_bolts/models/rl/common/agents.py @@ -10,10 +10,10 @@ from torch import Tensor, nn from torch.nn import functional as F -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class Agent(ABC): """Basic agent that always returns 0.""" @@ -33,7 +33,7 @@ def __call__(self, state: Tensor, device: str, *args, **kwargs) -> List[int]: return [0] -@experimental() +@to_review() class ValueAgent(Agent): """Value based agent that returns an action based on the Q values from the network.""" @@ -109,7 +109,7 @@ def update_epsilon(self, step: int) -> None: self.epsilon = max(self.eps_end, self.eps_start - (step + 1) / self.eps_frames) -@experimental() +@to_review() class PolicyAgent(Agent): """Policy based agent that returns an action based on the networks policy.""" @@ -139,7 +139,7 @@ def __call__(self, states: Tensor, device: str) -> List[int]: return actions -@experimental() +@to_review() class ActorCriticAgent(Agent): """Actor-Critic based agent that returns an action based on the networks policy.""" @@ -169,7 +169,7 @@ def __call__(self, states: Tensor, device: str) -> List[int]: return actions -@experimental() +@to_review() class SoftActorCriticAgent(Agent): """Actor-Critic based agent that returns a continuous action based on the policy.""" diff --git a/pl_bolts/models/rl/common/cli.py b/pl_bolts/models/rl/common/cli.py index e7ab729bdb..2b8b5d3e20 100644 --- a/pl_bolts/models/rl/common/cli.py +++ b/pl_bolts/models/rl/common/cli.py @@ -2,10 +2,10 @@ import argparse -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() def add_base_args(parent) -> argparse.ArgumentParser: """Adds arguments for DQN model. diff --git a/pl_bolts/models/rl/common/distributions.py b/pl_bolts/models/rl/common/distributions.py index 759f7e235d..fb348e179d 100644 --- a/pl_bolts/models/rl/common/distributions.py +++ b/pl_bolts/models/rl/common/distributions.py @@ -1,10 +1,10 @@ """Distributions used in some continuous RL algorithms.""" import torch -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class TanhMultivariateNormal(torch.distributions.MultivariateNormal): """The distribution of X is an affine of tanh applied on a normal distribution. diff --git a/pl_bolts/models/rl/common/gym_wrappers.py b/pl_bolts/models/rl/common/gym_wrappers.py index 8eb80e701d..5e526e6eec 100644 --- a/pl_bolts/models/rl/common/gym_wrappers.py +++ b/pl_bolts/models/rl/common/gym_wrappers.py @@ -6,7 +6,7 @@ import torch from pl_bolts.utils import _GYM_AVAILABLE, _OPENCV_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _GYM_AVAILABLE: @@ -24,7 +24,7 @@ warn_missing_pkg("cv2", pypi_name="opencv-python") -@experimental() +@to_review() class ToTensor(Wrapper): """For environments where the user need to press FIRE for the game to start.""" @@ -44,7 +44,7 @@ def reset(self): return torch.tensor(self.env.reset()) -@experimental() +@to_review() class FireResetEnv(Wrapper): """For environments where the user need to press FIRE for the game to start.""" @@ -72,7 +72,7 @@ def reset(self): return obs -@experimental() +@to_review() class MaxAndSkipEnv(Wrapper): """Return only every `skip`-th frame.""" @@ -109,7 +109,7 @@ def reset(self): return obs -@experimental() +@to_review() class ProcessFrame84(ObservationWrapper): """preprocessing images from env.""" @@ -140,7 +140,7 @@ def process(frame): return x_t.astype(np.uint8) -@experimental() +@to_review() class ImageToPyTorch(ObservationWrapper): """converts image to pytorch format.""" @@ -159,7 +159,7 @@ def observation(observation): return np.moveaxis(observation, 2, 0) -@experimental() +@to_review() class ScaledFloatFrame(ObservationWrapper): """scales the pixels.""" @@ -168,7 +168,7 @@ def observation(obs): return np.array(obs).astype(np.float32) / 255.0 -@experimental() +@to_review() class BufferWrapper(ObservationWrapper): """Wrapper for image stacking.""" @@ -195,7 +195,7 @@ def observation(self, observation): return self.buffer -@experimental() +@to_review() class DataAugmentation(ObservationWrapper): """Carries out basic data augmentation on the env observations. @@ -216,7 +216,7 @@ def observation(self, obs): return ProcessFrame84.process(obs) -@experimental() +@to_review() def make_environment(env_name): """Convert environment with wrappers.""" env = gym_make(env_name) diff --git a/pl_bolts/models/rl/common/memory.py b/pl_bolts/models/rl/common/memory.py index 611c13fb65..2b2d932da7 100644 --- a/pl_bolts/models/rl/common/memory.py +++ b/pl_bolts/models/rl/common/memory.py @@ -7,12 +7,12 @@ import numpy as np -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review Experience = namedtuple("Experience", field_names=["state", "action", "reward", "done", "new_state"]) -@experimental() +@to_review() class Buffer: """Basic Buffer for storing a single experience at a time.""" @@ -54,7 +54,7 @@ def sample(self, *args) -> Union[Tuple, List[Tuple]]: ) -@experimental() +@to_review() class ReplayBuffer(Buffer): """Replay Buffer for storing past experiences allowing the agent to learn from them.""" @@ -80,7 +80,7 @@ def sample(self, batch_size: int) -> Tuple: ) -@experimental() +@to_review() class MultiStepBuffer(ReplayBuffer): """N Step Replay Buffer.""" @@ -189,7 +189,7 @@ def discount_rewards(self, experiences: Tuple[Experience]) -> float: return total_reward -@experimental() +@to_review() class MeanBuffer: """Stores a deque of items and calculates the mean.""" @@ -212,7 +212,7 @@ def mean(self) -> float: return self.sum / len(self.deque) -@experimental() +@to_review() class PERBuffer(ReplayBuffer): """simple list based Prioritized Experience Replay Buffer Based on implementation found here: diff --git a/pl_bolts/models/rl/common/networks.py b/pl_bolts/models/rl/common/networks.py index 533a34e162..42de384e54 100644 --- a/pl_bolts/models/rl/common/networks.py +++ b/pl_bolts/models/rl/common/networks.py @@ -9,10 +9,10 @@ from torch.nn import functional as F from pl_bolts.models.rl.common.distributions import TanhMultivariateNormal -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class CNN(nn.Module): """Simple MLP network.""" @@ -59,7 +59,7 @@ def forward(self, input_x) -> Tensor: return self.head(conv_out) -@experimental() +@to_review() class MLP(nn.Module): """Simple MLP network.""" @@ -89,7 +89,7 @@ def forward(self, input_x): return self.net(input_x.float()) -@experimental() +@to_review() class ContinuousMLP(nn.Module): """MLP network that outputs continuous value via Gaussian distribution.""" @@ -148,7 +148,7 @@ def get_action(self, x: FloatTensor) -> Tensor: return self.action_scale * torch.tanh(batch_mean) + self.action_bias -@experimental() +@to_review() class ActorCriticMLP(nn.Module): """MLP network with heads for actor and critic.""" @@ -180,7 +180,7 @@ def forward(self, x) -> Tuple[Tensor, Tensor]: return a, c -@experimental() +@to_review() class DuelingMLP(nn.Module): """MLP network with duel heads for val and advantage.""" @@ -233,7 +233,7 @@ def adv_val(self, input_x) -> Tuple[Tensor, Tensor]: return self.fc_adv(base_out), self.fc_val(base_out) -@experimental() +@to_review() class DuelingCNN(nn.Module): """CNN network with duel heads for val and advantage.""" @@ -302,7 +302,7 @@ def adv_val(self, input_x): return self.head_adv(base_out), self.head_val(base_out) -@experimental() +@to_review() class NoisyCNN(nn.Module): """CNN with Noisy Linear layers for exploration.""" @@ -356,7 +356,7 @@ def forward(self, input_x) -> Tensor: ################### -@experimental() +@to_review() class NoisyLinear(nn.Linear): """Noisy Layer using Independent Gaussian Noise. @@ -413,7 +413,7 @@ def forward(self, input_x: Tensor) -> Tensor: return F.linear(input_x, noisy_weights, bias) -@experimental() +@to_review() class ActorCategorical(nn.Module): """Policy network, for discrete action spaces, which returns a distribution and an action given an observation.""" @@ -447,7 +447,7 @@ def get_log_prob(self, pi: Categorical, actions: Tensor): return pi.log_prob(actions) -@experimental() +@to_review() class ActorContinous(nn.Module): """Policy network, for continous action spaces, which returns a distribution and an action given an observation.""" diff --git a/pl_bolts/models/rl/double_dqn_model.py b/pl_bolts/models/rl/double_dqn_model.py index ad0f53b0bc..5aa2c8b745 100644 --- a/pl_bolts/models/rl/double_dqn_model.py +++ b/pl_bolts/models/rl/double_dqn_model.py @@ -8,10 +8,10 @@ from pl_bolts.losses.rl import double_dqn_loss from pl_bolts.models.rl.dqn_model import DQN -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class DoubleDQN(DQN): """Double Deep Q-network (DDQN) PyTorch Lightning implementation of `Double DQN`_. @@ -81,7 +81,7 @@ def training_step(self, batch: Tuple[Tensor, Tensor], _) -> OrderedDict: ) -@experimental() +@to_review() def cli_main(): parser = argparse.ArgumentParser(add_help=False) diff --git a/pl_bolts/models/rl/dqn_model.py b/pl_bolts/models/rl/dqn_model.py index ce73ba9ccb..9f62c74214 100644 --- a/pl_bolts/models/rl/dqn_model.py +++ b/pl_bolts/models/rl/dqn_model.py @@ -19,7 +19,7 @@ from pl_bolts.models.rl.common.memory import MultiStepBuffer from pl_bolts.models.rl.common.networks import CNN from pl_bolts.utils import _GYM_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _GYM_AVAILABLE: @@ -29,7 +29,7 @@ Env = object -@experimental() +@to_review() class DQN(LightningModule): """Basic DQN Model. @@ -410,7 +410,7 @@ def _use_dp_or_ddp2(trainer: Trainer) -> bool: return isinstance(trainer.training_type_plugin, (DataParallelPlugin, DDP2Plugin)) -@experimental() +@to_review() def cli_main(): parser = argparse.ArgumentParser(add_help=False) diff --git a/pl_bolts/models/rl/dueling_dqn_model.py b/pl_bolts/models/rl/dueling_dqn_model.py index 018ba8d83f..1485c54f82 100644 --- a/pl_bolts/models/rl/dueling_dqn_model.py +++ b/pl_bolts/models/rl/dueling_dqn_model.py @@ -5,10 +5,10 @@ from pl_bolts.models.rl.common.networks import DuelingCNN from pl_bolts.models.rl.dqn_model import DQN -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class DuelingDQN(DQN): """PyTorch Lightning implementation of `Dueling DQN `_ @@ -38,7 +38,7 @@ def build_networks(self) -> None: self.target_net = DuelingCNN(self.obs_shape, self.n_actions) -@experimental() +@to_review() def cli_main(): parser = argparse.ArgumentParser(add_help=False) diff --git a/pl_bolts/models/rl/noisy_dqn_model.py b/pl_bolts/models/rl/noisy_dqn_model.py index b4aa5b6aed..88785dc5ff 100644 --- a/pl_bolts/models/rl/noisy_dqn_model.py +++ b/pl_bolts/models/rl/noisy_dqn_model.py @@ -9,10 +9,10 @@ from pl_bolts.datamodules.experience_source import Experience from pl_bolts.models.rl.common.networks import NoisyCNN from pl_bolts.models.rl.dqn_model import DQN -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class NoisyDQN(DQN): """PyTorch Lightning implementation of `Noisy DQN `_ @@ -91,7 +91,7 @@ def train_batch( break -@experimental() +@to_review() def cli_main(): parser = argparse.ArgumentParser(add_help=False) diff --git a/pl_bolts/models/rl/per_dqn_model.py b/pl_bolts/models/rl/per_dqn_model.py index 34db230eb7..7fc09701f4 100644 --- a/pl_bolts/models/rl/per_dqn_model.py +++ b/pl_bolts/models/rl/per_dqn_model.py @@ -12,10 +12,10 @@ from pl_bolts.losses.rl import per_dqn_loss from pl_bolts.models.rl.common.memory import Experience, PERBuffer from pl_bolts.models.rl.dqn_model import DQN -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class PERDQN(DQN): """PyTorch Lightning implementation of `DQN With Prioritized Experience Replay`_. @@ -147,7 +147,7 @@ def _dataloader(self) -> DataLoader: return DataLoader(dataset=self.dataset, batch_size=self.batch_size) -@experimental() +@to_review() def cli_main(): parser = argparse.ArgumentParser(add_help=False) diff --git a/pl_bolts/models/rl/ppo_model.py b/pl_bolts/models/rl/ppo_model.py index eb71903a50..dc0e7ce773 100644 --- a/pl_bolts/models/rl/ppo_model.py +++ b/pl_bolts/models/rl/ppo_model.py @@ -10,7 +10,7 @@ from pl_bolts.datamodules import ExperienceSourceDataset from pl_bolts.models.rl.common.networks import MLP, ActorCategorical, ActorContinous from pl_bolts.utils import _GYM_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _GYM_AVAILABLE: @@ -19,7 +19,7 @@ warn_missing_pkg("gym") -@experimental() +@to_review() class PPO(LightningModule): """PyTorch Lightning implementation of `Proximal Policy Optimization. @@ -358,7 +358,7 @@ def add_model_specific_args(parent_parser): # pragma: no cover return parser -@experimental() +@to_review() def cli_main() -> None: parent_parser = argparse.ArgumentParser(add_help=False) parent_parser = Trainer.add_argparse_args(parent_parser) diff --git a/pl_bolts/models/rl/reinforce_model.py b/pl_bolts/models/rl/reinforce_model.py index c288ad3ee8..ae188d6786 100644 --- a/pl_bolts/models/rl/reinforce_model.py +++ b/pl_bolts/models/rl/reinforce_model.py @@ -15,7 +15,7 @@ from pl_bolts.models.rl.common.agents import PolicyAgent from pl_bolts.models.rl.common.networks import MLP from pl_bolts.utils import _GYM_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _GYM_AVAILABLE: @@ -24,7 +24,7 @@ warn_missing_pkg("gym") -@experimental() +@to_review() class Reinforce(LightningModule): r"""PyTorch Lightning implementation of REINFORCE_. @@ -304,7 +304,7 @@ def add_model_specific_args(arg_parser) -> argparse.ArgumentParser: return arg_parser -@experimental() +@to_review() def cli_main(): parser = argparse.ArgumentParser(add_help=False) diff --git a/pl_bolts/models/rl/sac_model.py b/pl_bolts/models/rl/sac_model.py index c6d65a4123..cca676e5c7 100644 --- a/pl_bolts/models/rl/sac_model.py +++ b/pl_bolts/models/rl/sac_model.py @@ -16,7 +16,7 @@ from pl_bolts.models.rl.common.memory import MultiStepBuffer from pl_bolts.models.rl.common.networks import MLP, ContinuousMLP from pl_bolts.utils import _GYM_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _GYM_AVAILABLE: @@ -26,7 +26,7 @@ Env = object -@experimental() +@to_review() class SAC(LightningModule): def __init__( self, @@ -386,7 +386,7 @@ def add_model_specific_args( return arg_parser -@experimental() +@to_review() def cli_main(): parser = argparse.ArgumentParser(add_help=False) diff --git a/pl_bolts/models/rl/vanilla_policy_gradient_model.py b/pl_bolts/models/rl/vanilla_policy_gradient_model.py index 68a01acfa2..a450f12714 100644 --- a/pl_bolts/models/rl/vanilla_policy_gradient_model.py +++ b/pl_bolts/models/rl/vanilla_policy_gradient_model.py @@ -15,7 +15,7 @@ from pl_bolts.models.rl.common.agents import PolicyAgent from pl_bolts.models.rl.common.networks import MLP from pl_bolts.utils import _GYM_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _GYM_AVAILABLE: @@ -24,7 +24,7 @@ warn_missing_pkg("gym") -@experimental() +@to_review() class VanillaPolicyGradient(LightningModule): r"""PyTorch Lightning implementation of `Vanilla Policy Gradient`_. @@ -287,7 +287,7 @@ def add_model_specific_args(arg_parser) -> argparse.ArgumentParser: return arg_parser -@experimental() +@to_review() def cli_main(): parser = argparse.ArgumentParser(add_help=False) diff --git a/pl_bolts/models/self_supervised/amdim/amdim_module.py b/pl_bolts/models/self_supervised/amdim/amdim_module.py index 39dd05785f..d050666cdd 100644 --- a/pl_bolts/models/self_supervised/amdim/amdim_module.py +++ b/pl_bolts/models/self_supervised/amdim/amdim_module.py @@ -11,10 +11,10 @@ from pl_bolts.models.self_supervised.amdim.datasets import AMDIMPretraining from pl_bolts.models.self_supervised.amdim.networks import AMDIMEncoder from pl_bolts.utils.self_supervised import torchvision_ssl_encoder -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() def generate_power_seq(lr, nb): half = int(nb / 2) coefs = [2**pow for pow in range(half, -half - 1, -1)] @@ -61,7 +61,7 @@ def generate_power_seq(lr, nb): } -@experimental() +@to_review() class AMDIM(LightningModule): """PyTorch Lightning implementation of Augmented Multiscale Deep InfoMax (AMDIM_) @@ -321,7 +321,7 @@ def add_model_specific_args(parent_parser): return parser -@experimental() +@to_review() def cli_main(): parser = ArgumentParser() parser = Trainer.add_argparse_args(parser) diff --git a/pl_bolts/models/self_supervised/amdim/datasets.py b/pl_bolts/models/self_supervised/amdim/datasets.py index 9fe8c4b0a5..51d7d94d64 100644 --- a/pl_bolts/models/self_supervised/amdim/datasets.py +++ b/pl_bolts/models/self_supervised/amdim/datasets.py @@ -5,7 +5,7 @@ from pl_bolts.datasets import CIFAR10Mixed, UnlabeledImagenet from pl_bolts.models.self_supervised.amdim import transforms as amdim_transforms from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -14,7 +14,7 @@ warn_missing_pkg("torchvision") -@experimental() +@to_review() class AMDIMPretraining: """For pretraining we use the train transform for both train and val.""" @@ -76,7 +76,7 @@ def get_dataset(datamodule: str, data_dir, split: str = "train", **kwargs): return datasets[datamodule](dataset_root=data_dir, split=split, **kwargs) -@experimental() +@to_review() class AMDIMPatchesPretraining: """For pretraining we use the train transform for both train and val.""" diff --git a/pl_bolts/models/self_supervised/amdim/networks.py b/pl_bolts/models/self_supervised/amdim/networks.py index 70a6109019..2510e01a2f 100644 --- a/pl_bolts/models/self_supervised/amdim/networks.py +++ b/pl_bolts/models/self_supervised/amdim/networks.py @@ -5,10 +5,10 @@ from torch import nn from torch.nn import functional as F -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class AMDIMEncoder(nn.Module): def __init__( self, @@ -149,7 +149,7 @@ def forward(self, x): return r1, r5, r7 -@experimental() +@to_review() class Conv3x3(nn.Module): def __init__(self, n_in, n_out, n_kern, n_stride, n_pad, use_bn=True, pad_mode="constant"): super().__init__() @@ -173,7 +173,7 @@ def forward(self, x): return out -@experimental() +@to_review() class ConvResBlock(nn.Module): def __init__(self, n_in, n_out, width, stride, pad, depth, use_bn): super().__init__() @@ -193,7 +193,7 @@ def forward(self, x): return x_out -@experimental() +@to_review() class ConvResNxN(nn.Module): def __init__(self, n_in, n_out, width, stride, pad, use_bn=False): super().__init__() @@ -238,7 +238,7 @@ def forward(self, x): return h23 -@experimental() +@to_review() class MaybeBatchNorm2d(nn.Module): def __init__(self, n_ftr, affine, use_bn): super().__init__() @@ -251,7 +251,7 @@ def forward(self, x): return x -@experimental() +@to_review() class NopNet(nn.Module): def __init__(self, norm_dim=None): super().__init__() @@ -265,7 +265,7 @@ def forward(self, x): return x -@experimental() +@to_review() class FakeRKHSConvNet(nn.Module): def __init__(self, n_input, n_output, use_bn=False): super().__init__() diff --git a/pl_bolts/models/self_supervised/amdim/transforms.py b/pl_bolts/models/self_supervised/amdim/transforms.py index a31c67c606..83ee7681bb 100644 --- a/pl_bolts/models/self_supervised/amdim/transforms.py +++ b/pl_bolts/models/self_supervised/amdim/transforms.py @@ -1,6 +1,6 @@ from pl_bolts.transforms.self_supervised import RandomTranslateWithReflect from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -9,7 +9,7 @@ warn_missing_pkg("torchvision") -@experimental() +@to_review() class AMDIMTrainTransformsCIFAR10: """Transforms applied to AMDIM. @@ -54,7 +54,7 @@ def __call__(self, inp): return out1, out2 -@experimental() +@to_review() class AMDIMEvalTransformsCIFAR10: """Transforms applied to AMDIM. @@ -91,7 +91,7 @@ def __call__(self, inp): return out1 -@experimental() +@to_review() class AMDIMTrainTransformsSTL10: """Transforms applied to AMDIM. @@ -132,7 +132,7 @@ def __call__(self, inp): return out1, out2 -@experimental() +@to_review() class AMDIMEvalTransformsSTL10: """Transforms applied to AMDIM. @@ -175,7 +175,7 @@ def __call__(self, inp): return out1 -@experimental() +@to_review() class AMDIMTrainTransformsImageNet128: """Transforms applied to AMDIM. @@ -219,7 +219,7 @@ def __call__(self, inp): return out1, out2 -@experimental() +@to_review() class AMDIMEvalTransformsImageNet128: """Transforms applied to AMDIM. diff --git a/pl_bolts/models/self_supervised/byol/byol_module.py b/pl_bolts/models/self_supervised/byol/byol_module.py index 8e75081a03..c7c8895b30 100644 --- a/pl_bolts/models/self_supervised/byol/byol_module.py +++ b/pl_bolts/models/self_supervised/byol/byol_module.py @@ -10,10 +10,10 @@ from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate from pl_bolts.models.self_supervised.byol.models import SiameseArm from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class BYOL(LightningModule): """PyTorch Lightning implementation of Bootstrap Your Own Latent (BYOL_)_ @@ -178,7 +178,7 @@ def add_model_specific_args(parent_parser): return parser -@experimental() +@to_review() def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule diff --git a/pl_bolts/models/self_supervised/byol/models.py b/pl_bolts/models/self_supervised/byol/models.py index 7bbec2b0b0..2b32da9395 100644 --- a/pl_bolts/models/self_supervised/byol/models.py +++ b/pl_bolts/models/self_supervised/byol/models.py @@ -1,10 +1,10 @@ from torch import nn from pl_bolts.utils.self_supervised import torchvision_ssl_encoder -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class MLP(nn.Module): def __init__(self, input_dim=2048, hidden_size=4096, output_dim=256): super().__init__() @@ -22,7 +22,7 @@ def forward(self, x): return x -@experimental() +@to_review() class SiameseArm(nn.Module): def __init__(self, encoder="resnet50", encoder_out_dim=2048, projector_hidden_size=4096, projector_out_dim=256): super().__init__() diff --git a/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py b/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py index d5de41537c..f3ff458a08 100644 --- a/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py +++ b/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py @@ -10,10 +10,10 @@ CPCTrainTransformsCIFAR10, CPCTrainTransformsSTL10, ) -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() def cli_main(): # pragma: no cover from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule diff --git a/pl_bolts/models/self_supervised/cpc/cpc_module.py b/pl_bolts/models/self_supervised/cpc/cpc_module.py index 4c3759a572..01a9385fc5 100644 --- a/pl_bolts/models/self_supervised/cpc/cpc_module.py +++ b/pl_bolts/models/self_supervised/cpc/cpc_module.py @@ -24,12 +24,12 @@ ) from pl_bolts.utils.pretrained_weights import load_pretrained from pl_bolts.utils.self_supervised import torchvision_ssl_encoder -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review __all__ = ["CPC_v2"] -@experimental() +@to_review() class CPC_v2(LightningModule): def __init__( self, @@ -204,7 +204,7 @@ def add_model_specific_args(parent_parser): return parser -@experimental() +@to_review() def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule diff --git a/pl_bolts/models/self_supervised/cpc/networks.py b/pl_bolts/models/self_supervised/cpc/networks.py index 283b701d56..3005021349 100644 --- a/pl_bolts/models/self_supervised/cpc/networks.py +++ b/pl_bolts/models/self_supervised/cpc/networks.py @@ -1,10 +1,10 @@ from torch import nn from torch.nn import functional as F -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class CPCResNet(nn.Module): def __init__( self, @@ -131,17 +131,17 @@ def forward(self, x): return x -@experimental() +@to_review() def cpc_resnet101(sample_batch, **kwargs): return CPCResNet(sample_batch, LNBottleneck, [3, 4, 46, 3], **kwargs) -@experimental() +@to_review() def cpc_resnet50(sample_batch, **kwargs): return CPCResNet(sample_batch, LNBottleneck, [3, 4, 6, 3], **kwargs) -@experimental() +@to_review() class LNBottleneck(nn.Module): def __init__( self, @@ -207,7 +207,7 @@ def forward(self, x): return out -@experimental() +@to_review() def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding.""" return nn.Conv2d( @@ -222,7 +222,7 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): ) -@experimental() +@to_review() def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution.""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) diff --git a/pl_bolts/models/self_supervised/cpc/transforms.py b/pl_bolts/models/self_supervised/cpc/transforms.py index 70e930a787..4caa22fa00 100644 --- a/pl_bolts/models/self_supervised/cpc/transforms.py +++ b/pl_bolts/models/self_supervised/cpc/transforms.py @@ -1,6 +1,6 @@ from pl_bolts.transforms.self_supervised import Patchify, RandomTranslateWithReflect from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -9,7 +9,7 @@ warn_missing_pkg("torchvision") -@experimental() +@to_review() class CPCTrainTransformsCIFAR10: """Transforms used for CPC: @@ -71,7 +71,7 @@ def __call__(self, inp): return out1 -@experimental() +@to_review() class CPCEvalTransformsCIFAR10: """Transforms used for CPC: @@ -124,7 +124,7 @@ def __call__(self, inp): return out1 -@experimental() +@to_review() class CPCTrainTransformsSTL10: """Transforms used for CPC: @@ -185,7 +185,7 @@ def __call__(self, inp): return out1 -@experimental() +@to_review() class CPCEvalTransformsSTL10: """Transforms used for CPC: @@ -236,7 +236,7 @@ def __call__(self, inp): return out1 -@experimental() +@to_review() class CPCTrainTransformsImageNet128: """Transforms used for CPC: @@ -290,7 +290,7 @@ def __call__(self, inp): return out1 -@experimental() +@to_review() class CPCEvalTransformsImageNet128: """Transforms used for CPC: diff --git a/pl_bolts/models/self_supervised/evaluator.py b/pl_bolts/models/self_supervised/evaluator.py index 7fe737b5b7..b30837e573 100644 --- a/pl_bolts/models/self_supervised/evaluator.py +++ b/pl_bolts/models/self_supervised/evaluator.py @@ -1,9 +1,9 @@ from torch import nn -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class SSLEvaluator(nn.Module): def __init__(self, n_input, n_classes, n_hidden=512, p=0.1): super().__init__() @@ -30,7 +30,7 @@ def forward(self, x): return logits -@experimental() +@to_review() class Flatten(nn.Module): def __init__(self): super().__init__() diff --git a/pl_bolts/models/self_supervised/moco/callbacks.py b/pl_bolts/models/self_supervised/moco/callbacks.py index 0b799c7709..332257276d 100644 --- a/pl_bolts/models/self_supervised/moco/callbacks.py +++ b/pl_bolts/models/self_supervised/moco/callbacks.py @@ -2,10 +2,10 @@ from pytorch_lightning import Callback -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class MocoLRScheduler(Callback): def __init__(self, initial_lr=0.03, use_cosine_scheduler=False, schedule=(120, 160), max_epochs=200): super().__init__() diff --git a/pl_bolts/models/self_supervised/moco/moco2_module.py b/pl_bolts/models/self_supervised/moco/moco2_module.py index 0b35b38027..8a0457e515 100644 --- a/pl_bolts/models/self_supervised/moco/moco2_module.py +++ b/pl_bolts/models/self_supervised/moco/moco2_module.py @@ -27,7 +27,7 @@ Moco2TrainSTL10Transforms, ) from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -36,7 +36,7 @@ warn_missing_pkg("torchvision") -@experimental() +@to_review() class Moco_v2(LightningModule): """PyTorch Lightning implementation of `Moco `_ @@ -343,7 +343,7 @@ def _use_ddp_or_ddp2(trainer: Trainer) -> bool: # utils @torch.no_grad() -@experimental() +@to_review() def concat_all_gather(tensor): """Performs all_gather operation on the provided tensors. @@ -356,7 +356,7 @@ def concat_all_gather(tensor): return output -@experimental() +@to_review() def cli_main(): from pl_bolts.datamodules import CIFAR10DataModule, SSLImagenetDataModule, STL10DataModule diff --git a/pl_bolts/models/self_supervised/moco/transforms.py b/pl_bolts/models/self_supervised/moco/transforms.py index a7aa13a26b..be637e2790 100644 --- a/pl_bolts/models/self_supervised/moco/transforms.py +++ b/pl_bolts/models/self_supervised/moco/transforms.py @@ -6,7 +6,7 @@ stl10_normalization, ) from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -20,7 +20,7 @@ warn_missing_pkg("PIL", pypi_name="Pillow") -@experimental() +@to_review() class Moco2TrainCIFAR10Transforms: """Moco 2 augmentation: @@ -50,7 +50,7 @@ def __call__(self, inp): return q, k -@experimental() +@to_review() class Moco2EvalCIFAR10Transforms: """Moco 2 augmentation: @@ -76,7 +76,7 @@ def __call__(self, inp): return q, k -@experimental() +@to_review() class Moco2TrainSTL10Transforms: """Moco 2 augmentation: @@ -106,7 +106,7 @@ def __call__(self, inp): return q, k -@experimental() +@to_review() class Moco2EvalSTL10Transforms: """Moco 2 augmentation: @@ -132,7 +132,7 @@ def __call__(self, inp): return q, k -@experimental() +@to_review() class Moco2TrainImagenetTransforms: """Moco 2 augmentation: @@ -162,7 +162,7 @@ def __call__(self, inp): return q, k -@experimental() +@to_review() class Moco2EvalImagenetTransforms: """Moco 2 augmentation: @@ -188,7 +188,7 @@ def __call__(self, inp): return q, k -@experimental() +@to_review() class GaussianBlur: """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709.""" diff --git a/pl_bolts/models/self_supervised/resnets.py b/pl_bolts/models/self_supervised/resnets.py index 8da862c84f..9b9c73020d 100644 --- a/pl_bolts/models/self_supervised/resnets.py +++ b/pl_bolts/models/self_supervised/resnets.py @@ -2,7 +2,7 @@ from torch import nn from torch.utils.model_zoo import load_url as load_state_dict_from_url -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review __all__ = [ "ResNet", @@ -30,7 +30,7 @@ } -@experimental() +@to_review() def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding.""" return nn.Conv2d( @@ -45,13 +45,13 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): ) -@experimental() +@to_review() def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution.""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) -@experimental() +@to_review() class BasicBlock(nn.Module): expansion = 1 @@ -93,7 +93,7 @@ def forward(self, x): return out -@experimental() +@to_review() class Bottleneck(nn.Module): expansion = 4 @@ -138,7 +138,7 @@ def forward(self, x): return out -@experimental() +@to_review() class ResNet(nn.Module): def __init__( self, @@ -276,7 +276,7 @@ def forward(self, x): return [x0] -@experimental() +@to_review() def _resnet(arch, block, layers, pretrained, progress, **kwargs): model = ResNet(block, layers, **kwargs) if pretrained: @@ -285,7 +285,7 @@ def _resnet(arch, block, layers, pretrained, progress, **kwargs): return model -@experimental() +@to_review() def resnet18(pretrained: bool = False, progress: bool = True, **kwargs): r"""ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_ @@ -297,7 +297,7 @@ def resnet18(pretrained: bool = False, progress: bool = True, **kwargs): return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) -@experimental() +@to_review() def resnet34(pretrained=False, progress=True, **kwargs): r"""ResNet-34 model from `"Deep Residual Learning for Image Recognition" `_ @@ -309,7 +309,7 @@ def resnet34(pretrained=False, progress=True, **kwargs): return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) -@experimental() +@to_review() def resnet50(pretrained: bool = False, progress: bool = True, **kwargs): r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_ @@ -321,7 +321,7 @@ def resnet50(pretrained: bool = False, progress: bool = True, **kwargs): return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) -@experimental() +@to_review() def resnet101(pretrained: bool = False, progress: bool = True, **kwargs): r"""ResNet-101 model from `"Deep Residual Learning for Image Recognition" `_ @@ -333,7 +333,7 @@ def resnet101(pretrained: bool = False, progress: bool = True, **kwargs): return _resnet("resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) -@experimental() +@to_review() def resnet152(pretrained: bool = False, progress: bool = True, **kwargs): r"""ResNet-152 model from `"Deep Residual Learning for Image Recognition" `_ @@ -345,7 +345,7 @@ def resnet152(pretrained: bool = False, progress: bool = True, **kwargs): return _resnet("resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs) -@experimental() +@to_review() def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs): r"""ResNeXt-50 32x4d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_ @@ -359,7 +359,7 @@ def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs): return _resnet("resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) -@experimental() +@to_review() def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs): r"""ResNeXt-101 32x8d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_ @@ -373,7 +373,7 @@ def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs): return _resnet("resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) -@experimental() +@to_review() def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs): r"""Wide ResNet-50-2 model from `"Wide Residual Networks" `_ @@ -390,7 +390,7 @@ def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs): return _resnet("wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) -@experimental() +@to_review() def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs): r"""Wide ResNet-101-2 model from `"Wide Residual Networks" `_ diff --git a/pl_bolts/models/self_supervised/simclr/simclr_finetuner.py b/pl_bolts/models/self_supervised/simclr/simclr_finetuner.py index e5e50de55d..6b8bc36c93 100644 --- a/pl_bolts/models/self_supervised/simclr/simclr_finetuner.py +++ b/pl_bolts/models/self_supervised/simclr/simclr_finetuner.py @@ -11,10 +11,10 @@ imagenet_normalization, stl10_normalization, ) -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() def cli_main(): # pragma: no cover from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule diff --git a/pl_bolts/models/self_supervised/simclr/simclr_module.py b/pl_bolts/models/self_supervised/simclr/simclr_module.py index 1cb8308cc5..4b3cc5106c 100644 --- a/pl_bolts/models/self_supervised/simclr/simclr_module.py +++ b/pl_bolts/models/self_supervised/simclr/simclr_module.py @@ -15,10 +15,10 @@ imagenet_normalization, stl10_normalization, ) -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class SyncFunction(torch.autograd.Function): @staticmethod def forward(ctx, tensor): @@ -41,7 +41,7 @@ def backward(ctx, grad_output): return grad_input[idx_from:idx_to] -@experimental() +@to_review() class Projection(nn.Module): def __init__(self, input_dim=2048, hidden_dim=2048, output_dim=128): super().__init__() @@ -61,7 +61,7 @@ def forward(self, x): return F.normalize(x, dim=1) -@experimental() +@to_review() class SimCLR(LightningModule): def __init__( self, @@ -304,7 +304,7 @@ def add_model_specific_args(parent_parser): return parser -@experimental() +@to_review() def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule diff --git a/pl_bolts/models/self_supervised/simclr/transforms.py b/pl_bolts/models/self_supervised/simclr/transforms.py index c74ad6fc82..40f5d304a3 100644 --- a/pl_bolts/models/self_supervised/simclr/transforms.py +++ b/pl_bolts/models/self_supervised/simclr/transforms.py @@ -1,7 +1,7 @@ import numpy as np from pl_bolts.utils import _OPENCV_AVAILABLE, _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -15,7 +15,7 @@ warn_missing_pkg("cv2", pypi_name="opencv-python") -@experimental() +@to_review() class SimCLRTrainDataTransform: """Transforms for SimCLR. @@ -93,7 +93,7 @@ def __call__(self, sample): return xi, xj, self.online_transform(sample) -@experimental() +@to_review() class SimCLREvalDataTransform(SimCLRTrainDataTransform): """Transforms for SimCLR. @@ -129,7 +129,7 @@ def __init__( ) -@experimental() +@to_review() class SimCLRFinetuneTransform: def __init__( self, input_height: int = 224, jitter_strength: float = 1.0, normalize=None, eval_transform: bool = False @@ -171,7 +171,7 @@ def __call__(self, sample): return self.transform(sample) -@experimental() +@to_review() class GaussianBlur: # Implements Gaussian blur as described in the SimCLR paper def __init__(self, kernel_size, p=0.5, min=0.1, max=2.0): diff --git a/pl_bolts/models/self_supervised/simsiam/models.py b/pl_bolts/models/self_supervised/simsiam/models.py index 4ea855e952..ccf373c9cf 100644 --- a/pl_bolts/models/self_supervised/simsiam/models.py +++ b/pl_bolts/models/self_supervised/simsiam/models.py @@ -3,10 +3,10 @@ from torch import Tensor, nn from pl_bolts.utils.self_supervised import torchvision_ssl_encoder -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class MLP(nn.Module): def __init__(self, input_dim: int = 2048, hidden_size: int = 4096, output_dim: int = 256) -> None: super().__init__() @@ -24,7 +24,7 @@ def forward(self, x: Tensor) -> Tensor: return x -@experimental() +@to_review() class SiameseArm(nn.Module): def __init__( self, diff --git a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py index 046615e11e..2ae61f3e50 100644 --- a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py +++ b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -14,10 +14,10 @@ imagenet_normalization, stl10_normalization, ) -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class SimSiam(LightningModule): """PyTorch Lightning implementation of Exploring Simple Siamese Representation Learning (SimSiam_) @@ -264,7 +264,7 @@ def add_model_specific_args(parent_parser): return parser -@experimental() +@to_review() def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule diff --git a/pl_bolts/models/self_supervised/ssl_finetuner.py b/pl_bolts/models/self_supervised/ssl_finetuner.py index 4a2d9300a5..3e950bf42e 100644 --- a/pl_bolts/models/self_supervised/ssl_finetuner.py +++ b/pl_bolts/models/self_supervised/ssl_finetuner.py @@ -6,10 +6,10 @@ from torchmetrics import Accuracy from pl_bolts.models.self_supervised import SSLEvaluator -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class SSLFineTuner(LightningModule): """Finetunes a self-supervised learning backbone using the standard evaluation protocol of a singler layer MLP with 1024 units. diff --git a/pl_bolts/models/self_supervised/swav/swav_finetuner.py b/pl_bolts/models/self_supervised/swav/swav_finetuner.py index 5fc66da6e6..8ebb8bfb36 100644 --- a/pl_bolts/models/self_supervised/swav/swav_finetuner.py +++ b/pl_bolts/models/self_supervised/swav/swav_finetuner.py @@ -7,10 +7,10 @@ from pl_bolts.models.self_supervised.swav.swav_module import SwAV from pl_bolts.models.self_supervised.swav.transforms import SwAVFinetuneTransform from pl_bolts.transforms.dataset_normalizations import imagenet_normalization, stl10_normalization -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() def cli_main(): # pragma: no cover from pl_bolts.datamodules import ImagenetDataModule, STL10DataModule diff --git a/pl_bolts/models/self_supervised/swav/swav_module.py b/pl_bolts/models/self_supervised/swav/swav_module.py index 56d5c173d0..cace97b3e7 100644 --- a/pl_bolts/models/self_supervised/swav/swav_module.py +++ b/pl_bolts/models/self_supervised/swav/swav_module.py @@ -17,10 +17,10 @@ imagenet_normalization, stl10_normalization, ) -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class SwAV(LightningModule): def __init__( self, @@ -446,7 +446,7 @@ def add_model_specific_args(parent_parser): return parser -@experimental() +@to_review() def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule diff --git a/pl_bolts/models/self_supervised/swav/swav_resnet.py b/pl_bolts/models/self_supervised/swav/swav_resnet.py index de9ae477b2..bc5419c073 100644 --- a/pl_bolts/models/self_supervised/swav/swav_resnet.py +++ b/pl_bolts/models/self_supervised/swav/swav_resnet.py @@ -2,10 +2,10 @@ import torch from torch import nn -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding.""" return nn.Conv2d( @@ -20,13 +20,13 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): ) -@experimental() +@to_review() def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution.""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) -@experimental() +@to_review() class BasicBlock(nn.Module): expansion = 1 __constants__ = ["downsample"] @@ -77,7 +77,7 @@ def forward(self, x): return out -@experimental() +@to_review() class Bottleneck(nn.Module): expansion = 4 __constants__ = ["downsample"] @@ -131,7 +131,7 @@ def forward(self, x): return out -@experimental() +@to_review() class ResNet(nn.Module): def __init__( self, @@ -343,7 +343,7 @@ def forward(self, inputs): return self.forward_head(output) -@experimental() +@to_review() class MultiPrototypes(nn.Module): def __init__(self, output_dim, nmb_prototypes): super().__init__() @@ -358,26 +358,26 @@ def forward(self, x): return out -@experimental() +@to_review() def resnet18(**kwargs): return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) -@experimental() +@to_review() def resnet50(**kwargs): return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) -@experimental() +@to_review() def resnet50w2(**kwargs): return ResNet(Bottleneck, [3, 4, 6, 3], widen=2, **kwargs) -@experimental() +@to_review() def resnet50w4(**kwargs): return ResNet(Bottleneck, [3, 4, 6, 3], widen=4, **kwargs) -@experimental() +@to_review() def resnet50w5(**kwargs): return ResNet(Bottleneck, [3, 4, 6, 3], widen=5, **kwargs) diff --git a/pl_bolts/models/self_supervised/swav/transforms.py b/pl_bolts/models/self_supervised/swav/transforms.py index c5c211f395..7cecc45f05 100644 --- a/pl_bolts/models/self_supervised/swav/transforms.py +++ b/pl_bolts/models/self_supervised/swav/transforms.py @@ -3,7 +3,7 @@ import numpy as np from pl_bolts.utils import _OPENCV_AVAILABLE, _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -17,7 +17,7 @@ warn_missing_pkg("cv2", pypi_name="opencv-python") -@experimental() +@to_review() class SwAVTrainDataTransform: def __init__( self, @@ -100,7 +100,7 @@ def __call__(self, sample): return multi_crops -@experimental() +@to_review() class SwAVEvalDataTransform(SwAVTrainDataTransform): def __init__( self, @@ -135,7 +135,7 @@ def __init__( self.transform[-1] = test_transform -@experimental() +@to_review() class SwAVFinetuneTransform: def __init__( self, input_height: int = 224, jitter_strength: float = 1.0, normalize=None, eval_transform: bool = False @@ -177,7 +177,7 @@ def __call__(self, sample): return self.transform(sample) -@experimental() +@to_review() class GaussianBlur: # Implements Gaussian blur as described in the SimCLR paper def __init__(self, kernel_size, p=0.5, min=0.1, max=2.0): diff --git a/pl_bolts/models/vision/image_gpt/gpt2.py b/pl_bolts/models/vision/image_gpt/gpt2.py index 4e2709bd47..3b97f33472 100644 --- a/pl_bolts/models/vision/image_gpt/gpt2.py +++ b/pl_bolts/models/vision/image_gpt/gpt2.py @@ -2,10 +2,10 @@ from pytorch_lightning import LightningModule from torch import nn -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class Block(nn.Module): def __init__(self, embed_dim, heads): super().__init__() @@ -30,7 +30,7 @@ def forward(self, x): return x -@experimental() +@to_review() class GPT2(LightningModule): """GPT-2 from `language Models are Unsupervised Multitask Learners `_ diff --git a/pl_bolts/models/vision/image_gpt/igpt_module.py b/pl_bolts/models/vision/image_gpt/igpt_module.py index 8ae07bfff9..a36d5ee146 100644 --- a/pl_bolts/models/vision/image_gpt/igpt_module.py +++ b/pl_bolts/models/vision/image_gpt/igpt_module.py @@ -6,10 +6,10 @@ from torch import nn from pl_bolts.models.vision.image_gpt.gpt2 import GPT2 -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() def _shape_input(x): """shape batch of images for input into GPT2 model.""" x = x.view(x.shape[0], -1) # flatten images into sequences @@ -17,7 +17,7 @@ def _shape_input(x): return x -@experimental() +@to_review() class ImageGPT(LightningModule): """ **Paper**: `Generative Pretraining from Pixels @@ -241,7 +241,7 @@ def add_model_specific_args(parent_parser): return parser -@experimental() +@to_review() def cli_main(): from pl_bolts.datamodules import FashionMNISTDataModule, ImagenetDataModule diff --git a/pl_bolts/models/vision/pixel_cnn.py b/pl_bolts/models/vision/pixel_cnn.py index a93059f009..146d69ae27 100644 --- a/pl_bolts/models/vision/pixel_cnn.py +++ b/pl_bolts/models/vision/pixel_cnn.py @@ -7,10 +7,10 @@ from torch import nn from torch.nn import functional as F -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class PixelCNN(nn.Module): """Implementation of `Pixel CNN `_. diff --git a/pl_bolts/models/vision/segmentation.py b/pl_bolts/models/vision/segmentation.py index 1bd3be8e43..6f72b7d6f9 100644 --- a/pl_bolts/models/vision/segmentation.py +++ b/pl_bolts/models/vision/segmentation.py @@ -5,10 +5,10 @@ from torch.nn import functional as F from pl_bolts.models.vision.unet import UNet -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class SemSegment(LightningModule): def __init__( self, @@ -92,7 +92,7 @@ def add_model_specific_args(parent_parser): return parser -@experimental() +@to_review() def cli_main(): from pl_bolts.datamodules import KittiDataModule diff --git a/pl_bolts/models/vision/unet.py b/pl_bolts/models/vision/unet.py index 5619c8bb3c..3ed88aaa3e 100644 --- a/pl_bolts/models/vision/unet.py +++ b/pl_bolts/models/vision/unet.py @@ -2,10 +2,10 @@ from torch import nn from torch.nn import functional as F -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class UNet(nn.Module): """ Paper: `U-Net: Convolutional Networks for Biomedical Image Segmentation @@ -67,7 +67,7 @@ def forward(self, x): return self.layers[-1](xi[-1]) -@experimental() +@to_review() class DoubleConv(nn.Module): """[ Conv2d => BatchNorm (optional) => ReLU ] x 2.""" @@ -86,7 +86,7 @@ def forward(self, x): return self.net(x) -@experimental() +@to_review() class Down(nn.Module): """Downscale with MaxPool => DoubleConvolution block.""" @@ -98,7 +98,7 @@ def forward(self, x): return self.net(x) -@experimental() +@to_review() class Up(nn.Module): """Upsampling (by either bilinear interpolation or transpose convolutions) followed by concatenation of feature map from contracting path, followed by DoubleConv.""" diff --git a/pl_bolts/optimizers/lars.py b/pl_bolts/optimizers/lars.py index de03c6dcd6..d51a8862f6 100644 --- a/pl_bolts/optimizers/lars.py +++ b/pl_bolts/optimizers/lars.py @@ -6,10 +6,10 @@ import torch from torch.optim.optimizer import Optimizer, required -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class LARS(Optimizer): """Extends SGD in PyTorch with LARS scaling from the paper `Large batch training of Convolutional Networks `_. diff --git a/pl_bolts/optimizers/lr_scheduler.py b/pl_bolts/optimizers/lr_scheduler.py index 5e01bafcd3..61303bb352 100644 --- a/pl_bolts/optimizers/lr_scheduler.py +++ b/pl_bolts/optimizers/lr_scheduler.py @@ -6,10 +6,10 @@ from torch.optim import Adam, Optimizer from torch.optim.lr_scheduler import _LRScheduler -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() class LinearWarmupCosineAnnealingLR(_LRScheduler): """Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr and base_lr followed by a cosine annealing schedule between base_lr and eta_min. @@ -124,7 +124,7 @@ def _get_closed_form_lr(self) -> List[float]: # warmup + decay as a function -@experimental() +@to_review() def linear_warmup_decay(warmup_steps, total_steps, cosine=True, linear=False): """Linear warmup for warmup_steps, optionally with cosine annealing or linear decay to 0 at total_steps.""" assert not (linear and cosine) diff --git a/pl_bolts/transforms/dataset_normalizations.py b/pl_bolts/transforms/dataset_normalizations.py index 3b1a859f1e..17fa1d50fb 100644 --- a/pl_bolts/transforms/dataset_normalizations.py +++ b/pl_bolts/transforms/dataset_normalizations.py @@ -1,5 +1,5 @@ from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -8,7 +8,7 @@ warn_missing_pkg("torchvision") -@experimental() +@to_review() def imagenet_normalization(): if not _TORCHVISION_AVAILABLE: # pragma: no cover raise ModuleNotFoundError( @@ -19,7 +19,7 @@ def imagenet_normalization(): return normalize -@experimental() +@to_review() def cifar10_normalization(): if not _TORCHVISION_AVAILABLE: # pragma: no cover raise ModuleNotFoundError( @@ -33,7 +33,7 @@ def cifar10_normalization(): return normalize -@experimental() +@to_review() def stl10_normalization(): if not _TORCHVISION_AVAILABLE: # pragma: no cover raise ModuleNotFoundError( @@ -44,7 +44,7 @@ def stl10_normalization(): return normalize -@experimental() +@to_review() def emnist_normalization(split: str): if not _TORCHVISION_AVAILABLE: # pragma: no cover raise ModuleNotFoundError( diff --git a/pl_bolts/transforms/self_supervised/ssl_transforms.py b/pl_bolts/transforms/self_supervised/ssl_transforms.py index 492a41a53b..ffbd57ac5a 100644 --- a/pl_bolts/transforms/self_supervised/ssl_transforms.py +++ b/pl_bolts/transforms/self_supervised/ssl_transforms.py @@ -2,7 +2,7 @@ from torch.nn import functional as F from pl_bolts.utils import _PIL_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _PIL_AVAILABLE: @@ -11,7 +11,7 @@ warn_missing_pkg("PIL", pypi_name="Pillow") -@experimental() +@to_review() class RandomTranslateWithReflect: """Translate image randomly Translate vertically and horizontally by n pixels where n is integer drawn uniformly independently for each axis from [-max_translation, max_translation]. @@ -55,7 +55,7 @@ def __call__(self, old_image): return new_image -@experimental() +@to_review() class Patchify: def __init__(self, patch_size, overlap_size): self.patch_size = patch_size diff --git a/pl_bolts/utils/arguments.py b/pl_bolts/utils/arguments.py index ed16117e76..004d945815 100644 --- a/pl_bolts/utils/arguments.py +++ b/pl_bolts/utils/arguments.py @@ -5,10 +5,10 @@ from pytorch_lightning import LightningDataModule, LightningModule -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() @dataclass(frozen=True) class LitArg: """Dataclass to represent init args of an object.""" @@ -20,7 +20,7 @@ class LitArg: context: Optional[str] = None -@experimental() +@to_review() class LightningArgumentParser(ArgumentParser): """Extension of argparse.ArgumentParser that lets you parse arbitrary object init args. @@ -76,7 +76,7 @@ def parse_lit_args(self, *args: Any, **kwargs: Any) -> Namespace: return lit_args -@experimental() +@to_review() def gather_lit_args(cls: Any, root_cls: Optional[Any] = None) -> List[LitArg]: if root_cls is None: diff --git a/pl_bolts/utils/pretrained_weights.py b/pl_bolts/utils/pretrained_weights.py index 438570d18a..cbdd560abf 100644 --- a/pl_bolts/utils/pretrained_weights.py +++ b/pl_bolts/utils/pretrained_weights.py @@ -2,7 +2,7 @@ from pytorch_lightning import LightningModule -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review vae_imagenet2012 = ( "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/" "vae/imagenet_06_22_2019/checkpoints/epoch%3D63.ckpt" @@ -12,7 +12,7 @@ urls = {"vae-imagenet2012": vae_imagenet2012, "CPC_v2-resnet18": cpcv2_resnet18} -@experimental() +@to_review() def load_pretrained(model: LightningModule, class_name: Optional[str] = None) -> None: # pragma: no cover if class_name is None: class_name = model.__class__.__name__ diff --git a/pl_bolts/utils/self_supervised.py b/pl_bolts/utils/self_supervised.py index 1b1c9c6550..a862b6448f 100644 --- a/pl_bolts/utils/self_supervised.py +++ b/pl_bolts/utils/self_supervised.py @@ -2,10 +2,10 @@ from pl_bolts.models.self_supervised import resnets from pl_bolts.utils.semi_supervised import Identity -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() def torchvision_ssl_encoder( name: str, pretrained: bool = False, diff --git a/pl_bolts/utils/semi_supervised.py b/pl_bolts/utils/semi_supervised.py index 68c4ada0c0..ef79e91287 100644 --- a/pl_bolts/utils/semi_supervised.py +++ b/pl_bolts/utils/semi_supervised.py @@ -6,7 +6,7 @@ from torch import Tensor from pl_bolts.utils import _SKLEARN_AVAILABLE -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review from pl_bolts.utils.warnings import warn_missing_pkg if _SKLEARN_AVAILABLE: @@ -15,7 +15,7 @@ warn_missing_pkg("sklearn", pypi_name="scikit-learn") -@experimental() +@to_review() class Identity(torch.nn.Module): """An identity class to replace arbitrary layers in pretrained models. @@ -34,7 +34,7 @@ def forward(self, x: Tensor) -> Tensor: return x -@experimental() +@to_review() def balance_classes( X: Union[Tensor, np.ndarray], Y: Union[Tensor, np.ndarray, Sequence[int]], batch_size: int ) -> Tuple[np.ndarray, np.ndarray]: @@ -98,7 +98,7 @@ def balance_classes( return final_batches_x, final_batches_y -@experimental() +@to_review() def generate_half_labeled_batches( smaller_set_X: np.ndarray, smaller_set_Y: np.ndarray, diff --git a/pl_bolts/utils/shaping.py b/pl_bolts/utils/shaping.py index dabb49df60..da09fafbcb 100644 --- a/pl_bolts/utils/shaping.py +++ b/pl_bolts/utils/shaping.py @@ -2,10 +2,10 @@ import torch from torch import Tensor -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review -@experimental() +@to_review() def tile(a: Tensor, dim: int, n_tile: int) -> Tensor: init_dim = a.size(dim) repeat_idx = [1] * a.dim() diff --git a/pl_bolts/utils/stability.py b/pl_bolts/utils/stability.py index ed66f9b3bb..137ac3afbb 100644 --- a/pl_bolts/utils/stability.py +++ b/pl_bolts/utils/stability.py @@ -13,61 +13,77 @@ # limitations under the License. import functools import inspect -from typing import Callable, Type, Union +from typing import Callable, Optional, Type, Union +from warnings import filterwarnings from pytorch_lightning.utilities import rank_zero_warn -def _raise_experimental_warning(message: str, stacklevel: int = 6): - rank_zero_warn( - f"{message} The compatibility with other Lightning projects is not guaranteed and API may change at any time." +class ReviewNeededWarning(Warning): + pass + + +def _create_full_message(message: str) -> str: + return ( + f"{message} The compatibility with other Lightning projects is not guaranteed and API may change at any time. " "The API and functionality may change without warning in future releases. " - "More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html", - stacklevel=stacklevel, - category=UserWarning, + "More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html" ) -def experimental( - message: str = "This feature is currently marked as experimental.", -): - """The experimental decorator is used to indicate that a particular feature is not properly reviewed and tested yet. - A callable or type that has been marked as experimental will give a ``UserWarning`` when it is called or +def _create_docstring_message(docstring: str, message: str) -> str: + rst_warning = ".. warning:: " + _create_full_message(message) + if docstring is None: + return rst_warning + return rst_warning + "\n\n " + docstring + + +def _add_message_to_docstring(callable: Union[Callable, Type], message: str) -> Union[Callable, Type]: + callable.__doc__ = _create_docstring_message(callable.__doc__, message) + return callable + + +def _raise_review_warning(message: str, stacklevel: int = 6) -> None: + rank_zero_warn(_create_full_message(message), stacklevel=stacklevel, category=ReviewNeededWarning) + + +def to_review(): + """The to_review decorator is used to indicate that a particular feature is not properly reviewed and tested yet. + A callable or type that has been marked as to_review will give a ``ReviewNeededWarning`` when it is called or instantiated. This designation should be used following the description given in :ref:`stability`. Args: message: The message to include in the warning. Examples ________ - >>> import pytest - >>> from pl_bolts.utils.stability import experimental - >>> @experimental() + >>> from pytest import warns + >>> from pl_bolts.utils.stability import to_review, ReviewNeededWarning + >>> @to_review() ... class MyExperimentalFeature: ... pass ... - >>> with pytest.warns(UserWarning, match="This feature is currently marked as experimental."): + >>> with warns(ReviewNeededWarning, match="The feature MyExperimentalFeature is currently marked for review."): ... MyExperimentalFeature() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE ... <...> - >>> @experimental("This feature is currently marked as experimental with a message.") - ... class MyExperimentalFeatureWithCustomMessage: - ... pass - ... - >>> with pytest.warns(UserWarning, match="This feature is currently marked as experimental with a message."): - ... MyExperimentalFeatureWithCustomMessage() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - ... - <...> """ - def decorator(callable: Union[Callable, Type]): + def decorator(callable: Union[Callable, Type], message_name: Optional[str] = None): + if message_name is None: + message_name = callable.__qualname__ + + message = f"The feature {message_name} is currently marked for review." + filterwarnings("once", message, ReviewNeededWarning) if inspect.isclass(callable): - callable.__init__ = decorator(callable.__init__) + if not hasattr(callable.__init__, "__wrapped__"): + callable.__init__ = decorator(callable.__init__, message_name=message_name) + return _add_message_to_docstring(callable, message) return callable @functools.wraps(callable) def wrapper(*args, **kwargs): - _raise_experimental_warning(message) + _raise_review_warning(message) return callable(*args, **kwargs) - return wrapper + return _add_message_to_docstring(wrapper, message) return decorator diff --git a/pl_bolts/utils/warnings.py b/pl_bolts/utils/warnings.py index d60c5e1b8c..dc6a7fc0b1 100644 --- a/pl_bolts/utils/warnings.py +++ b/pl_bolts/utils/warnings.py @@ -2,14 +2,14 @@ import warnings from typing import Callable, Dict, Optional -from pl_bolts.utils.stability import experimental +from pl_bolts.utils.stability import to_review MISSING_PACKAGE_WARNINGS: Dict[str, int] = {} WARN_MISSING_PACKAGE = int(os.environ.get("WARN_MISSING_PACKAGE", False)) -@experimental() +@to_review() def warn_missing_pkg( pkg_name: str, pypi_name: Optional[str] = None,