Skip to content

Commit

Permalink
add network quicktests to main test file, fix bug with averaged corre…
Browse files Browse the repository at this point in the history
…lation in ContextualizedCorrelationNetworks
  • Loading branch information
cnellington committed Apr 7, 2024
1 parent 1576003 commit 49aff48
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 3 deletions.
6 changes: 3 additions & 3 deletions contextualized/easy/ContextualizedNetworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _split_train_data(
def predict_networks(
self,
C: np.ndarray,
with_offsets: bool,
with_offsets: bool = False,
individual_preds: bool = False,
**kwargs,
) -> Union[
Expand All @@ -61,7 +61,7 @@ def predict_networks(
Returns:
Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray, np.ndarray], Tuple[List[np.ndarray], List[np.ndarray]]]: The predicted network parameters (and offsets if with_offsets is True). Returned as lists of individual bootstraps if individual_preds is True.
"""
betas, mus = self.predict_params(C, uses_y=False, **kwargs)
betas, mus = self.predict_params(C, individual_preds=individual_preds, uses_y=False, **kwargs)
if with_offsets:
return betas, mus
return betas
Expand Down Expand Up @@ -132,7 +132,7 @@ def predict_correlation(
else:
if squared:
return np.square(np.mean(rhos, axis=0))
return np.mean(rhos)
return np.mean(rhos, axis=0)

def measure_mses(
self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False
Expand Down
4 changes: 4 additions & 0 deletions contextualized/easy/tests/test_correlation_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,12 @@ def test_correlation(self):
self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3)
rho = model.predict_correlation(self.C, squared=False)
assert np.min(rho) < 0
assert rho.shape == (1, self.n_samples, self.x_dim, self.x_dim)
rho = model.predict_correlation(self.C, individual_preds=False, squared=False)
assert rho.shape == (self.n_samples, self.x_dim, self.x_dim), rho.shape
rho_squared = model.predict_correlation(self.C, squared=True)
assert np.min(rho_squared) >= 0
assert rho_squared.shape == (1, self.n_samples, self.x_dim, self.x_dim)


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions contextualized/easy/tests/test_markov_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def test_markov(self):
self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3)
omegas = model.predict_precisions(self.C, individual_preds=False)
assert np.shape(omegas) == (self.n_samples, self.x_dim, self.x_dim)
omegas = model.predict_precisions(self.C, individual_preds=True)
assert np.shape(omegas) == (1, self.n_samples, self.x_dim, self.x_dim)


if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from contextualized.dags.tests import *
from contextualized.easy.tests.test_regressor import *
from contextualized.easy.tests.test_classifier import *
from contextualized.easy.tests.test_markov_networks import *
from contextualized.easy.tests.test_correlation_networks import *
from contextualized.easy.tests.test_bayesian_networks import *

if __name__ == '__main__':
unittest.main()

0 comments on commit 49aff48

Please sign in to comment.