Skip to content
This repository has been archived by the owner on Jun 26, 2021. It is now read-only.

Commit

Permalink
Merge pull request #64 from justusschock/backend_choosing
Browse files Browse the repository at this point in the history
refactor backend choosing
  • Loading branch information
ORippler authored Feb 21, 2019
2 parents 2064ed6 + 9edbad1 commit 01f25de
Show file tree
Hide file tree
Showing 35 changed files with 484 additions and 401 deletions.
89 changes: 50 additions & 39 deletions delira/__init__.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,82 @@
__version__ = '0.3.0'

# from .models import AbstractNetwork
# from .logging import TrixiHandler, MultiStreamHandler
# from .data_loading import BaseCacheDataset, BaseLazyDataset, BaseDataManager, \
# RandomSampler, SequentialSampler
import json
import os
import warnings
warnings.simplefilter('default', DeprecationWarning)
warnings.simplefilter('ignore', ImportWarning)

import os
import json
# to register new pssible backends, they have to be added to this list.
# each backend should consist of a tuple of length 2 with the first entry
# being the package import name and the second being the backend abbreviation.
# E.g. TensorFlow's package is named 'tensorflow' but if the package is found,
# it will be considered as 'tf' later on
__POSSIBLE_BACKENDS = [("torch", "torch"), ("tensorflow", "tf")]
__BACKENDS = []


def _determine_backends():

if "DELIRA_BACKEND" not in os.environ:
_config_file = __file__.replace("__init__.py", ".delira")
# look for config file to determine backend
# if file exists: load config into environment variables

if not os.path.isfile(_config_file):
_backends = {}
# try to import backends to determine valid backends
try:
import torch
_backends["torch"] = True
del torch
except ImportError:
_backends["torch"] = False
try:
import tensorflow
_backends["tf"] = True
del tensorflow
except ImportError:
_backends["tf"] = False
# try to import all possible backends to determine valid backends

import importlib
for curr_backend in __POSSIBLE_BACKENDS:
try:
assert len(curr_backend) == 2
assert all([isinstance(_tmp, str) for _tmp in curr_backend]), \
"All entries in current backend must be strings"

# check if backend can be imported
bcknd = importlib.util.find_spec(curr_backend[0])

if bcknd is not None:
_backends[curr_backend[1]] = True
else:
_backends[curr_backend[1]] = False
del bcknd

except ValueError:
_backends[curr_backend[1]] = False

with open(_config_file, "w") as f:
json.dump({ "version": __version__, "backend": _backends}, f, sort_keys=True, indent=4)
json.dump({"version": __version__, "backend": _backends},
f, sort_keys=True, indent=4)

del _backends

# set values from config file to environment variables
# set values from config file to variable
with open(_config_file) as f:
_config_dict = json.load(f)
_backend_str = ""
for key, val in _config_dict.pop("backend").items():
if val:
_backend_str += "%s," % key
_config_dict["backend"] = _backend_str
for key, val in _config_dict.items():
if isinstance(val, str):
val = val.lower()
os.environ["DELIRA_%s" % key.upper()] = val

del _backend_str
__BACKENDS.append(key.upper())
del _config_dict

del _config_file

from .data_loading import BaseCacheDataset, BaseLazyDataset, BaseDataManager, \
RandomSampler, SequentialSampler

from .logging import TrixiHandler, MultiStreamHandler

from .models import AbstractNetwork

def get_backends():
"""
Return List of current backends
Return List of currently available backends
"""
return os.environ["DELIRA_BACKEND"].split(",")[:-1]

if "torch" in os.environ["DELIRA_BACKEND"]:
from .io import torch_load_checkpoint, torch_save_checkpoint
from .models import AbstractPyTorchNetwork
from .data_loading import TorchvisionClassificationDataset
if not __BACKENDS:
_determine_backends()
return __BACKENDS


# if "TORCH" in get_backends():
# from .io import torch_load_checkpoint, torch_save_checkpoint
# from .models import AbstractPyTorchNetwork
# from .data_loading import TorchvisionClassificationDataset
5 changes: 3 additions & 2 deletions delira/data_loading/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
SequentialSampler
from .sampler import __all__ as __all_sampling

import os
if "torch" in os.environ["DELIRA_BACKEND"]:
from delira import get_backends

if "TORCH" in get_backends():
from .dataset import TorchvisionClassificationDataset
3 changes: 2 additions & 1 deletion delira/data_loading/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import typing
from ..utils import subdirs
from ..utils.decorators import make_deprecated
from delira import get_backends


class AbstractDataset:
Expand Down Expand Up @@ -658,7 +659,7 @@ def _make_dataset(self, path):
return data


if "torch" in os.environ["DELIRA_BACKEND"]:
if "TORCH" in get_backends():
from torchvision.datasets import CIFAR10, CIFAR100, EMNIST, MNIST, FashionMNIST

class TorchvisionClassificationDataset(AbstractDataset):
Expand Down
7 changes: 4 additions & 3 deletions delira/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
if "torch" in os.environ["DELIRA_BACKEND"]:
from delira import get_backends

if "TORCH" in get_backends():
from .torch import save_checkpoint as torch_save_checkpoint
from .torch import load_checkpoint as torch_load_checkpoint

if "tf" in os.environ["DELIRA_BACKEND"]:
if "TF" in get_backends():
from .tf import save_checkpoint as tf_save_checkpoint
from .tf import load_checkpoint as tf_load_checkpoint
26 changes: 12 additions & 14 deletions delira/io/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import importlib
from collections import OrderedDict
from itertools import islice
from delira import get_backends

logger = logging.getLogger(__name__)

if "torch" in os.environ["DELIRA_BACKEND"]:
if "TORCH" in get_backends():

import torch

Expand All @@ -18,9 +19,6 @@

from ..models import AbstractPyTorchNetwork




def save_checkpoint(file: str, model=None, optimizers={},
epoch=None, weights_only=True, **kwargs):
"""
Expand Down Expand Up @@ -65,8 +63,8 @@ def save_checkpoint(file: str, model=None, optimizers={},
epoch = 0

state = {"optimizer": optim_state,
"model": model_state,
"epoch": epoch}
"model": model_state,
"epoch": epoch}

if not weights_only:

Expand All @@ -85,12 +83,11 @@ def save_checkpoint(file: str, model=None, optimizers={},
torch.save({'source': source, 'cls_name_model': class_name_model,
'parent_class': parent_class, 'init_kwargs': init_kwargs,
'state_dict': state, 'cls_name_optim': class_names_optim},
file)
file)

else:
torch.save(state, file)


def load_checkpoint(file, weights_only=True, **kwargs):
"""
Loads a saved model
Expand Down Expand Up @@ -128,7 +125,7 @@ def load_checkpoint(file, weights_only=True, **kwargs):

# create class instance (default device: CPU)
exec("model = " + loaded_dict["cls_name_model"] +
"(**loaded_dict['init_kwargs'])")
"(**loaded_dict['init_kwargs'])")

# check for "map_location" kwarg and use device of first weight tensor
# as default argument (weight tensors should be all on same device)
Expand All @@ -141,20 +138,21 @@ def load_checkpoint(file, weights_only=True, **kwargs):
default_device = torch.device("cpu")

map_location = kwargs.get("map_location",
# use slicing instead of converting to list
# to avoid memory overhead
default_device)
# use slicing instead of converting to list
# to avoid memory overhead
default_device)

# push created class from CPU to suitable device
locals()['model'].to(map_location)

locals()['model'].load_state_dict(loaded_dict["state_dict"]["model"])
locals()['model'].load_state_dict(
loaded_dict["state_dict"]["model"])

optims = OrderedDict()

for key in loaded_dict["cls_name_optim"].keys():
exec("_optim = optim.%s(models.parameters())" %
loaded_dict["cls_name_optim"][key])
loaded_dict["cls_name_optim"][key])

optims[key] = locals()['_optim']

Expand Down
7 changes: 4 additions & 3 deletions delira/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .abstract_network import AbstractNetwork

import os
if "torch" in os.environ["DELIRA_BACKEND"]:
from delira import get_backends

if "TORCH" in get_backends():
from .abstract_network import AbstractPyTorchNetwork
from .classification import VGG3DClassificationNetworkPyTorch, \
ClassificationNetworkBasePyTorch
Expand All @@ -10,6 +11,6 @@

from .gan import GenerativeAdversarialNetworkBasePyTorch

if "tf" in os.environ["DELIRA_BACKEND"]:
if "TF" in get_backends():
from .abstract_network import AbstractTfNetwork
from .classification import ClassificationNetworkBaseTf
6 changes: 3 additions & 3 deletions delira/models/abstract_network.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
import logging
import os
from delira import get_backends

file_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -134,7 +134,7 @@ def init_kwargs(self):
"""
return self._init_kwargs

if "torch" in os.environ["DELIRA_BACKEND"]:
if "TORCH" in get_backends():
import torch

class AbstractPyTorchNetwork(AbstractNetwork, torch.nn.Module):
Expand Down Expand Up @@ -228,7 +228,7 @@ def prepare_batch(batch: dict, input_device, output_device):

return return_dict

if "tf" in os.environ["DELIRA_BACKEND"]:
if "TF" in get_backends():
import tensorflow as tf

class AbstractTfNetwork(AbstractNetwork):
Expand Down
8 changes: 5 additions & 3 deletions delira/models/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
if "torch" in os.environ["DELIRA_BACKEND"]:
from delira import get_backends

if "TORCH" in get_backends():
from .classification_network import ClassificationNetworkBasePyTorch
from .classification_network_3D import VGG3DClassificationNetworkPyTorch
if "tf" in os.environ["DELIRA_BACKEND"]:

if "TF" in get_backends():
from .classification_network_tf import ClassificationNetworkBaseTf
from .ResNet18 import ResNet18
5 changes: 3 additions & 2 deletions delira/models/classification/classification_network.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging
file_logger = logging.getLogger(__name__)

import os
if "torch" in os.environ["DELIRA_BACKEND"]:
from delira import get_backends

if "TORCH" in get_backends():
import torch
from torchvision import models as t_models
from delira.models.abstract_network import AbstractPyTorchNetwork
Expand Down
5 changes: 3 additions & 2 deletions delira/models/classification/classification_network_3D.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

file_logger = logging.getLogger(__name__)

import os
if "torch" in os.environ["DELIRA_BACKEND"]:
from delira import get_backends

if "TORCH" in get_backends():
import torch.nn as nn
import torch.nn.functional as F
from .classification_network import ClassificationNetworkBasePyTorch
Expand Down
5 changes: 3 additions & 2 deletions delira/models/gan/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
if "torch" in os.environ["DELIRA_BACKEND"]:
from delira import get_backends

if "TORCH" in get_backends():
from .generative_adversarial_network import \
GenerativeAdversarialNetworkBasePyTorch
5 changes: 3 additions & 2 deletions delira/models/gan/generative_adversarial_network.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging
logger = logging.getLogger(__name__)

import os
if "torch" in os.environ["DELIRA_BACKEND"]:
from delira import get_backends

if "TORCH" in get_backends():
import torch
from torchvision import models as t_models

Expand Down
5 changes: 3 additions & 2 deletions delira/models/segmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
if "torch" in os.environ["DELIRA_BACKEND"]:
from delira import get_backends

if "TORCH" in get_backends():
from .unet import UNet2dPyTorch, UNet3dPyTorch

5 changes: 3 additions & 2 deletions delira/models/segmentation/unet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Adapted from https://github.com/jaxony/unet-pytorch/blob/master/model.py

import os
if "torch" in os.environ["DELIRA_BACKEND"]:
from delira import get_backends

if "TORCH" in get_backends():
import torch
import torch.nn.functional as F
from torch.nn import init
Expand Down
7 changes: 4 additions & 3 deletions delira/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from .experiment import AbstractExperiment
from .abstract_trainer import AbstractNetworkTrainer

import os
if "torch" in os.environ["DELIRA_BACKEND"]:
from delira import get_backends

if "TORCH" in get_backends():
from .experiment import PyTorchExperiment
from .pytorch_trainer import PyTorchNetworkTrainer
from .metrics import AccuracyMetricPyTorch, AurocMetricPyTorch

if "tf" in os.environ["DELIRA_BACKEND"]:
if "TF" in get_backends():
from .experiment import TfExperiment
from .tf_trainer import TfNetworkTrainer
5 changes: 3 additions & 2 deletions delira/training/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .abstract_callback import AbstractCallback
from .early_stopping import EarlyStopping

import os
if "torch" in os.environ["DELIRA_BACKEND"]:
from delira import get_backends

if "TORCH" in get_backends():
from .pytorch_schedulers import DefaultPyTorchSchedulerCallback
from .pytorch_schedulers import CosineAnnealingLRCallback as \
CosineAnnealingLRCallbackPyTorch
Expand Down
Loading

0 comments on commit 01f25de

Please sign in to comment.