This repository has been archived by the owner on Jun 26, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #64 from justusschock/backend_choosing
refactor backend choosing
- Loading branch information
Showing
35 changed files
with
484 additions
and
401 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,71 +1,82 @@ | ||
__version__ = '0.3.0' | ||
|
||
# from .models import AbstractNetwork | ||
# from .logging import TrixiHandler, MultiStreamHandler | ||
# from .data_loading import BaseCacheDataset, BaseLazyDataset, BaseDataManager, \ | ||
# RandomSampler, SequentialSampler | ||
import json | ||
import os | ||
import warnings | ||
warnings.simplefilter('default', DeprecationWarning) | ||
warnings.simplefilter('ignore', ImportWarning) | ||
|
||
import os | ||
import json | ||
# to register new pssible backends, they have to be added to this list. | ||
# each backend should consist of a tuple of length 2 with the first entry | ||
# being the package import name and the second being the backend abbreviation. | ||
# E.g. TensorFlow's package is named 'tensorflow' but if the package is found, | ||
# it will be considered as 'tf' later on | ||
__POSSIBLE_BACKENDS = [("torch", "torch"), ("tensorflow", "tf")] | ||
__BACKENDS = [] | ||
|
||
|
||
def _determine_backends(): | ||
|
||
if "DELIRA_BACKEND" not in os.environ: | ||
_config_file = __file__.replace("__init__.py", ".delira") | ||
# look for config file to determine backend | ||
# if file exists: load config into environment variables | ||
|
||
if not os.path.isfile(_config_file): | ||
_backends = {} | ||
# try to import backends to determine valid backends | ||
try: | ||
import torch | ||
_backends["torch"] = True | ||
del torch | ||
except ImportError: | ||
_backends["torch"] = False | ||
try: | ||
import tensorflow | ||
_backends["tf"] = True | ||
del tensorflow | ||
except ImportError: | ||
_backends["tf"] = False | ||
# try to import all possible backends to determine valid backends | ||
|
||
import importlib | ||
for curr_backend in __POSSIBLE_BACKENDS: | ||
try: | ||
assert len(curr_backend) == 2 | ||
assert all([isinstance(_tmp, str) for _tmp in curr_backend]), \ | ||
"All entries in current backend must be strings" | ||
|
||
# check if backend can be imported | ||
bcknd = importlib.util.find_spec(curr_backend[0]) | ||
|
||
if bcknd is not None: | ||
_backends[curr_backend[1]] = True | ||
else: | ||
_backends[curr_backend[1]] = False | ||
del bcknd | ||
|
||
except ValueError: | ||
_backends[curr_backend[1]] = False | ||
|
||
with open(_config_file, "w") as f: | ||
json.dump({ "version": __version__, "backend": _backends}, f, sort_keys=True, indent=4) | ||
json.dump({"version": __version__, "backend": _backends}, | ||
f, sort_keys=True, indent=4) | ||
|
||
del _backends | ||
|
||
# set values from config file to environment variables | ||
# set values from config file to variable | ||
with open(_config_file) as f: | ||
_config_dict = json.load(f) | ||
_backend_str = "" | ||
for key, val in _config_dict.pop("backend").items(): | ||
if val: | ||
_backend_str += "%s," % key | ||
_config_dict["backend"] = _backend_str | ||
for key, val in _config_dict.items(): | ||
if isinstance(val, str): | ||
val = val.lower() | ||
os.environ["DELIRA_%s" % key.upper()] = val | ||
|
||
del _backend_str | ||
__BACKENDS.append(key.upper()) | ||
del _config_dict | ||
|
||
del _config_file | ||
|
||
from .data_loading import BaseCacheDataset, BaseLazyDataset, BaseDataManager, \ | ||
RandomSampler, SequentialSampler | ||
|
||
from .logging import TrixiHandler, MultiStreamHandler | ||
|
||
from .models import AbstractNetwork | ||
|
||
def get_backends(): | ||
""" | ||
Return List of current backends | ||
Return List of currently available backends | ||
""" | ||
return os.environ["DELIRA_BACKEND"].split(",")[:-1] | ||
|
||
if "torch" in os.environ["DELIRA_BACKEND"]: | ||
from .io import torch_load_checkpoint, torch_save_checkpoint | ||
from .models import AbstractPyTorchNetwork | ||
from .data_loading import TorchvisionClassificationDataset | ||
if not __BACKENDS: | ||
_determine_backends() | ||
return __BACKENDS | ||
|
||
|
||
# if "TORCH" in get_backends(): | ||
# from .io import torch_load_checkpoint, torch_save_checkpoint | ||
# from .models import AbstractPyTorchNetwork | ||
# from .data_loading import TorchvisionClassificationDataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
import os | ||
if "torch" in os.environ["DELIRA_BACKEND"]: | ||
from delira import get_backends | ||
|
||
if "TORCH" in get_backends(): | ||
from .torch import save_checkpoint as torch_save_checkpoint | ||
from .torch import load_checkpoint as torch_load_checkpoint | ||
|
||
if "tf" in os.environ["DELIRA_BACKEND"]: | ||
if "TF" in get_backends(): | ||
from .tf import save_checkpoint as tf_save_checkpoint | ||
from .tf import load_checkpoint as tf_load_checkpoint |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
import os | ||
if "torch" in os.environ["DELIRA_BACKEND"]: | ||
from delira import get_backends | ||
|
||
if "TORCH" in get_backends(): | ||
from .classification_network import ClassificationNetworkBasePyTorch | ||
from .classification_network_3D import VGG3DClassificationNetworkPyTorch | ||
if "tf" in os.environ["DELIRA_BACKEND"]: | ||
|
||
if "TF" in get_backends(): | ||
from .classification_network_tf import ClassificationNetworkBaseTf | ||
from .ResNet18 import ResNet18 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import os | ||
if "torch" in os.environ["DELIRA_BACKEND"]: | ||
from delira import get_backends | ||
|
||
if "TORCH" in get_backends(): | ||
from .generative_adversarial_network import \ | ||
GenerativeAdversarialNetworkBasePyTorch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import os | ||
if "torch" in os.environ["DELIRA_BACKEND"]: | ||
from delira import get_backends | ||
|
||
if "TORCH" in get_backends(): | ||
from .unet import UNet2dPyTorch, UNet3dPyTorch | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.