diff --git a/typhon/retrieval/qrnn/backends/keras.py b/typhon/retrieval/qrnn/backends/keras.py index 4a7830fc..f8d83188 100644 --- a/typhon/retrieval/qrnn/backends/keras.py +++ b/typhon/retrieval/qrnn/backends/keras.py @@ -18,6 +18,8 @@ from keras.models import Sequential, clone_model, Model from keras.layers import Dense, Activation, Dropout from keras.optimizers import SGD + if int(keras.__version__.split('.')[0]) != 2: + raise ImportError() except ImportError: raise ImportError( "Could not import the required Keras modules. The QRNN " diff --git a/typhon/retrieval/qrnn/models/keras.py b/typhon/retrieval/qrnn/models/keras.py index e049809a..0131bb8d 100644 --- a/typhon/retrieval/qrnn/models/keras.py +++ b/typhon/retrieval/qrnn/models/keras.py @@ -7,11 +7,18 @@ """ import logging import numpy as np -import keras -from keras.models import Sequential -from keras.layers import Dense, Activation, deserialize -from keras.optimizers import SGD -import keras.backend as K +try: + import keras + from keras.models import Sequential + from keras.layers import Dense, Activation, deserialize + from keras.optimizers import SGD + import keras.backend as K + if int(keras.__version__.split('.')[0]) != 2: + raise ImportError() +except ImportError: + raise ImportError( + "Could not import the required Keras modules. The QRNN " + "implementation was developed for use with Keras version 2.0.9.") def save_model(f, model): diff --git a/typhon/retrieval/qrnn/models/pytorch/common.py b/typhon/retrieval/qrnn/models/pytorch/common.py index 17300fa0..ae8b1ec6 100644 --- a/typhon/retrieval/qrnn/models/pytorch/common.py +++ b/typhon/retrieval/qrnn/models/pytorch/common.py @@ -51,7 +51,7 @@ def load_model(f, quantiles): Returns: The loaded pytorch model. """ - model = torch.load(f) + model = torch.load(f, weights_only=False) return model @@ -92,8 +92,8 @@ class BatchedDataset(Dataset): def __init__(self, training_data, batch_size): x, y = training_data - self.x = torch.tensor(x, dtype=torch.float) - self.y = torch.tensor(y, dtype=torch.float) + self.x = x if isinstance(x, torch.Tensor) else torch.tensor(x, dtype=torch.float) + self.y = y if isinstance(y, torch.Tensor) else torch.tensor(y, dtype=torch.float) self.batch_size = batch_size def __len__(self): diff --git a/typhon/retrieval/qrnn/qrnn.py b/typhon/retrieval/qrnn/qrnn.py index 5be1cb08..0e5ecaa0 100644 --- a/typhon/retrieval/qrnn/qrnn.py +++ b/typhon/retrieval/qrnn/qrnn.py @@ -23,12 +23,12 @@ ################################################################################ try: - import typhon.retrieval.qrnn.models.keras as keras - backend = keras + import typhon.retrieval.qrnn.models.pytorch as pytorch + backend = pytorch except Exception as e: try: - import typhon.retrieval.qrnn.models.pytorch as pytorch - backend = pytorch + import typhon.retrieval.qrnn.models.keras as keras + backend = keras except: raise Exception("Couldn't import neither Keras nor Pytorch " "one of them must be available to use the QRNN" @@ -600,11 +600,14 @@ def load(path): The loaded QRNN object. """ - with open(path, 'rb') as f: + with open(path + ".pkl", 'rb') as f: qrnn = pickle.load(f) + + with open(path + ".model", 'rb') as f: backend = importlib.import_module(qrnn.backend) model = backend.load_model(f, qrnn.quantiles) qrnn.model = model + return qrnn def save(self, path): @@ -621,11 +624,12 @@ def save(self, path): store the model. """ - f = open(path, "wb") - pickle.dump(self, f) - backend = importlib.import_module(self.backend) - backend.save_model(f, self.model) - f.close() + with open(path + ".pkl", 'wb') as f: + pickle.dump(self, f) + + with open(path + ".model", 'wb') as f: + backend = importlib.import_module(self.backend) + backend.save_model(f, self.model) def __getstate__(self): diff --git a/typhon/tests/retrieval/qrnn/test_qrnn.py b/typhon/tests/retrieval/qrnn/test_qrnn.py index 327b55a1..a3822020 100644 --- a/typhon/tests/retrieval/qrnn/test_qrnn.py +++ b/typhon/tests/retrieval/qrnn/test_qrnn.py @@ -14,12 +14,12 @@ # backends = [] -try: - import typhon.retrieval.qrnn.models.keras - - backends += ["keras"] -except: - pass +# try: +# import typhon.retrieval.qrnn.models.keras +# +# backends += ["keras"] +# except: +# pass try: import typhon.retrieval.qrnn.models.pytorch @@ -87,9 +87,10 @@ def test_save_qrnn(self, backend): """ set_backend(backend) qrnn = QRNN(self.x_train.shape[1], np.linspace(0.05, 0.95, 10)) - f = tempfile.NamedTemporaryFile() - qrnn.save(f.name) - qrnn_loaded = QRNN.load(f.name) + with tempfile.TemporaryDirectory() as d: + f = os.path.join(d, "qrnn") + qrnn.save(f) + qrnn_loaded = QRNN.load(f) x_pred = qrnn.predict(self.x_train) x_pred_loaded = qrnn.predict(self.x_train)