Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer cleanup #934

Merged
merged 8 commits into from
Feb 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/auto_mix_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@

class TrainerAMPMixin(ABC):

def __init__(self):
self.use_amp = None
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
use_amp: bool

def init_amp(self, use_amp):
self.use_amp = use_amp and APEX_AVAILABLE
Expand Down
23 changes: 16 additions & 7 deletions pytorch_lightning/trainer/callback_config.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
import os
from abc import ABC
from abc import ABC, abstractmethod
from typing import Union

from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import LightningLoggerBase


class TrainerCallbackConfigMixin(ABC):

def __init__(self):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
self.default_save_path = None
self.save_checkpoint = None
self.slurm_job_id = None
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
default_save_path: str
logger: Union[LightningLoggerBase, bool]

@property
@abstractmethod
def slurm_job_id(self) -> int:
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def save_checkpoint(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

def configure_checkpoint_callback(self):
"""
Expand Down
47 changes: 25 additions & 22 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from abc import ABC
from abc import ABC, abstractmethod

import torch.distributed as dist
from torch.utils.data import SequentialSampler, DataLoader
Expand All @@ -8,40 +8,43 @@

try:
from apex import amp

APEX_AVAILABLE = True
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True

try:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

XLA_AVAILABLE = True
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True


class TrainerDataLoadingMixin(ABC):

def __init__(self):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
self.proc_rank = None
self.use_ddp = None
self.use_ddp2 = None
self.shown_warnings = None
self.val_check_interval = None
self.use_tpu = None
self.tpu_local_core_rank = None
self.train_dataloader = None
self.num_training_batches = None
self.val_check_batch = None
self.val_dataloaders = None
self.num_val_batches = None
self.test_dataloaders = None
self.num_test_batches = None
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
proc_rank: int
use_ddp: bool
use_ddp2: bool
shown_warnings: ...
val_check_interval: float
use_tpu: bool
tpu_local_core_rank: int
train_dataloader: DataLoader
num_training_batches: int
val_check_batch: ...
val_dataloaders: DataLoader
num_val_batches: int
test_dataloaders: DataLoader
num_test_batches: int

@abstractmethod
def is_overriden(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

def _percent_range_check(self, name):
value = getattr(self, name)
Expand Down
48 changes: 25 additions & 23 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,48 +118,50 @@ def train_fx(trial_hparams, cluster_manager, _):
import re
import warnings
from abc import ABC, abstractmethod
from typing import Union

import torch
from pytorch_lightning.loggers import LightningLoggerBase

from pytorch_lightning.utilities.debugging import MisconfigurationException

try:
from apex import amp

APEX_AVAILABLE = True
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True


class TrainerDDPMixin(ABC):

def __init__(self):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
self.num_gpus = None
self.on_gpu = None
self.num_gpu_nodes = None
self.logger = None
self.data_parallel_device_ids = None
self.distributed_backend = None
self.use_amp = None
self.amp_level = None
self.use_tpu = None
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
on_gpu: bool
num_gpu_nodes: int
logger: Union[LightningLoggerBase, bool]
data_parallel_device_ids: ...
distributed_backend: str
use_amp: bool
amp_level: str
use_tpu: bool

@property
@abstractmethod
def num_gpus(self) -> int:
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def copy_trainer_model_properties(self, model):
# this is just empty shell for code from other class
pass
def copy_trainer_model_properties(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def run_pretrain_routine(self, model):
# this is just empty shell for code from other class
pass
def run_pretrain_routine(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def init_optimizers(self, optimizers):
# this is just empty shell for code from other class
pass
def init_optimizers(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

def init_tpu(self):
# turn off all the GPU stuff
Expand Down
54 changes: 26 additions & 28 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,49 +348,47 @@

try:
from apex import amp

APEX_AVAILABLE = True
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True

try:
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True

except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True


class TrainerDPMixin(ABC):

def __init__(self):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
self.on_gpu = None
self.use_dp = None
self.use_ddp2 = None
self.use_ddp = None
self.use_amp = None
self.testing = None
self.single_gpu = None
self.root_gpu = None
self.amp_level = None
self.precision = None
self.current_tpu_idx = None
self.proc_rank = None
self.tpu_local_core_rank = None
self.tpu_global_core_rank = None
self.use_tpu = None
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
on_gpu: bool
use_dp: bool
use_ddp2: bool
use_ddp: bool
use_amp: bool
testing: bool
single_gpu: bool
root_gpu: ...
amp_level: str
precision: ...
current_tpu_idx: ...
proc_rank: int
tpu_local_core_rank: int
tpu_global_core_rank: int
use_tpu: bool
data_parallel_device_ids: ...

@abstractmethod
def run_pretrain_routine(self, model):
# this is just empty shell for code from other class
pass
def run_pretrain_routine(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def init_optimizers(self, optimizers):
# this is just empty shell for code from other class
pass
def init_optimizers(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

def copy_trainer_model_properties(self, model):
if isinstance(model, LightningDataParallel):
Expand Down
Loading