diff --git a/pytorch_lightning/trainer/auto_mix_precision.py b/pytorch_lightning/trainer/auto_mix_precision.py index 135a0bce35d2c..a84f44a508163 100644 --- a/pytorch_lightning/trainer/auto_mix_precision.py +++ b/pytorch_lightning/trainer/auto_mix_precision.py @@ -12,8 +12,9 @@ class TrainerAMPMixin(ABC): - def __init__(self): - self.use_amp = None + # this is just a summary on variables used in this abstract class, + # the proper values/initialisation should be done in child class + use_amp: bool def init_amp(self, use_amp): self.use_amp = use_amp and APEX_AVAILABLE diff --git a/pytorch_lightning/trainer/callback_config.py b/pytorch_lightning/trainer/callback_config.py index 3756b19e433c0..8a17698e82e31 100644 --- a/pytorch_lightning/trainer/callback_config.py +++ b/pytorch_lightning/trainer/callback_config.py @@ -1,17 +1,26 @@ import os -from abc import ABC +from abc import ABC, abstractmethod +from typing import Union from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping +from pytorch_lightning.loggers import LightningLoggerBase class TrainerCallbackConfigMixin(ABC): - def __init__(self): - # this is just a summary on variables used in this abstract class, - # the proper values/initialisation should be done in child class - self.default_save_path = None - self.save_checkpoint = None - self.slurm_job_id = None + # this is just a summary on variables used in this abstract class, + # the proper values/initialisation should be done in child class + default_save_path: str + logger: Union[LightningLoggerBase, bool] + + @property + @abstractmethod + def slurm_job_id(self) -> int: + """Warning: this is just empty shell for code implemented in other class.""" + + @abstractmethod + def save_checkpoint(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" def configure_checkpoint_callback(self): """ diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 6861fbf33b278..a868b04980099 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -1,4 +1,4 @@ -from abc import ABC +from abc import ABC, abstractmethod import torch.distributed as dist from torch.utils.data import SequentialSampler, DataLoader @@ -8,40 +8,43 @@ try: from apex import amp - - APEX_AVAILABLE = True except ImportError: APEX_AVAILABLE = False +else: + APEX_AVAILABLE = True try: import torch_xla import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp - - XLA_AVAILABLE = True except ImportError: XLA_AVAILABLE = False +else: + XLA_AVAILABLE = True class TrainerDataLoadingMixin(ABC): - def __init__(self): - # this is just a summary on variables used in this abstract class, - # the proper values/initialisation should be done in child class - self.proc_rank = None - self.use_ddp = None - self.use_ddp2 = None - self.shown_warnings = None - self.val_check_interval = None - self.use_tpu = None - self.tpu_local_core_rank = None - self.train_dataloader = None - self.num_training_batches = None - self.val_check_batch = None - self.val_dataloaders = None - self.num_val_batches = None - self.test_dataloaders = None - self.num_test_batches = None + # this is just a summary on variables used in this abstract class, + # the proper values/initialisation should be done in child class + proc_rank: int + use_ddp: bool + use_ddp2: bool + shown_warnings: ... + val_check_interval: float + use_tpu: bool + tpu_local_core_rank: int + train_dataloader: DataLoader + num_training_batches: int + val_check_batch: ... + val_dataloaders: DataLoader + num_val_batches: int + test_dataloaders: DataLoader + num_test_batches: int + + @abstractmethod + def is_overriden(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" def _percent_range_check(self, name): value = getattr(self, name) diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 01bcdb1d1ded2..90c1deb9b44db 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -118,48 +118,50 @@ def train_fx(trial_hparams, cluster_manager, _): import re import warnings from abc import ABC, abstractmethod +from typing import Union import torch +from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.utilities.debugging import MisconfigurationException try: from apex import amp - - APEX_AVAILABLE = True except ImportError: APEX_AVAILABLE = False +else: + APEX_AVAILABLE = True class TrainerDDPMixin(ABC): - def __init__(self): - # this is just a summary on variables used in this abstract class, - # the proper values/initialisation should be done in child class - self.num_gpus = None - self.on_gpu = None - self.num_gpu_nodes = None - self.logger = None - self.data_parallel_device_ids = None - self.distributed_backend = None - self.use_amp = None - self.amp_level = None - self.use_tpu = None + # this is just a summary on variables used in this abstract class, + # the proper values/initialisation should be done in child class + on_gpu: bool + num_gpu_nodes: int + logger: Union[LightningLoggerBase, bool] + data_parallel_device_ids: ... + distributed_backend: str + use_amp: bool + amp_level: str + use_tpu: bool + + @property + @abstractmethod + def num_gpus(self) -> int: + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def copy_trainer_model_properties(self, model): - # this is just empty shell for code from other class - pass + def copy_trainer_model_properties(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def run_pretrain_routine(self, model): - # this is just empty shell for code from other class - pass + def run_pretrain_routine(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def init_optimizers(self, optimizers): - # this is just empty shell for code from other class - pass + def init_optimizers(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" def init_tpu(self): # turn off all the GPU stuff diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 29bc8178b8525..ee5e48338cb04 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -348,49 +348,47 @@ try: from apex import amp - - APEX_AVAILABLE = True except ImportError: APEX_AVAILABLE = False +else: + APEX_AVAILABLE = True try: import torch_xla.core.xla_model as xm - XLA_AVAILABLE = True - except ImportError: XLA_AVAILABLE = False +else: + XLA_AVAILABLE = True class TrainerDPMixin(ABC): - def __init__(self): - # this is just a summary on variables used in this abstract class, - # the proper values/initialisation should be done in child class - self.on_gpu = None - self.use_dp = None - self.use_ddp2 = None - self.use_ddp = None - self.use_amp = None - self.testing = None - self.single_gpu = None - self.root_gpu = None - self.amp_level = None - self.precision = None - self.current_tpu_idx = None - self.proc_rank = None - self.tpu_local_core_rank = None - self.tpu_global_core_rank = None - self.use_tpu = None + # this is just a summary on variables used in this abstract class, + # the proper values/initialisation should be done in child class + on_gpu: bool + use_dp: bool + use_ddp2: bool + use_ddp: bool + use_amp: bool + testing: bool + single_gpu: bool + root_gpu: ... + amp_level: str + precision: ... + current_tpu_idx: ... + proc_rank: int + tpu_local_core_rank: int + tpu_global_core_rank: int + use_tpu: bool + data_parallel_device_ids: ... @abstractmethod - def run_pretrain_routine(self, model): - # this is just empty shell for code from other class - pass + def run_pretrain_routine(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def init_optimizers(self, optimizers): - # this is just empty shell for code from other class - pass + def init_optimizers(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" def copy_trainer_model_properties(self, model): if isinstance(model, LightningDataParallel): diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index bca62836bfdc9..7a837ee09fe42 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -129,100 +129,92 @@ from abc import ABC, abstractmethod import torch +from torch.utils.data import DataLoader from tqdm.auto import tqdm +from pytorch_lightning import LightningModule from pytorch_lightning.utilities.debugging import MisconfigurationException try: import torch_xla.distributed.parallel_loader as xla_pl import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True except ImportError: XLA_AVAILABLE = False +else: + XLA_AVAILABLE = True class TrainerEvaluationLoopMixin(ABC): - def __init__(self): - # this is just a summary on variables used in this abstract class, - # the proper values/initialisation should be done in child class - self.test_progress_bar = None - self.val_progress_bar = None - self.main_progress_bar = None - self.use_ddp = None - self.use_dp = None - self.use_ddp2 = None - self.single_gpu = None - self.data_parallel_device_ids = None - self.model = None - self.num_test_batches = None - self.num_val_batches = None - self.fast_dev_run = None - self.process_position = None - self.show_progress_bar = None - self.process_output = None - self.training_tqdm_dict = None - self.proc_rank = None - self.checkpoint_callback = None - self.current_epoch = None - self.callback_metrics = None - self.test_dataloaders = None - self.val_dataloaders = None - self.use_tpu = None - self.reload_dataloaders_every_epoch = None - self.progress_bar_refresh_rate = None - - # Callback system - self.on_validation_start: Callable = ... - self.on_validation_end: Callable = ... - self.on_test_start: Callable = ... - self.on_test_end: Callable = ... + # this is just a summary on variables used in this abstract class, + # the proper values/initialisation should be done in child class + test_progress_bar: ... + val_progress_bar: ... + main_progress_bar: ... + use_ddp: bool + use_dp: bool + use_ddp2: bool + single_gpu: bool + data_parallel_device_ids: ... + model: LightningModule + num_test_batches: int + num_val_batches: int + fast_dev_run: ... + process_position: ... + show_progress_bar: ... + process_output: ... + training_tqdm_dict: ... + proc_rank: int + checkpoint_callback: ... + current_epoch: int + callback_metrics: ... + test_dataloaders: DataLoader + val_dataloaders: DataLoader + use_tpu: bool + reload_dataloaders_every_epoch: ... + progress_bar_refresh_rate: ... + + # Callback system + on_validation_start: Callable + on_validation_end: Callable + on_test_start: Callable + on_test_end: Callable @abstractmethod - def copy_trainer_model_properties(self, model): - # this is just empty shell for code from other class - pass + def copy_trainer_model_properties(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def get_model(self): - # this is just empty shell for code from other class - pass + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def is_overriden(self, m): - # this is just empty shell for code from other class - pass + def is_overriden(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def transfer_batch_to_tpu(self, batch): - # this is just empty shell for code from other class - pass + def transfer_batch_to_tpu(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def transfer_batch_to_gpu(self, batch, gpu): - # this is just empty shell for code from other class - pass + def transfer_batch_to_gpu(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def add_tqdm_metrics(self, metrics): - # this is just empty shell for code from other class - pass + def add_tqdm_metrics(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def log_metrics(self, metrics, grad_norm_dic): - # this is just empty shell for code from other class - pass + def log_metrics(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def reset_test_dataloader(self, model): - # this is just empty shell for code from other class - pass + def reset_test_dataloader(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def reset_val_dataloader(self, model): - # this is just empty shell for code from other class - pass + def reset_val_dataloader(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" def evaluate(self, model, dataloaders, max_batches, test_mode: bool = False): """Run evaluation code. diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 20a6673d69aa6..091ab02465ff1 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -1,27 +1,28 @@ -from abc import ABC -from typing import Iterable +from abc import ABC, abstractmethod +from typing import Union, Iterable import torch from pytorch_lightning.core import memory -from pytorch_lightning.loggers import TensorBoardLogger, LoggerCollection +from pytorch_lightning.loggers import TensorBoardLogger, LightningLoggerBase, LoggerCollection class TrainerLoggingMixin(ABC): - def __init__(self): - # this is just a summary on variables used in this abstract class, - # the proper values/initialisation should be done in child class - self.current_epoch = None - self.on_gpu = None - self.log_gpu_memory = None - self.logger = None - self.tqdm_metrics = None - self.global_step = None - self.proc_rank = None - self.use_dp = None - self.use_ddp2 = None - self.num_gpus = None + # this is just a summary on variables used in this abstract class, + # the proper values/initialisation should be done in child class + current_epoch: int + on_gpu: bool + log_gpu_memory: ... + logger: Union[LightningLoggerBase, bool] + tqdm_metrics: ... + global_step: int + proc_rank: int + use_dp: bool + use_ddp2: bool + default_save_path: str + slurm_job_id: int + num_gpus: int def configure_logger(self, logger): if logger is True: diff --git a/pytorch_lightning/trainer/model_hooks.py b/pytorch_lightning/trainer/model_hooks.py index eb0d529d2681b..2894cc6e11736 100644 --- a/pytorch_lightning/trainer/model_hooks.py +++ b/pytorch_lightning/trainer/model_hooks.py @@ -27,5 +27,4 @@ def has_arg(self, f_name, arg_name): @abstractmethod def get_model(self): - # this is just empty shell for code from other class - pass + """Warning: this is just empty shell for code implemented in other class.""" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 562c5bcfa334c..d36da054a50f6 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -38,19 +38,19 @@ try: from apex import amp - - APEX_AVAILABLE = True except ImportError: APEX_AVAILABLE = False +else: + APEX_AVAILABLE = True try: import torch_xla import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp - - XLA_AVAILABLE = True except ImportError: XLA_AVAILABLE = False +else: + XLA_AVAILABLE = True class Trainer(TrainerIOMixin, @@ -98,7 +98,7 @@ def __init__( train_percent_check: float = 1.0, val_percent_check: float = 1.0, test_percent_check: float = 1.0, - val_check_interval: Union[float] = 1.0, + val_check_interval: float = 1.0, log_save_interval: int = 100, row_log_interval: int = 10, add_row_log_interval=None, # backward compatible, todo: remove in v0.8.0 diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 569c571838aa6..68645fdfbe650 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -96,11 +96,13 @@ import warnings from abc import ABC from subprocess import call -from argparse import Namespace +from typing import Union import torch import torch.distributed as dist +from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning import LightningModule from pytorch_lightning.overrides.data_parallel import ( LightningDistributedDataParallel, LightningDataParallel, @@ -110,33 +112,32 @@ import torch_xla import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp - - XLA_AVAILABLE = True except ImportError: XLA_AVAILABLE = False +else: + XLA_AVAILABLE = True class TrainerIOMixin(ABC): - def __init__(self): - # this is just a summary on variables used in this abstract class, - # the proper values/initialisation should be done in child class - self.model = None - self.on_gpu = None - self.root_gpu = None - self.resume_from_checkpoint = None - self.use_ddp = None - self.use_ddp2 = None - self.checkpoint_callback = None - self.proc_rank = None - self.weights_save_path = None - self.logger = None - self.early_stop_callback = None - self.lr_schedulers = None - self.optimizers = None - self.on_tpu = None - self.num_training_batches = None - self.accumulate_grad_batches = None + # this is just a summary on variables used in this abstract class, + # the proper values/initialisation should be done in child class + model: LightningModule + on_gpu: bool + root_gpu: ... + resume_from_checkpoint: ... + use_ddp: bool + use_ddp2: bool + checkpoint_callback: ... + proc_rank: int + weights_save_path: str + logger: Union[LightningLoggerBase, bool] + early_stop_callback: ... + lr_schedulers: ... + optimizers: ... + on_tpu: bool + num_training_batches: int + accumulate_grad_batches: int def get_model(self): is_dp_module = isinstance(self.model, (LightningDistributedDataParallel, diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index d2be9894a4fef..9a93923214c13 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -156,91 +156,93 @@ def training_step(self, batch, batch_idx): import copy import warnings -from abc import ABC, abstractmethod import logging as log +from abc import ABC, abstractmethod +from typing import Union, List import numpy as np +from torch.utils.data import DataLoader +from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning import LightningModule from pytorch_lightning.utilities.debugging import MisconfigurationException from pytorch_lightning.callbacks.base import Callback try: from apex import amp - - APEX_AVAILABLE = True except ImportError: APEX_AVAILABLE = False +else: + APEX_AVAILABLE = True try: import torch_xla.distributed.parallel_loader as xla_pl import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True - except ImportError: XLA_AVAILABLE = False +else: + XLA_AVAILABLE = True class TrainerTrainLoopMixin(ABC): - def __init__(self): - # this is just a summary on variables used in this abstract class, - # the proper values/initialisation should be done in child class - self.max_epochs = None - self.min_epochs = None - self.use_ddp = None - self.use_dp = None - self.use_ddp2 = None - self.single_gpu = None - self.use_tpu = None - self.data_parallel_device_ids = None - self.check_val_every_n_epoch = None - self.num_training_batches = None - self.val_check_batch = None - self.num_val_batches = None - self.disable_validation = None - self.fast_dev_run = None - self.main_progress_bar = None - self.accumulation_scheduler = None - self.lr_schedulers = None - self.enable_early_stop = None - self.early_stop_callback = None - self.callback_metrics = None - self.logger = None - self.global_step = None - self.testing = None - self.log_save_interval = None - self.proc_rank = None - self.row_log_interval = None - self.total_batches = None - self.truncated_bptt_steps = None - self.optimizers = None - self.accumulate_grad_batches = None - self.use_amp = None - self.print_nan_grads = None - self.track_grad_norm = None - self.model = None - self.running_loss = None - self.training_tqdm_dict = None - self.reduce_lr_on_plateau_scheduler = None - self.profiler = None - self.batch_idx = None - self.precision = None - self.train_dataloader = None - self.reload_dataloaders_every_epoch = None - self.progress_bar_refresh_rate = None - self.max_steps = ... - self.max_steps = ... - - # Callback system - self.callbacks: list[Callback] = [] - self.max_steps = None - self.on_train_start: Callable = ... - self.on_train_end: Callable = ... - self.on_batch_start: Callable = ... - self.on_batch_end: Callable = ... - self.on_epoch_start: Callable = ... - self.on_epoch_end: Callable = ... + # this is just a summary on variables used in this abstract class, + # the proper values/initialisation should be done in child class + max_epochs: int + min_epochs: int + use_ddp: bool + use_dp: bool + use_ddp2: bool + single_gpu: bool + use_tpu: bool + data_parallel_device_ids: ... + check_val_every_n_epoch: ... + num_training_batches: int + val_check_batch: ... + num_val_batches: int + disable_validation: bool + fast_dev_run: ... + main_progress_bar: ... + accumulation_scheduler: ... + lr_schedulers: ... + enable_early_stop: ... + early_stop_callback: ... + callback_metrics: ... + logger: Union[LightningLoggerBase, bool] + global_step: int + testing: bool + log_save_interval: float + proc_rank: int + row_log_interval: float + total_batches: int + truncated_bptt_steps: ... + optimizers: ... + accumulate_grad_batches: int + use_amp: bool + print_nan_grads: ... + track_grad_norm: ... + model: LightningModule + running_loss: ... + training_tqdm_dict: ... + reduce_lr_on_plateau_scheduler: ... + profiler: ... + batch_idx: int + precision: ... + train_dataloader: DataLoader + reload_dataloaders_every_epoch: bool + progress_bar_refresh_rate: ... + max_steps: int + max_steps: int + total_batch_idx: int + + # Callback system + callbacks: List[Callback] + on_train_start: Callable + on_train_end: Callable + on_batch_start: Callable + on_batch_end: Callable + on_epoch_start: Callable + on_epoch_end: Callable @property def max_nb_epochs(self): @@ -262,78 +264,63 @@ def min_nb_epochs(self): @abstractmethod def get_model(self): - # this is just empty shell for code from other class - pass + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def is_function_implemented(self, m): - # this is just empty shell for code from other class - pass + def is_function_implemented(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def is_infinite_dataloader(self, dataloader): - # this is just empty shell for code from other class - pass + def is_infinite_dataloader(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def run_evaluation(self, test_mode): - # this is just empty shell for code from other class - pass + def run_evaluation(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def transfer_batch_to_gpu(self, batch, gpu): - # this is just empty shell for code from other class - pass + def transfer_batch_to_gpu(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def transfer_batch_to_tpu(self, batch): - # this is just empty shell for code from other class - pass + def transfer_batch_to_tpu(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def clip_gradients(self): - # this is just empty shell for code from other class - pass + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def print_nan_gradients(self): - # this is just empty shell for code from other class - pass + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def is_overriden(self, m): - # this is just empty shell for code from other class - pass + def is_overriden(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def add_tqdm_metrics(self, metrics): - # this is just empty shell for code from other class - pass + def add_tqdm_metrics(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def log_metrics(self, metrics, grad_norm_dic): - # this is just empty shell for code from other class - pass + def log_metrics(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def process_output(self, output, train): - # this is just empty shell for code from other class - pass + def process_output(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def reset_train_dataloader(self, model): - # this is just empty shell for code from other class - pass + def reset_train_dataloader(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def reset_val_dataloader(self, model): - # this is just empty shell for code from other class - pass + """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def has_arg(self, f_name, arg_name): - # this is just empty shell for code from other class - pass + def has_arg(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" def train(self): warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,' diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 7fa4059afc3e2..6e4ea506c3d62 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -8,15 +8,13 @@ class TrainerTrainingTricksMixin(ABC): - def __init__(self): - # this is just a summary on variables used in this abstract class, - # the proper values/initialisation should be done in child class - self.gradient_clip_val = None + # this is just a summary on variables used in this abstract class, + # the proper values/initialisation should be done in child class + gradient_clip_val: ... @abstractmethod def get_model(self): - # this is just empty shell for code from other class - pass + """Warning: this is just empty shell for code implemented in other class.""" def clip_gradients(self): if self.gradient_clip_val > 0: diff --git a/tests/models/utils.py b/tests/models/utils.py index 75be02ef0c836..adf1265ef46ba 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -166,7 +166,7 @@ def load_model(exp, root_weights_dir, module_class=LightningTemplateModel, path_ return trained_model -def run_prediction(dataloader, trained_model, dp=False, min_acc=0.50): +def run_prediction(dataloader, trained_model, dp=False, min_acc=0.45): # run prediction on 1 batch for batch in dataloader: break