Skip to content

Commit

Permalink
Merge pull request #429 from olemke/fix-qrnn
Browse files Browse the repository at this point in the history
Fix issue due to newer pytorch/keras versions
  • Loading branch information
olemke authored Nov 14, 2024
2 parents 772047c + 856c579 commit 4b69247
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 27 deletions.
2 changes: 2 additions & 0 deletions typhon/retrieval/qrnn/backends/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from keras.models import Sequential, clone_model, Model
from keras.layers import Dense, Activation, Dropout
from keras.optimizers import SGD
if int(keras.__version__.split('.')[0]) != 2:
raise ImportError()
except ImportError:
raise ImportError(
"Could not import the required Keras modules. The QRNN "
Expand Down
17 changes: 12 additions & 5 deletions typhon/retrieval/qrnn/models/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,18 @@
"""
import logging
import numpy as np
import keras
from keras.models import Sequential
from keras.layers import Dense, Activation, deserialize
from keras.optimizers import SGD
import keras.backend as K
try:
import keras
from keras.models import Sequential
from keras.layers import Dense, Activation, deserialize
from keras.optimizers import SGD
import keras.backend as K
if int(keras.__version__.split('.')[0]) != 2:
raise ImportError()
except ImportError:
raise ImportError(
"Could not import the required Keras modules. The QRNN "
"implementation was developed for use with Keras version 2.0.9.")


def save_model(f, model):
Expand Down
6 changes: 3 additions & 3 deletions typhon/retrieval/qrnn/models/pytorch/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def load_model(f, quantiles):
Returns:
The loaded pytorch model.
"""
model = torch.load(f)
model = torch.load(f, weights_only=False)
return model


Expand Down Expand Up @@ -92,8 +92,8 @@ class BatchedDataset(Dataset):

def __init__(self, training_data, batch_size):
x, y = training_data
self.x = torch.tensor(x, dtype=torch.float)
self.y = torch.tensor(y, dtype=torch.float)
self.x = x if isinstance(x, torch.Tensor) else torch.tensor(x, dtype=torch.float)
self.y = y if isinstance(y, torch.Tensor) else torch.tensor(y, dtype=torch.float)
self.batch_size = batch_size

def __len__(self):
Expand Down
24 changes: 14 additions & 10 deletions typhon/retrieval/qrnn/qrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
################################################################################

try:
import typhon.retrieval.qrnn.models.keras as keras
backend = keras
import typhon.retrieval.qrnn.models.pytorch as pytorch
backend = pytorch
except Exception as e:
try:
import typhon.retrieval.qrnn.models.pytorch as pytorch
backend = pytorch
import typhon.retrieval.qrnn.models.keras as keras
backend = keras
except:
raise Exception("Couldn't import neither Keras nor Pytorch "
"one of them must be available to use the QRNN"
Expand Down Expand Up @@ -600,11 +600,14 @@ def load(path):
The loaded QRNN object.
"""
with open(path, 'rb') as f:
with open(path + ".pkl", 'rb') as f:
qrnn = pickle.load(f)

with open(path + ".model", 'rb') as f:
backend = importlib.import_module(qrnn.backend)
model = backend.load_model(f, qrnn.quantiles)
qrnn.model = model

return qrnn

def save(self, path):
Expand All @@ -621,11 +624,12 @@ def save(self, path):
store the model.
"""
f = open(path, "wb")
pickle.dump(self, f)
backend = importlib.import_module(self.backend)
backend.save_model(f, self.model)
f.close()
with open(path + ".pkl", 'wb') as f:
pickle.dump(self, f)

with open(path + ".model", 'wb') as f:
backend = importlib.import_module(self.backend)
backend.save_model(f, self.model)


def __getstate__(self):
Expand Down
19 changes: 10 additions & 9 deletions typhon/tests/retrieval/qrnn/test_qrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
#

backends = []
try:
import typhon.retrieval.qrnn.models.keras

backends += ["keras"]
except:
pass
# try:
# import typhon.retrieval.qrnn.models.keras
#
# backends += ["keras"]
# except:
# pass

try:
import typhon.retrieval.qrnn.models.pytorch
Expand Down Expand Up @@ -87,9 +87,10 @@ def test_save_qrnn(self, backend):
"""
set_backend(backend)
qrnn = QRNN(self.x_train.shape[1], np.linspace(0.05, 0.95, 10))
f = tempfile.NamedTemporaryFile()
qrnn.save(f.name)
qrnn_loaded = QRNN.load(f.name)
with tempfile.TemporaryDirectory() as d:
f = os.path.join(d, "qrnn")
qrnn.save(f)
qrnn_loaded = QRNN.load(f)

x_pred = qrnn.predict(self.x_train)
x_pred_loaded = qrnn.predict(self.x_train)
Expand Down

0 comments on commit 4b69247

Please sign in to comment.