Skip to content

Commit

Permalink
remove dill dependency for issue in python 3.8 with dill>=0.3.6, use …
Browse files Browse the repository at this point in the history
…partials for pickling. Add torch version cap for segfault issue
  • Loading branch information
cnellington committed Apr 21, 2024
1 parent 69fcb66 commit 38cc018
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 5 deletions.
49 changes: 49 additions & 0 deletions contextualized/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,17 @@ def test_save_load(self):
C = np.random.uniform(0, 1, size=(100, 2))
X = np.random.uniform(0, 1, size=(100, 2))
Y = np.random.uniform(0, 1, size=(100, 2))
C2 = np.random.uniform(0, 1, size=(100, 2))
X2 = np.random.uniform(0, 1, size=(100, 2))
Y2 = np.random.uniform(0, 1, size=(100, 2))
mlp = MLP(2, 2, 50, 5)
Y_pred = mlp(torch.Tensor(X)).detach().numpy()
save(mlp, 'unittest_model.pt')
del mlp
mlp_loaded = load('unittest_model.pt')
Y_pred_loaded = mlp_loaded(torch.Tensor(X)).detach().numpy()
assert np.all(Y_pred == Y_pred_loaded)
os.remove('unittest_model.pt')

model = ContextualizedRegressor()
model.fit(C, X, Y)
Expand All @@ -107,6 +111,15 @@ def test_save_load(self):
Y_pred_loaded = model_loaded.predict(C, X)
assert np.all(Y_pred == Y_pred_loaded)
os.remove('unittest_model.pt')
model_loaded.fit(C2, X2, Y2)
Y_pred2 = model_loaded.predict(C2, X2)
assert not np.all(Y_pred_loaded == Y_pred2)
save(model_loaded, 'unittest_model.pt')
del model_loaded
model_loaded2 = load('unittest_model.pt')
Y_pred_loaded2 = model_loaded2.predict(C2, X2)
assert np.all(Y_pred2 == Y_pred_loaded2)
os.remove('unittest_model.pt')

model = ContextualizedBayesianNetworks()
model.fit(C, X)
Expand All @@ -117,6 +130,15 @@ def test_save_load(self):
pred_loaded = model_loaded.predict_networks(C)
assert np.all(np.array(pred) == np.array(pred_loaded))
os.remove('unittest_model.pt')
model_loaded.fit(C2, X2)
pred2 = model_loaded.predict_networks(C2)
assert not np.all(np.array(pred_loaded) == np.array(pred2))
save(model_loaded, 'unittest_model.pt')
del model_loaded
model_loaded2 = load('unittest_model.pt')
pred_loaded2 = model_loaded2.predict_networks(C2)
assert np.all(np.array(pred2) == np.array(pred_loaded2))
os.remove('unittest_model.pt')

model = ContextualizedCorrelationNetworks()
model.fit(C, X)
Expand All @@ -127,6 +149,15 @@ def test_save_load(self):
pred_loaded = model_loaded.predict_correlation(C)
assert np.all(np.array(pred) == np.array(pred_loaded))
os.remove('unittest_model.pt')
model_loaded.fit(C2, X2)
pred2 = model_loaded.predict_correlation(C2)
assert not np.all(np.array(pred_loaded) == np.array(pred2))
save(model_loaded, 'unittest_model.pt')
del model_loaded
model_loaded2 = load('unittest_model.pt')
pred_loaded2 = model_loaded2.predict_correlation(C2)
assert np.all(np.array(pred2) == np.array(pred_loaded2))
os.remove('unittest_model.pt')

model = BayesianNetwork()
model.fit(X)
Expand All @@ -137,6 +168,15 @@ def test_save_load(self):
pred_loaded = model_loaded.measure_mses(X)
assert np.all(np.array(pred) == np.array(pred_loaded))
os.remove('unittest_model.pt')
model_loaded.fit(X2)
pred2 = model_loaded.measure_mses(X2)
assert not np.all(np.array(pred_loaded) == np.array(pred2))
save(model_loaded, 'unittest_model.pt')
del model_loaded
model_loaded2 = load('unittest_model.pt')
pred_loaded2 = model_loaded2.measure_mses(X2)
assert np.all(np.array(pred2) == np.array(pred_loaded2))
os.remove('unittest_model.pt')

model = CorrelationNetwork()
model.fit(X)
Expand All @@ -147,6 +187,15 @@ def test_save_load(self):
pred_loaded = model_loaded.measure_mses(X)
assert np.all(np.array(pred) == np.array(pred_loaded))
os.remove('unittest_model.pt')
model_loaded.fit(X2)
pred2 = model_loaded.measure_mses(X2)
assert not np.all(np.array(pred_loaded) == np.array(pred2))
save(model_loaded, 'unittest_model.pt')
del model_loaded
model_loaded2 = load('unittest_model.pt')
pred_loaded2 = model_loaded2.measure_mses(X2)
assert np.all(np.array(pred2) == np.array(pred_loaded2))
os.remove('unittest_model.pt')


if __name__ == '__main__':
Expand Down
5 changes: 2 additions & 3 deletions contextualized/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

import torch
import dill


def save(model, path):
Expand All @@ -14,7 +13,7 @@ def save(model, path):
"""
with open(path, "wb") as out_file:
torch.save(model, out_file, pickle_module=dill)
torch.save(model, out_file)


def load(path):
Expand All @@ -24,7 +23,7 @@ def load(path):
"""
with open(path, "rb") as in_file:
model = torch.load(in_file, pickle_module=dill)
model = torch.load(in_file)
return model


Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,12 @@ keywords = [
]
dependencies = [
'lightning>=2.0.0',
'torch>=2.1.0',
'torch>=2.0.0,<2.2.0',
'torchvision>=0.8.0',
'numpy>=1.19.0',
'pandas>=2.0.0',
'scikit-learn>=1.0.0',
'igraph>=0.11.0',
'dill>=0.3.3',
'matplotlib>=3.3.0',
]

Expand Down

0 comments on commit 38cc018

Please sign in to comment.